55"""
66
77from sympy import And , Ge , Gt , Le , Lt , Mul , true
8+ from sympy .logic .boolalg import BooleanFunction
9+ import numpy as np
810
911from devito .ir .support .space import Forward , IterationDirection
1012from devito .symbolics import CondEq , CondNe
1113from devito .tools import Pickable , as_tuple , frozendict , split
12- from devito .types import Dimension
14+ from devito .types import Dimension , LocalObject
1315
1416__all__ = ['GuardFactor' , 'GuardBound' , 'GuardBoundNext' , 'BaseGuardBound' ,
15- 'BaseGuardBoundNext' , 'GuardOverflow' , 'Guards' ]
17+ 'BaseGuardBoundNext' , 'GuardOverflow' , 'Guards' , 'GuardExpr' ]
1618
1719
1820class Guard :
@@ -273,6 +275,14 @@ def filter(self, key):
273275 return Guards (m )
274276
275277
278+ class GuardExpr (LocalObject , BooleanFunction ):
279+
280+ dtype = np .bool
281+
282+ def __init__ (self , name , liveness = 'eager' , ** kwargs ):
283+ super ().__init__ (name , liveness = liveness , ** kwargs )
284+
285+
276286# *** Utils
277287
278288def simplify_and (relation , v ):
@@ -294,21 +304,23 @@ def simplify_and(relation, v):
294304 covered = False
295305 new_args = []
296306 for a in candidates :
297- if a .lhs is v .lhs :
298- covered = True
299- try :
300- if type (a ) in (Gt , Ge ) and v .rhs > a .rhs :
301- new_args .append (v )
302- elif type (a ) in (Lt , Le ) and v .rhs < a .rhs :
303- new_args .append (v )
304- else :
305- new_args .append (a )
306- except TypeError :
307- # E.g., `v.rhs = const + z_M` and `a.rhs = z_M`, so the inequalities
308- # above are not evaluable to True/False
307+ if isinstance (a , GuardExpr ) or a .lhs is not v .lhs :
308+ new_args .append (a )
309+ continue
310+
311+ covered = True
312+ try :
313+ if type (a ) in (Gt , Ge ) and v .rhs > a .rhs :
314+ new_args .append (v )
315+ elif type (a ) in (Lt , Le ) and v .rhs < a .rhs :
316+ new_args .append (v )
317+ else :
309318 new_args .append (a )
310- else :
319+ except TypeError :
320+ # E.g., `v.rhs = const + z_M` and `a.rhs = z_M`, so the inequalities
321+ # above are not evaluable to True/False
311322 new_args .append (a )
323+
312324 if not covered :
313325 new_args .append (v )
314326
0 commit comments