Skip to content

Commit 636f70e

Browse files
committed
api: fix derivatives kw for tensors
1 parent b4e9995 commit 636f70e

3 files changed

Lines changed: 29 additions & 0 deletions

File tree

devito/operations/interpolators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __repr__(self):
125125
return (f"Interpolation({repr(self.expr)} into "
126126
f"{repr(self.interpolator.sfunction)})")
127127

128+
__str__ = __repr__
129+
128130

129131
class Injection(UnevaluatedSparseOperation):
130132

@@ -152,6 +154,8 @@ def operation(self, **kwargs):
152154
def __repr__(self):
153155
return f"Injection({repr(self.expr)} into {repr(self.field)})"
154156

157+
__str__ = __repr__
158+
155159

156160
class GenericInterpolator(ABC):
157161

devito/types/basic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,6 +1574,16 @@ def adjoint(self, inner=True):
15741574
# Real valued adjoint is transpose
15751575
return self.transpose(inner=inner)
15761576

1577+
def __call__(self, **kwargs):
1578+
"""
1579+
Derivative custom inputs (weights/x0/...) is done through call
1580+
and needs to be applied to each component through applyfunc
1581+
"""
1582+
try:
1583+
return self.applyfunc(lambda x: x(**kwargs))
1584+
except TypeError as e:
1585+
raise f"{self.name} not callable with {kwargs}" from e
1586+
15771587
@call_highest_priority('__radd__')
15781588
def __add__(self, other):
15791589
try:

tests/test_tensors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,21 @@ def test_custom_coeffs_tensor():
442442
assert list(drv.weights) == c
443443

444444

445+
@pytest.mark.parametrize('func', [TensorFunction, TensorTimeFunction,
446+
VectorFunction, VectorTimeFunction])
447+
def test_custom_coeffs_tensor_basic(func):
448+
grid = Grid(tuple([5]*3))
449+
f = func(name="t", grid=grid, space_order=2)
450+
451+
# Custom coefficients
452+
c = [10, 10, 10]
453+
454+
df = f.dx(w=c)
455+
for (fi, dfi) in zip(f.values(), df.values()):
456+
assert dfi == fi.dx(w=c)
457+
assert list(dfi.weights) == c
458+
459+
445460
@pytest.mark.parametrize('func1', [TensorFunction, TensorTimeFunction,
446461
VectorFunction, VectorTimeFunction])
447462
def test_rebuild(func1):

0 commit comments

Comments
 (0)