Skip to content

Commit c7bec04

Browse files
committed
compiler: fix multi-cond async
1 parent 828497e commit c7bec04

3 files changed

Lines changed: 10 additions & 2 deletions

File tree

devito/finite_differences/finite_difference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def first_derivative(expr, dim, fd_order, **kwargs):
157157

158158
def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coefficients,
159159
expand, weights=None):
160+
if deriv_order == 0 and not expr.is_Add:
161+
print(expr, dim, fd_order)
160162
# Always expand time derivatives to avoid issue with buffering and streaming.
161163
# Time derivative are almost always short stencils and won't benefit from
162164
# unexpansion in the rare case the derivative is not evaluated for time stepping.

devito/ir/equations/equation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,13 @@ def __new__(cls, *args, **kwargs):
229229
index = d.index
230230
if d.condition is not None and d in expr.free_symbols:
231231
index = index - relational_min(d.condition, d.parent)
232-
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)})
232+
# If there is a condition we might access on a non-factor
233+
# index and need to make sure we don't overwrite the previous
234+
# index
235+
num = index + d.symbolic_factor - 1
236+
else:
237+
num = index
238+
expr = uxreplace(expr, {d: IntDiv(num, d.symbolic_factor)})
233239

234240
conditionals = frozendict(conditionals)
235241

devito/passes/clusters/asynchrony.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry):
240240

241241
fetch = e.rhs.indices[d]
242242
fshift = {Forward: 1, Backward: -1}.get(direction, 0)
243-
findex = fetch + fshift if fetch.find(IntDiv) else fetch._subs(pd, pd + fshift)
243+
findex = fetch._subs(pd, pd + fshift)
244244

245245
# If fetching into e.g. `ub[t1]` we might need to prefetch into e.g. `ub[t0]`
246246
tindex0 = e.lhs.indices[d]

0 commit comments

Comments
 (0)