Skip to content

Commit 4fd0296

Browse files
committed
api: fix buffering with multiple conditions
1 parent 54c5e49 commit 4fd0296

4 files changed

Lines changed: 68 additions & 17 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def __new__(cls, *args, **kwargs):
213213
relations=ordering.relations, mode='partial')
214214
ispace = IterationSpace(intervals, iterators)
215215

216-
# Construct the conditionals and replace the ConditionalDimensions in `expr`
216+
# Construct the conditionals
217217
conditionals = {}
218218
for d in ordering:
219219
if not d.is_Conditional:
@@ -225,14 +225,6 @@ def __new__(cls, *args, **kwargs):
225225
if d._factor is not None:
226226
cond = d.relation(cond, GuardFactor(d))
227227
conditionals[d] = cond
228-
# Replace dimension with index
229-
index = d.index
230-
if d.condition is not None and d in expr.free_symbols:
231-
index = index - relational_min(d.condition, d.parent)
232-
shift = relational_shift(d.condition, d.parent)
233-
else:
234-
shift = 0
235-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
236228

237229
# Merge conditionals when possible. E.g if we have an implicit_dim
238230
# and there is a dimension with the same parent, we ca merged
@@ -249,6 +241,14 @@ def __new__(cls, *args, **kwargs):
249241

250242
conditionals = frozendict(conditionals)
251243

244+
# Replace the ConditionalDimensions in `expr`
245+
for d, cond in conditionals.items():
246+
# Replace dimension with index
247+
index = d.index
248+
index = index - relational_min(cond, d.parent)
249+
shift = relational_shift(cond, d.parent)
250+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
251+
252252
# Lower all Differentiable operations into SymPy operations
253253
rhs = diff2sympy(expr.rhs)
254254

devito/ir/support/vector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def __lt__(self, other):
128128
return True
129129
elif q_positive(i):
130130
return False
131+
131132
raise TypeError("Non-comparable index functions") from e
132133

133134
return False
@@ -164,6 +165,7 @@ def __gt__(self, other):
164165
return True
165166
elif q_negative(i):
166167
return False
168+
167169
raise TypeError("Non-comparable index functions") from e
168170

169171
return False
@@ -203,6 +205,7 @@ def __le__(self, other):
203205
return True
204206
elif q_positive(i):
205207
return False
208+
206209
raise TypeError("Non-comparable index functions") from e
207210

208211
# Note: unlike `__lt__`, if we end up here, then *it is* <=. For example,

devito/types/relational.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,22 @@ def relational_shift(expr, s):
302302
if not expr.has(s):
303303
return 0
304304

305-
try:
306-
return expr._as_min
307-
except (TypeError, AttributeError):
305+
return _relational_shift(expr, s)
306+
307+
308+
@singledispatch
309+
def _relational_shift(s, expr):
310+
return 0
311+
312+
313+
@_relational_shift.register(sympy.Or)
314+
@_relational_shift.register(sympy.And)
315+
def _(expr, s):
316+
return sum([_relational_shift(e, s) for e in expr.args])
317+
318+
319+
@_relational_shift.register(sympy.Eq)
320+
def _(expr, s):
321+
if isinstance(expr.lhs, sympy.Mod):
308322
return 0
323+
return expr._as_min

tests/test_buffering.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def test_buffer_reuse():
754754
assert all(np.all(vsave.data[i-1] == i + 1) for i in range(1, nt + 1))
755755

756756

757-
def test_multi_cond():
757+
def test_multi_cond_v0():
758758
grid = Grid((3, 3))
759759
nt = 5
760760

@@ -774,14 +774,47 @@ def test_multi_cond():
774774
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)
775775

776776
eqs = [Eq(T, grid.time_dim)]
777-
# this to save times from 0 to nt - 2
777+
# This saves
778+
# - All subsampled times since ct1 is the dimension of f
779+
# - The last time step (ntmod - 2) through ctend (since it's set as ct1 or ctend)
780+
eqs.append(Eq(f, T, implicit_dims=ctend))
781+
782+
# run operator with buffering
783+
op = Operator(eqs, opt='buffering')
784+
op.apply(time_m=0, time_M=ntmod-2)
785+
786+
for i in range(nt-1):
787+
assert np.allclose(f.data[i], i*2)
788+
assert np.allclose(f.data[nt-1], ntmod - 2)
789+
790+
791+
def test_multi_cond_v1():
792+
grid = Grid((3, 3))
793+
nt = 5
794+
795+
x, y = grid.dimensions
796+
797+
factor = 2
798+
ntmod = (nt - 1) * factor + 1
799+
800+
ct1 = ConditionalDimension(name="ct1", parent=grid.time_dim,
801+
factor=factor, relation=Or,
802+
condition=CondEq(grid.time_dim, ntmod - 2))
803+
804+
f = TimeFunction(grid=grid, name='f', time_order=0,
805+
space_order=0, save=nt, time_dim=ct1)
806+
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)
807+
808+
eqs = [Eq(T, grid.time_dim)]
809+
# This saves
810+
# - All subsampled times since ct1 is the dimension of f with factor 2
811+
# - The last time step (ntmod - 2) since ct1 also has the condition for ntmod - 2
778812
eqs.append(Eq(f, T))
779-
# this to save the last time sample nt - 1
780-
eqs.append(Eq(f.forward, T+1, implicit_dims=ctend))
781813

782814
# run operator with buffering
783815
op = Operator(eqs, opt='buffering')
784816
op.apply(time_m=0, time_M=ntmod-2)
785817

786-
for i in range(nt):
818+
for i in range(nt-1):
787819
assert np.allclose(f.data[i], i*2)
820+
assert np.allclose(f.data[nt-1], ntmod - 2)

0 commit comments

Comments
 (0)