|
5 | 5 |
|
6 | 6 | import sympy |
7 | 7 |
|
8 | | -from .finite_difference import generic_derivative, cross_derivative |
9 | | -from .differentiable import Differentiable, diffify, interp_for_fd, Add, Mul |
10 | | -from .tools import direct, transpose |
11 | | -from .rsfd import d45 |
12 | | -from devito.tools import (as_mapper, as_tuple, frozendict, is_integer, |
13 | | - Pickable) |
| 8 | +from devito.tools import Pickable, as_mapper, as_tuple, frozendict, is_integer |
| 9 | +from devito.types.dimension import Dimension |
14 | 10 | from devito.types.utils import DimensionTuple |
15 | 11 | from devito.warnings import warn |
16 | 12 |
|
| 13 | +from .differentiable import Add, Differentiable, Mul, diffify, interp_for_fd |
| 14 | +from .finite_difference import cross_derivative, generic_derivative |
| 15 | +from .rsfd import d45 |
| 16 | +from .tools import direct, transpose |
| 17 | + |
17 | 18 | __all__ = ['Derivative'] |
18 | 19 |
|
19 | 20 |
|
@@ -101,14 +102,30 @@ def __new__(cls, expr, *dims, **kwargs): |
101 | 102 | # Count the derivatives w.r.t. each variable |
102 | 103 | dcounter = cls._count_derivatives(deriv_order, dims) |
103 | 104 |
|
104 | | - # It's possible that the expr is a `sympy.Number` at this point, which |
| 105 | + # It is possible that the expr is a `sympy.Number` at this point, which |
105 | 106 | # has derivative 0, unless we're taking a 0th derivative. |
106 | 107 | if isinstance(expr, sympy.Number): |
107 | 108 | if any(dcounter.values()): |
108 | 109 | return 0 |
109 | 110 | else: |
110 | 111 | return expr |
111 | 112 |
|
| 113 | + # It is also possible that the expression itself is just a |
| 114 | + # `devito.Dimension` type which is: |
| 115 | + # - derivative 1 if the Dimension coincides and the number of derivatives |
| 116 | + # is 1 ie: `Derivative(x, (x, 1)) == 1`. |
| 117 | + # - derivative 0 if the Dimension coincides and the total number of |
| 118 | + # derivatives is greater than 1 ie: `Derivative(x, (x, 2)) == 0` and |
| 119 | + # `Derivative(x, x, y) == 0`. |
| 120 | + # - An unevaluated expression otherwise. |
| 121 | + if isinstance(expr, Dimension) and expr in dcounter: |
| 122 | + if dcounter[expr] == 0: |
| 123 | + pass |
| 124 | + elif dcounter.pop(expr) == 1 and not dcounter: |
| 125 | + return 1 |
| 126 | + else: |
| 127 | + return 0 |
| 128 | + |
112 | 129 | # Validate the finite difference order `fd_order` |
113 | 130 | fd_order = cls._validate_fd_order(kwargs.get('fd_order'), expr, dims, dcounter) |
114 | 131 |
|
@@ -232,7 +249,11 @@ def _validate_fd_order(fd_order, expr, dims, dcounter): |
232 | 249 | Required: `expr`, `dims`, and the derivative counter to validate. |
233 | 250 | If not provided, the maximum supported order will be used. |
234 | 251 | """ |
235 | | - if fd_order is not None: |
| 252 | + if isinstance(expr, Dimension): |
| 253 | + # If the expression is just a dimension `expr.time_order` and |
| 254 | + # `expr.space_order` are not defined |
| 255 | + fd_order = (99,)*len(dcounter) |
| 256 | + elif fd_order is not None: |
236 | 257 | # If `fd_order` is specified, then validate |
237 | 258 | fcounter = defaultdict(int) |
238 | 259 | # First create a dictionary mapping variable wrt which to differentiate |
|
0 commit comments