|
11 | 11 | from .rsfd import d45 |
12 | 12 | from devito.tools import (as_mapper, as_tuple, frozendict, is_integer, |
13 | 13 | Pickable) |
| 14 | +from devito.types.dimension import Dimension |
14 | 15 | from devito.types.utils import DimensionTuple |
15 | 16 | from devito.warnings import warn |
16 | 17 |
|
@@ -101,14 +102,38 @@ def __new__(cls, expr, *dims, **kwargs): |
101 | 102 | # Count the derivatives w.r.t. each variable |
102 | 103 | dcounter = cls._count_derivatives(deriv_order, dims) |
103 | 104 |
|
104 | | - # It's possible that the expr is a `sympy.Number` at this point, which |
| 105 | + # It is possible that the expr is a `sympy.Number` at this point, which |
105 | 106 | # has derivative 0, unless we're taking a 0th derivative. |
106 | 107 | if isinstance(expr, sympy.Number): |
107 | 108 | if any(dcounter.values()): |
108 | 109 | return 0 |
109 | 110 | else: |
110 | 111 | return expr |
111 | 112 |
|
| 113 | + # It is also possible that the expression itself is just a |
| 114 | + # `devito.Dimension` type which is: |
| 115 | + # - derivative 1 if the dimension coincides and the number of derivatives |
| 116 | + # is 1 ie: `Derivative(x, (x, 1)) == 1`. |
| 117 | + # - derivative 0 if the dimension coincides and the total number of |
| 118 | + # derivatives is greater than 1 ie: `Derivative(x, (x, 2)) == 0` and |
| 119 | + # `Derivative(x, x, y) == 0`. |
| 120 | + # - An error otherwise. |
| 121 | + if isinstance(expr, Dimension): |
| 122 | + if expr in dcounter.keys(): |
| 123 | + if dcounter[expr] == 0: |
| 124 | + raise ValueError( |
| 125 | + f'Cannot interpolate a dimension `{expr}` onto itself' |
| 126 | + ) |
| 127 | + elif dcounter.pop(expr) == 1 and not dcounter: |
| 128 | + return 1 |
| 129 | + else: |
| 130 | + return 0 |
| 131 | + else: |
| 132 | + raise ValueError( |
| 133 | + f'Cannot differentiate one dimension `{expr}` with respect to' |
| 134 | + f' another {tuple(dcounter.keys())}' |
| 135 | + ) |
| 136 | + |
112 | 137 | # Validate the finite difference order `fd_order` |
113 | 138 | fd_order = cls._validate_fd_order(kwargs.get('fd_order'), expr, dims, dcounter) |
114 | 139 |
|
@@ -157,12 +182,12 @@ def _validate_expr(expr): |
157 | 182 | convertible to "differentiable" type. |
158 | 183 | """ |
159 | 184 | if type(expr) is sympy.Derivative: |
160 | | - raise ValueError("Cannot nest sympy.Derivative with devito.Derivative") |
| 185 | + raise ValueError('Cannot nest sympy.Derivative with devito.Derivative') |
161 | 186 | if not isinstance(expr, Differentiable): |
162 | 187 | try: |
163 | 188 | expr = diffify(expr) |
164 | 189 | except Exception as e: |
165 | | - raise ValueError("`expr` must be a `Differentiable` type object") from e |
| 190 | + raise ValueError('`expr` must be a `Differentiable` type object') from e |
166 | 191 | return expr |
167 | 192 |
|
168 | 193 | @staticmethod |
|
0 commit comments