Skip to content

Commit fe87e06

Browse files
committed
sympy: Enable add, mul and nest expansion hints
1 parent c89d3cb commit fe87e06

1 file changed

Lines changed: 32 additions & 1 deletion

File tree

devito/finite_differences/derivative.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from collections import OrderedDict
22
from collections.abc import Iterable
33
from functools import cached_property
4+
from itertools import chain
45

56
import sympy
67

78
from .finite_difference import generic_derivative, cross_derivative
8-
from .differentiable import Differentiable, diffify, interp_for_fd
9+
from .differentiable import Differentiable, diffify, interp_for_fd, Add
910
from .tools import direct, transpose
1011
from .rsfd import d45
1112
from devito.tools import (as_mapper, as_tuple, filter_ordered, frozendict, is_integer,
@@ -510,3 +511,33 @@ def _eval_fd(self, expr, **kwargs):
510511
res = res.xreplace(e)
511512

512513
return res
514+
515+
def _eval_expand_nest(self, **hints):
516+
expr = self.args[0]
517+
if isinstance(expr, self.__class__):
518+
return self.func(expr.args[0], *[(d, ii)
519+
for d, ii in zip(
520+
chain(self.dims, expr.dims),
521+
chain(self.deriv_order, expr.deriv_order)
522+
)])
523+
else:
524+
return self
525+
526+
def _eval_expand_mul(self, **hints):
527+
expr = self.args[0]
528+
if isinstance(expr, sympy.Mul):
529+
ind, dep = expr.as_independent(*self.dims, as_Mul=True)
530+
return ind*self.func(dep, *self.args[1:])
531+
else:
532+
return self
533+
534+
def _eval_expand_add(self, **hints):
535+
expr = self.args[0]
536+
if isinstance(expr, sympy.Add):
537+
ind, dep = expr.as_independent(*self.dims, as_Add=True)
538+
if isinstance(dep, sympy.Add):
539+
return Add(*[self.func(s, *self.args[1:]) for s in dep.args])
540+
else:
541+
return self.func(dep, *self.args[1:])
542+
else:
543+
return self

0 commit comments

Comments
 (0)