Skip to content

Commit 342be17

Browse files
committed
compiler: fix multi-cond for multi-layer
1 parent c7bec04 commit 342be17

5 files changed

Lines changed: 29 additions & 11 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
)
1212
from devito.symbolics import IntDiv, limits_mapper, uxreplace
1313
from devito.tools import Pickable, Tag, frozendict
14-
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
14+
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min, relational_shift
1515

1616
__all__ = [
1717
'ClusterizedEq',
@@ -229,13 +229,10 @@ def __new__(cls, *args, **kwargs):
229229
index = d.index
230230
if d.condition is not None and d in expr.free_symbols:
231231
index = index - relational_min(d.condition, d.parent)
232-
# If there is a condition we might access on a non-factor
233-
# index and need to make sure we don't overwrite the previous
234-
# index
235-
num = index + d.symbolic_factor - 1
232+
shift = relational_shift(d.condition, d.parent)
236233
else:
237-
num = index
238-
expr = uxreplace(expr, {d: IntDiv(num, d.symbolic_factor)})
234+
shift = 0
235+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})
239236

240237
conditionals = frozendict(conditionals)
241238

devito/passes/clusters/asynchrony.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry):
240240

241241
fetch = e.rhs.indices[d]
242242
fshift = {Forward: 1, Backward: -1}.get(direction, 0)
243-
findex = fetch._subs(pd, pd + fshift)
243+
findex = fetch + fshift if fetch.find(IntDiv) else fetch._subs(pd, pd + fshift)
244244

245245
# If fetching into e.g. `ub[t1]` we might need to prefetch into e.g. `ub[t0]`
246246
tindex0 = e.lhs.indices[d]

devito/passes/clusters/buffering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from itertools import chain
44

55
import numpy as np
6-
from sympy import S
6+
from sympy import S, simplify
77

88
from devito.exceptions import CompilationError
99
from devito.ir import (
@@ -775,7 +775,7 @@ def infer_buffer_size(f, dim, clusters):
775775
slots = [Vector(i) for i in slots]
776776
size = int((vmax(*slots) - vmin(*slots) + 1)[0])
777777

778-
return size
778+
return simplify(size)
779779

780780

781781
def offset_from_centre(d, indices):

devito/symbolics/extended_sympy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from devito.types import Symbol
1919
from devito.types.basic import Basic
20+
from devito.types.relational import Ge
2021

2122
__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa
2223
'LeftShift', 'RightShift', 'IntDiv', 'CallFromPointer',
@@ -46,6 +47,11 @@ def canonical(self):
4647
def negated(self):
4748
return CondNe(*self.args, evaluate=False)
4849

50+
@property
51+
def _as_min(self):
52+
from devito.symbolics.extended_dtypes import INT
53+
return INT(Ge(*self.args))
54+
4955

5056
class CondNe(sympy.Ne):
5157

devito/types/relational.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import sympy
55

6-
__all__ = ['Ge', 'Gt', 'Le', 'Lt', 'Ne', 'relational_max', 'relational_min']
6+
__all__ = ['Ge', 'Gt', 'Le', 'Lt', 'Ne', 'relational_max', 'relational_min',
7+
'relational_shift']
78

89

910
class AbstractRel:
@@ -291,3 +292,17 @@ def _(expr, s):
291292
return expr.gts
292293
else:
293294
return sympy.S.Infinity
295+
296+
297+
def relational_shift(expr, s):
298+
"""
299+
Infer shift incurred by the expression. Generally only
300+
applies when a CondEq is used as it adds a single value.
301+
"""
302+
if not expr.has(s):
303+
return 0
304+
305+
try:
306+
return expr._as_min
307+
except (TypeError, AttributeError):
308+
return 0

0 commit comments

Comments
 (0)