|
2 | 2 | from itertools import product |
3 | 3 |
|
4 | 4 | import numpy as np |
5 | | -from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational |
| 5 | +from sympy import S, finite_diff_weights, cacheit, sympify, Rational |
6 | 6 |
|
7 | 7 | from devito.logger import warning |
8 | 8 | from devito.tools import Tag, as_tuple |
@@ -48,11 +48,11 @@ def adjoint(self, matvec): |
48 | 48 | def check_input(func): |
49 | 49 | @wraps(func) |
50 | 50 | def wrapper(expr, *args, **kwargs): |
51 | | - try: |
52 | | - return S.Zero if expr.is_Number else func(expr, *args, **kwargs) |
53 | | - except AttributeError: |
54 | | - raise ValueError("'%s' must be of type Differentiable, not %s" |
55 | | - % (expr, type(expr))) |
| 51 | + # try: |
| 52 | + return S.Zero if expr.is_Number else func(expr, *args, **kwargs) |
| 53 | + # except AttributeError: |
| 54 | + # raise ValueError("'%s' must be of type Differentiable, not %s" |
| 55 | + # % (expr, type(expr))) |
56 | 56 | return wrapper |
57 | 57 |
|
58 | 58 |
|
@@ -326,9 +326,13 @@ def make_shift_x0(shift, ndim): |
326 | 326 |
|
327 | 327 |
|
328 | 328 | def process_weights(weights, expr, dim): |
| 329 | + from devito.symbolics import retrieve_functions |
| 330 | + w_func = retrieve_functions(weights) |
329 | 331 | if weights is None: |
330 | 332 | return 0, None, False |
331 | | - elif isinstance(weights, Function): |
| 333 | + elif w_func: |
| 334 | + assert len(w_func) == 1, "Only one function expected in weights" |
| 335 | + weights = w_func[0] |
332 | 336 | if len(weights.dimensions) == 1: |
333 | 337 | return weights.shape[0], weights.dimensions[0], False |
334 | 338 | try: |
|
0 commit comments