Skip to content

Commit 2dd24f4

Browse files
JDBetteridgemloubout
authored andcommitted
dsl: Add a case to differentiate dimension type
1 parent 0d90d06 commit 2dd24f4

1 file changed

Lines changed: 28 additions & 3 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .rsfd import d45
1212
from devito.tools import (as_mapper, as_tuple, frozendict, is_integer,
1313
Pickable)
14+
from devito.types.dimension import Dimension
1415
from devito.types.utils import DimensionTuple
1516
from devito.warnings import warn
1617

@@ -101,14 +102,38 @@ def __new__(cls, expr, *dims, **kwargs):
101102
# Count the derivatives w.r.t. each variable
102103
dcounter = cls._count_derivatives(deriv_order, dims)
103104

104-
# It's possible that the expr is a `sympy.Number` at this point, which
105+
# It is possible that the expr is a `sympy.Number` at this point, which
105106
# has derivative 0, unless we're taking a 0th derivative.
106107
if isinstance(expr, sympy.Number):
107108
if any(dcounter.values()):
108109
return 0
109110
else:
110111
return expr
111112

113+
# It is also possible that the expression itself is just a
114+
# `devito.Dimension` type which is:
115+
# - derivative 1 if the dimension coincides and the number of derivatives
116+
# is 1 ie: `Derivative(x, (x, 1)) == 1`.
117+
# - derivative 0 if the dimension coincides and the total number of
118+
# derivatives is greater than 1 ie: `Derivative(x, (x, 2)) == 0` and
119+
# `Derivative(x, x, y) == 0`.
120+
# - An error otherwise.
121+
if isinstance(expr, Dimension):
122+
if expr in dcounter.keys():
123+
if dcounter[expr] == 0:
124+
raise ValueError(
125+
f'Cannot interpolate a dimension `{expr}` onto itself'
126+
)
127+
elif dcounter.pop(expr) == 1 and not dcounter:
128+
return 1
129+
else:
130+
return 0
131+
else:
132+
raise ValueError(
133+
f'Cannot differentiate one dimension `{expr}` with respect to'
134+
f' another {tuple(dcounter.keys())}'
135+
)
136+
112137
# Validate the finite difference order `fd_order`
113138
fd_order = cls._validate_fd_order(kwargs.get('fd_order'), expr, dims, dcounter)
114139

@@ -157,12 +182,12 @@ def _validate_expr(expr):
157182
convertible to "differentiable" type.
158183
"""
159184
if type(expr) is sympy.Derivative:
160-
raise ValueError("Cannot nest sympy.Derivative with devito.Derivative")
185+
raise ValueError('Cannot nest sympy.Derivative with devito.Derivative')
161186
if not isinstance(expr, Differentiable):
162187
try:
163188
expr = diffify(expr)
164189
except Exception as e:
165-
raise ValueError("`expr` must be a `Differentiable` type object") from e
190+
raise ValueError('`expr` must be a `Differentiable` type object') from e
166191
return expr
167192

168193
@staticmethod

0 commit comments

Comments
 (0)