|
1 | 1 | from collections import OrderedDict |
2 | 2 | from collections.abc import Iterable |
3 | 3 | from functools import cached_property |
| 4 | +from itertools import chain |
4 | 5 |
|
5 | 6 | import sympy |
6 | 7 |
|
7 | 8 | from .finite_difference import generic_derivative, cross_derivative |
8 | | -from .differentiable import Differentiable, diffify, interp_for_fd |
| 9 | +from .differentiable import Differentiable, diffify, interp_for_fd, Add |
9 | 10 | from .tools import direct, transpose |
10 | 11 | from .rsfd import d45 |
11 | 12 | from devito.tools import (as_mapper, as_tuple, filter_ordered, frozendict, is_integer, |
@@ -510,3 +511,33 @@ def _eval_fd(self, expr, **kwargs): |
510 | 511 | res = res.xreplace(e) |
511 | 512 |
|
512 | 513 | return res |
| 514 | + |
| 515 | + def _eval_expand_nest(self, **hints): |
| 516 | + expr = self.args[0] |
| 517 | + if isinstance(expr, self.__class__): |
| 518 | + return self.func(expr.args[0], *[(d, ii) |
| 519 | + for d, ii in zip( |
| 520 | + chain(self.dims, expr.dims), |
| 521 | + chain(self.deriv_order, expr.deriv_order) |
| 522 | + )]) |
| 523 | + else: |
| 524 | + return self |
| 525 | + |
| 526 | + def _eval_expand_mul(self, **hints): |
| 527 | + expr = self.args[0] |
| 528 | + if isinstance(expr, sympy.Mul): |
| 529 | + ind, dep = expr.as_independent(*self.dims, as_Mul=True) |
| 530 | + return ind*self.func(dep, *self.args[1:]) |
| 531 | + else: |
| 532 | + return self |
| 533 | + |
| 534 | + def _eval_expand_add(self, **hints): |
| 535 | + expr = self.args[0] |
| 536 | + if isinstance(expr, sympy.Add): |
| 537 | + ind, dep = expr.as_independent(*self.dims, as_Add=True) |
| 538 | + if isinstance(dep, sympy.Add): |
| 539 | + return Add(*[self.func(s, *self.args[1:]) for s in dep.args]) |
| 540 | + else: |
| 541 | + return self.func(dep, *self.args[1:]) |
| 542 | + else: |
| 543 | + return self |
0 commit comments