Skip to content

Commit 8107646

Browse files
committed
compiler: Introduce Terminal mixin for SymPy subclasses
1 parent 4e3ac51 commit 8107646

2 files changed

Lines changed: 22 additions & 16 deletions

File tree

devito/symbolics/extended_sympy.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from devito.types.basic import Basic
1919

2020
__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa
21-
'LeftShift', 'RightShift', 'IntDiv', 'CallFromPointer',
21+
'LeftShift', 'RightShift', 'IntDiv', 'Terminal', 'CallFromPointer',
2222
'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite',
2323
'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction',
2424
'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword',
@@ -147,6 +147,17 @@ def __mul__(self, other):
147147
return super().__mul__(other)
148148

149149

150+
class Terminal:
151+
152+
"""
153+
Abstract base class for all terminal objects, that is, those objects
154+
collected by `retrieve_terminals` in addition to all other SymPy atoms
155+
such as `Symbol`, `Number`, etc.
156+
"""
157+
158+
pass
159+
160+
150161
class BasicWrapperMixin:
151162

152163
"""
@@ -188,7 +199,7 @@ def _sympystr(self, printer):
188199
return str(self)
189200

190201

191-
class CallFromPointer(sympy.Expr, Pickable, BasicWrapperMixin):
202+
class CallFromPointer(Expr, Pickable, BasicWrapperMixin, Terminal):
192203

193204
"""
194205
Symbolic representation of the C notation ``pointer->call(params)``.
@@ -256,7 +267,7 @@ def free_symbols(self):
256267
__reduce_ex__ = Pickable.__reduce_ex__
257268

258269

259-
class CallFromComposite(CallFromPointer, Pickable):
270+
class CallFromComposite(CallFromPointer):
260271

261272
"""
262273
Symbolic representation of the C notation ``composite.call(params)``.
@@ -269,7 +280,7 @@ def __str__(self):
269280
__repr__ = __str__
270281

271282

272-
class FieldFromPointer(CallFromPointer, Pickable):
283+
class FieldFromPointer(CallFromPointer):
273284

274285
"""
275286
Symbolic representation of the C notation ``pointer->field``.
@@ -290,7 +301,7 @@ def field(self):
290301
__repr__ = __str__
291302

292303

293-
class FieldFromComposite(CallFromPointer, Pickable):
304+
class FieldFromComposite(CallFromPointer):
294305

295306
"""
296307
Symbolic representation of the C notation ``composite.field``,
@@ -352,7 +363,7 @@ def is_numeric(self):
352363
__reduce_ex__ = Pickable.__reduce_ex__
353364

354365

355-
class UnaryOp(sympy.Expr, Pickable, BasicWrapperMixin):
366+
class UnaryOp(Expr, Pickable, BasicWrapperMixin):
356367

357368
"""
358369
Symbolic representation of a unary C operator.
@@ -490,7 +501,7 @@ def __str__(self):
490501
return f"{self._op}{self.base}"
491502

492503

493-
class IndexedPointer(sympy.Expr, Pickable, BasicWrapperMixin):
504+
class IndexedPointer(Expr, Pickable, BasicWrapperMixin, Terminal):
494505

495506
"""
496507
Symbolic representation of the C notation ``symbol[...]``

devito/symbolics/queries.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from sympy import Eq, IndexedBase, Mod, S, diff, nan
22

3-
from devito.symbolics.extended_sympy import (FieldFromComposite, FieldFromPointer,
4-
IndexedPointer, IntDiv)
3+
from devito.symbolics.extended_sympy import IntDiv, Terminal
54
from devito.tools import as_tuple, is_integer
65
from devito.types.basic import AbstractFunction
76
from devito.types.constant import Constant
@@ -16,13 +15,9 @@
1615
'q_dimension', 'q_positive', 'q_negative']
1716

1817

19-
# The following SymPy objects are considered tree leaves:
20-
#
21-
# * Number
22-
# * Symbol
23-
# * Indexed
24-
extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject,
25-
IndexedPointer)
18+
# The following SymPy objects are considered tree leaves in addition to the classic
19+
# SymPy atoms such as Number, Symbol, Indexed, etc
20+
extra_leaves = (IndexedBase, AbstractObject, Terminal)
2621

2722

2823
def q_symbol(expr):

0 commit comments

Comments
 (0)