Skip to content

Commit 5f6cbf2

Browse files
committed
api: prevent sympy warnings with diag
1 parent aa7620e commit 5f6cbf2

3 files changed

Lines changed: 7 additions & 4 deletions

File tree

devito/finite_differences/elementary.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from packaging.version import Version
33

44
from devito.finite_differences.differentiable import DifferentiableFunction, diffify
5-
from devito.types.lazy import Evaluable
65

76

87
class factorial(DifferentiableFunction, sympy.factorial):
@@ -89,13 +88,13 @@ def root(x):
8988
return diffify(sympy.root(x))
9089

9190

92-
class Min(sympy.Min, Evaluable):
91+
class Min(sympy.Min, DifferentiableFunction):
9392

9493
def _evaluate(self, **kwargs):
9594
return self.func(*self._evaluate_args(**kwargs), evaluate=False)
9695

9796

98-
class Max(sympy.Max, Evaluable):
97+
class Max(sympy.Max, DifferentiableFunction):
9998

10099
def _evaluate(self, **kwargs):
101100
return self.func(*self._evaluate_args(**kwargs), evaluate=False)

devito/finite_differences/operators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from sympy import S
2+
3+
14
def div(func, shift=None, order=None, method='FD', side=None, **kwargs):
25
"""
36
Divergence of the input Function.
@@ -194,6 +197,6 @@ def diag(func, size=None):
194197
to = getattr(func, 'time_order', 0)
195198

196199
tens_func = TensorTimeFunction if func.is_TimeDependent else TensorFunction
197-
comps = [[func if i == j else 0 for i in range(dim)] for j in range(dim)]
200+
comps = [[func if i == j else S.Zero for i in range(dim)] for j in range(dim)]
198201
return tens_func(name='diag', grid=func.grid, space_order=func.space_order,
199202
components=comps, time_order=to, diagonal=True)

devito/passes/clusters/implicit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def key(tkn):
189189

190190
# Turn the reduced mapper into a list of equations
191191
processed = []
192+
192193
for bunch in found.values():
193194
exprs = make_implicit_exprs(bunch.mapper)
194195

0 commit comments

Comments
 (0)