Skip to content

Commit dc0e8f5

Browse files
committed
api: prevent evaluated derivatives to be re-evaluted
1 parent c31bee8 commit dc0e8f5

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,11 @@ def _evaluate(self, **kwargs):
952952
class DiffDerivative(IndexDerivative, DifferentiableOp):
953953
pass
954954

955+
def _eval_at(self, func):
956+
# Like EvalDerivative, a DiffDerivative must have already been evaluated
957+
# at a valid x0 and should not be re-evaluated at a different location
958+
return self
959+
955960

956961
# SymPy args ordering is the same for Derivatives and IndexDerivatives
957962
for i in ('DiffDerivative', 'IndexDerivative'):
@@ -998,6 +1003,11 @@ def _new_rawargs(self, *args, **kwargs):
9981003
kwargs.pop('is_commutative', None)
9991004
return self.func(*args, **kwargs)
10001005

1006+
def _eval_at(self, func):
1007+
# An EvalDerivative must have already been evaluated at a valid x0
1008+
# and should not be re-evaluated at a different location
1009+
return self
1010+
10011011

10021012
class diffify:
10031013

tests/test_derivatives.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -805,13 +805,11 @@ def test_param_stagg_add(self):
805805
eq1 = Eq(vx, (c11 * txx).dy)
806806
eq2 = Eq(vx, (c11 * txx + c66 * txy).dy)
807807

808-
# C66 is a paramater. Expects to evaluate c66 at xp then the derivative at yp
809-
# and the derivative will interpolate txy at xp
808+
# Expects to evaluate c66 at xp then the derivative at yp
810809
expect0 = (c66.subs({x: xp, y: yp}).evaluate * txy).dy.evaluate
811810
assert simplify(eq0.evaluate.rhs - expect0) == 0
812811

813-
# C11 is a paramater and txy is staggered in x.
814-
# Expects to evaluate c11 and txy xp then the derivative at yp
812+
# Expects to evaluate c11 and txy at xp then the derivative at yp
815813
expect1 = (c11._subs(x, xp).evaluate * txx._subs(x, xp).evaluate).dy.evaluate
816814
assert simplify(eq1.evaluate.rhs - expect1) == 0
817815

0 commit comments

Comments
 (0)