Skip to content

Commit af0bd23

Browse files
committed
api: add fallback from wrong sized custom coeffs
1 parent 636f70e commit af0bd23

7 files changed

Lines changed: 34 additions & 7 deletions

File tree

.github/workflows/asv.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@ on:
1313
inputs:
1414
tags:
1515
description: 'Run ASV'
16-
# Trigger the workflow on push to the main branch
17-
push:
18-
branches:
19-
- main
2016

2117
jobs:
2218

devito/finite_differences/finite_difference.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from collections.abc import Iterable
2+
13
from sympy import sympify
24

5+
from devito.logger import warning
36
from .differentiable import EvalDerivative, DiffDerivative, Weights
47
from .tools import (left, right, generate_indices, centered, direct, transpose,
58
check_input, fd_weights_registry, process_weights)
@@ -158,6 +161,12 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
158161
# `coefficients` method (`taylor` or `symbolic`)
159162
if weights is None:
160163
weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0)
164+
if isinstance(weights, Iterable) and len(weights) != len(indices):
165+
warning(f"Number of weights ({len(weights)}) does not match "
166+
f"number of indices ({len(indices)}), reverting to Taylor")
167+
scale = False
168+
weights = fd_weights_registry['taylor'](expr, deriv_order, indices, x0)
169+
161170
# Did fd_weights_registry return a new Function/Expression instead of a values?
162171
_, wdim, _ = process_weights(weights, expr, dim)
163172
if wdim is not None:

requirements-optional.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ pillow>11,<11.3.1
33
pyrevolve==2.2.6
44
scipy<1.15.4
55
distributed<2025.7.1
6-
click<9.0
6+
click<9.0
7+
cloudpickle<3.1.2

requirements-testing.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ nbval<0.11.1
66
scipy<1.15.4
77
pooch<1.8.3
88
click<9.0
9+
cloudpickle<3.1.2

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,4 @@ cgen>=2020.1,<2026
77
codepy>=2019.1,<2025
88
multidict<6.3
99
anytree>=2.4.3,<=2.13.0
10-
cloudpickle<3.1.2
1110
packaging<25.1

tests/test_derivatives.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,3 +1229,24 @@ def test_expand_product_rule(self):
12291229
+ 10*self.f*Derivative(self.v, self.x)*Derivative(self.u, self.x) \
12301230
+ 5*self.f*self.u*Derivative(self.v, (self.x, 2))
12311231
assert diffify(expr.expand(product_rule=True)) == expanded
1232+
1233+
def test_fallback_wrong_custom_size(self):
1234+
"""
1235+
Check an exception is raised when a custom size is not compatible
1236+
with the derivative order
1237+
"""
1238+
grid = Grid((10,))
1239+
x, = grid.dimensions
1240+
u = Function(name="u", grid=grid, space_order=2, staggered=x)
1241+
v = Function(name="v", grid=grid, space_order=2, staggered=NODE)
1242+
1243+
w = [-2, 2] # Should 2 coeff since this is staggered
1244+
1245+
eq0 = Eq(u, v.dx(w=w)).evaluate
1246+
exp0 = -2 * v / x.spacing + 2 * v._subs(x, x + x.spacing)/x.spacing
1247+
# This one should fallback to taylor coeffs since w is too short
1248+
# for a centered derivative
1249+
eq1 = Eq(v, v.dx(w=w)).evaluate
1250+
exp1 = - .5 * (v._subs(x, x - x.spacing) - v._subs(x, x + x.spacing))/x.spacing
1251+
assert simplify(eq0.rhs - exp0) == 0
1252+
assert simplify(eq1.rhs - exp1) == 0

tests/test_tensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def test_custom_coeffs_tensor_basic(func):
449449
f = func(name="t", grid=grid, space_order=2)
450450

451451
# Custom coefficients
452-
c = [10, 10, 10]
452+
c = [10, 20, 30]
453453

454454
df = f.dx(w=c)
455455
for (fi, dfi) in zip(f.values(), df.values()):

0 commit comments

Comments
 (0)