Skip to content

Commit fa903e4

Browse files
authored
Merge pull request #2519 from devitocodes/maybe-fix-coeffs
compiler: Tweak custom coefficients error handling
2 parents da2c9a4 + 35d0142 commit fa903e4

2 files changed

Lines changed: 46 additions & 12 deletions

File tree

devito/finite_differences/tools.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational
66

7+
from devito.logger import warning
78
from devito.tools import Tag, as_tuple
89
from devito.types.dimension import StencilDimension
910

@@ -260,6 +261,18 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
260261
-------
261262
An IndexSet, representing an ordered list of indices.
262263
"""
264+
# Check size of input weights
265+
if nweights > 0:
266+
do, dw = order + 1 + order % 2, nweights
267+
if do < dw:
268+
raise ValueError(f"More weights ({nweights}) provided than the maximum"
269+
f"stencil size ({order + 1}) for order {order} scheme")
270+
elif do > dw:
271+
warning(f"Less weights ({nweights}) provided than the stencil size"
272+
f"({order + 1}) for order {order} scheme."
273+
" Reducing order to {nweights//2}")
274+
order = nweights - nweights % 2
275+
263276
# Evaluation point
264277
x0 = sympify(((x0 or {}).get(dim) or expr.indices_ref[dim]))
265278

@@ -276,23 +289,15 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
276289
side = side or centered
277290

278291
# Indices range
279-
o_min = int(np.ceil(mid - order/2)) + side.val
280-
o_max = int(np.floor(mid + order/2)) + side.val
292+
r = (nweights or order) / 2
293+
o_min = int(np.ceil(mid - r)) + side.val
294+
o_max = int(np.floor(mid + r)) + side.val
281295
if o_max == o_min:
282296
if dim.is_Time or not expr.is_Staggered:
283297
o_max += 1
284298
else:
285299
o_min -= 1
286300

287-
if nweights > 0 and (o_max - o_min + 1) != nweights:
288-
# We cannot infer how the stencil should be centered
289-
# if nweights is more than one extra point.
290-
assert nweights == (o_max - o_min + 1) + 1
291-
# In the "one extra" case we need to pad with one point to symmetrize
292-
if (o_max - mid) > (mid - o_min):
293-
o_min -= 1
294-
else:
295-
o_max += 1
296301
# StencilDimension and expression
297302
d = make_stencil_dimension(expr, o_min, o_max)
298303
iexpr = expr.indices_ref[dim] + d * dim.spacing

tests/test_unexpansion.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from conftest import assert_structure, get_params, get_arrays, check_array
55
from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator,
6-
cos, sin)
6+
Coefficient, Substitutions, cos, sin)
77
from devito.finite_differences import Weights
88
from devito.arch.compiler import OneapiCompiler
99
from devito.ir import Expression, FindNodes, FindSymbols
@@ -76,6 +76,35 @@ def test_multiple_cross_derivs(self, coeffs, expected):
7676
weights = {f for f in functions if isinstance(f, Weights)}
7777
assert len(weights) == expected
7878

79+
@pytest.mark.parametrize('order', [1, 2])
80+
@pytest.mark.parametrize('nweight', [None, +4, -4])
81+
def test_legacy_api(self, order, nweight):
82+
grid = Grid(shape=(51, 51, 51))
83+
x, y, z = grid.dimensions
84+
85+
nweight = 0 if nweight is None else nweight
86+
so = 8
87+
88+
u = TimeFunction(name='u', grid=grid, space_order=so,
89+
coefficients='symbolic')
90+
91+
w0 = np.arange(so + 1 + nweight) + 1
92+
wstr = '{' + ', '.join([f"{w:1.1f}F" for w in w0]) + '}'
93+
wdef = f'[{so + 1 + nweight}] __attribute__ ((aligned (64)))'
94+
95+
coeffs_x_p1 = Coefficient(order, u, x, w0)
96+
97+
coeffs = Substitutions(coeffs_x_p1)
98+
99+
eqn = Eq(u, u.dx.dy + u.dx2 + .37, coefficients=coeffs)
100+
101+
if nweight > 0:
102+
with pytest.raises(ValueError):
103+
op = Operator(eqn, opt=('advanced', {'expand': False}))
104+
else:
105+
op = Operator(eqn, opt=('advanced', {'expand': False}))
106+
assert f'{wdef} = {wstr}' in str(op)
107+
79108

80109
class Test1Pass:
81110

0 commit comments

Comments
 (0)