Skip to content

Commit eaa4891

Browse files
authored
Merge pull request #2723 from devitocodes/tens-deriv-kw
api: fix derivatives kw for tensors
2 parents b4e9995 + af0bd23 commit eaa4891

9 files changed

Lines changed: 62 additions & 6 deletions

File tree

.github/workflows/asv.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ on:
1313
inputs:
1414
tags:
1515
description: 'Run ASV'
16-
# Trigger the workflow on push to the main branch
17-
push:
18-
branches:
19-
- main
2016

2117
jobs:
2218

devito/finite_differences/finite_difference.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from collections.abc import Iterable
2+
13
from sympy import sympify
24

5+
from devito.logger import warning
36
from .differentiable import EvalDerivative, DiffDerivative, Weights
47
from .tools import (left, right, generate_indices, centered, direct, transpose,
58
check_input, fd_weights_registry, process_weights)
@@ -158,6 +161,12 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
158161
# `coefficients` method (`taylor` or `symbolic`)
159162
if weights is None:
160163
weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0)
164+
if isinstance(weights, Iterable) and len(weights) != len(indices):
165+
warning(f"Number of weights ({len(weights)}) does not match "
166+
f"number of indices ({len(indices)}), reverting to Taylor")
167+
scale = False
168+
weights = fd_weights_registry['taylor'](expr, deriv_order, indices, x0)
169+
161170
# Did fd_weights_registry return a new Function/Expression instead of a values?
162171
_, wdim, _ = process_weights(weights, expr, dim)
163172
if wdim is not None:

devito/operations/interpolators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __repr__(self):
125125
return (f"Interpolation({repr(self.expr)} into "
126126
f"{repr(self.interpolator.sfunction)})")
127127

128+
__str__ = __repr__
129+
128130

129131
class Injection(UnevaluatedSparseOperation):
130132

@@ -152,6 +154,8 @@ def operation(self, **kwargs):
152154
def __repr__(self):
153155
return f"Injection({repr(self.expr)} into {repr(self.field)})"
154156

157+
__str__ = __repr__
158+
155159

156160
class GenericInterpolator(ABC):
157161

devito/types/basic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,6 +1574,16 @@ def adjoint(self, inner=True):
15741574
# Real valued adjoint is transpose
15751575
return self.transpose(inner=inner)
15761576

1577+
def __call__(self, **kwargs):
1578+
"""
1579+
Derivative custom inputs (weights/x0/...) is done through call
1580+
and needs to be applied to each component through applyfunc
1581+
"""
1582+
try:
1583+
return self.applyfunc(lambda x: x(**kwargs))
1584+
except TypeError as e:
1585+
raise f"{self.name} not callable with {kwargs}" from e
1586+
15771587
@call_highest_priority('__radd__')
15781588
def __add__(self, other):
15791589
try:

requirements-optional.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ pillow>11,<11.3.1
33
pyrevolve==2.2.6
44
scipy<1.15.4
55
distributed<2025.7.1
6-
click<9.0
6+
click<9.0
7+
cloudpickle<3.1.2

requirements-testing.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ nbval<0.11.1
66
scipy<1.15.4
77
pooch<1.8.3
88
click<9.0
9+
cloudpickle<3.1.2

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@ cgen>=2020.1,<2026
77
codepy>=2019.1,<2025
88
multidict<6.3
99
anytree>=2.4.3,<=2.13.0
10-
cloudpickle<3.1.2
1110
packaging<25.1

tests/test_derivatives.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,3 +1229,24 @@ def test_expand_product_rule(self):
12291229
+ 10*self.f*Derivative(self.v, self.x)*Derivative(self.u, self.x) \
12301230
+ 5*self.f*self.u*Derivative(self.v, (self.x, 2))
12311231
assert diffify(expr.expand(product_rule=True)) == expanded
1232+
1233+
def test_fallback_wrong_custom_size(self):
1234+
"""
1235+
Check an exception is raised when a custom size is not compatible
1236+
with the derivative order
1237+
"""
1238+
grid = Grid((10,))
1239+
x, = grid.dimensions
1240+
u = Function(name="u", grid=grid, space_order=2, staggered=x)
1241+
v = Function(name="v", grid=grid, space_order=2, staggered=NODE)
1242+
1243+
w = [-2, 2] # Should 2 coeff since this is staggered
1244+
1245+
eq0 = Eq(u, v.dx(w=w)).evaluate
1246+
exp0 = -2 * v / x.spacing + 2 * v._subs(x, x + x.spacing)/x.spacing
1247+
# This one should fallback to taylor coeffs since w is too short
1248+
# for a centered derivative
1249+
eq1 = Eq(v, v.dx(w=w)).evaluate
1250+
exp1 = - .5 * (v._subs(x, x - x.spacing) - v._subs(x, x + x.spacing))/x.spacing
1251+
assert simplify(eq0.rhs - exp0) == 0
1252+
assert simplify(eq1.rhs - exp1) == 0

tests/test_tensors.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,21 @@ def test_custom_coeffs_tensor():
442442
assert list(drv.weights) == c
443443

444444

445+
@pytest.mark.parametrize('func', [TensorFunction, TensorTimeFunction,
446+
VectorFunction, VectorTimeFunction])
447+
def test_custom_coeffs_tensor_basic(func):
448+
grid = Grid(tuple([5]*3))
449+
f = func(name="t", grid=grid, space_order=2)
450+
451+
# Custom coefficients
452+
c = [10, 20, 30]
453+
454+
df = f.dx(w=c)
455+
for (fi, dfi) in zip(f.values(), df.values()):
456+
assert dfi == fi.dx(w=c)
457+
assert list(dfi.weights) == c
458+
459+
445460
@pytest.mark.parametrize('func1', [TensorFunction, TensorTimeFunction,
446461
VectorFunction, VectorTimeFunction])
447462
def test_rebuild(func1):

0 commit comments

Comments
 (0)