Skip to content

Commit cd17375

Browse files
committed
compiler: Add GuardExpr utility class
1 parent 8fd3824 commit cd17375

1 file changed

Lines changed: 27 additions & 15 deletions

File tree

devito/ir/support/guards.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
"""
66

77
from sympy import And, Ge, Gt, Le, Lt, Mul, true
8+
from sympy.logic.boolalg import BooleanFunction
9+
import numpy as np
810

911
from devito.ir.support.space import Forward, IterationDirection
1012
from devito.symbolics import CondEq, CondNe
1113
from devito.tools import Pickable, as_tuple, frozendict, split
12-
from devito.types import Dimension
14+
from devito.types import Dimension, LocalObject
1315

1416
__all__ = ['GuardFactor', 'GuardBound', 'GuardBoundNext', 'BaseGuardBound',
15-
'BaseGuardBoundNext', 'GuardOverflow', 'Guards']
17+
'BaseGuardBoundNext', 'GuardOverflow', 'Guards', 'GuardExpr']
1618

1719

1820
class Guard:
@@ -273,6 +275,14 @@ def filter(self, key):
273275
return Guards(m)
274276

275277

278+
class GuardExpr(LocalObject, BooleanFunction):
279+
280+
dtype = np.bool
281+
282+
def __init__(self, name, liveness='eager', **kwargs):
283+
super().__init__(name, liveness=liveness, **kwargs)
284+
285+
276286
# *** Utils
277287

278288
def simplify_and(relation, v):
@@ -294,21 +304,23 @@ def simplify_and(relation, v):
294304
covered = False
295305
new_args = []
296306
for a in candidates:
297-
if a.lhs is v.lhs:
298-
covered = True
299-
try:
300-
if type(a) in (Gt, Ge) and v.rhs > a.rhs:
301-
new_args.append(v)
302-
elif type(a) in (Lt, Le) and v.rhs < a.rhs:
303-
new_args.append(v)
304-
else:
305-
new_args.append(a)
306-
except TypeError:
307-
# E.g., `v.rhs = const + z_M` and `a.rhs = z_M`, so the inequalities
308-
# above are not evaluable to True/False
307+
if isinstance(a, GuardExpr) or a.lhs is not v.lhs:
308+
new_args.append(a)
309+
continue
310+
311+
covered = True
312+
try:
313+
if type(a) in (Gt, Ge) and v.rhs > a.rhs:
314+
new_args.append(v)
315+
elif type(a) in (Lt, Le) and v.rhs < a.rhs:
316+
new_args.append(v)
317+
else:
309318
new_args.append(a)
310-
else:
319+
except TypeError:
320+
# E.g., `v.rhs = const + z_M` and `a.rhs = z_M`, so the inequalities
321+
# above are not evaluable to True/False
311322
new_args.append(a)
323+
312324
if not covered:
313325
new_args.append(v)
314326

0 commit comments

Comments
 (0)