Skip to content

Commit d70ded4

Browse files
committed
api: Fix custom coefficients inlining
1 parent 8cec681 commit d70ded4

2 files changed

Lines changed: 28 additions & 4 deletions

File tree

devito/finite_differences/finite_difference.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,11 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
162162
weights = [weights._subs(wdim, i) for i in range(len(indices))]
163163

164164
# Enforce fixed precision FD coefficients to avoid variations in results
165-
weights = [sympify(w).evalf(_PRECISION) for w in weights]
165+
if scale:
166+
scale = dim.spacing**(-deriv_order)
167+
else:
168+
scale = 1
169+
weights = [sympify(scale * w).evalf(_PRECISION) for w in weights]
166170

167171
# Transpose the FD, if necessary
168172
if matvec == transpose:
@@ -208,7 +212,4 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
208212

209213
deriv = EvalDerivative(*terms, base=expr)
210214

211-
if scale:
212-
deriv = dim.spacing**(-deriv_order) * deriv
213-
214215
return deriv

tests/test_unexpansion.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,29 @@ def test_legacy_api(self, order, nweight):
105105
op = Operator(eqn, opt=('advanced', {'expand': False}))
106106
assert f'{wdef} = {wstr}' in str(op)
107107

108+
def test_legacy_api_v2(self):
109+
grid = Grid(shape=(10, 10, 10))
110+
x, y, z = grid.dimensions
111+
112+
u = TimeFunction(name='u', grid=grid, space_order=4)
113+
114+
cc = np.array([2, 2, 2, 2, 2])
115+
coeffs = [Coefficient(1, u, d, cc) for d in grid.dimensions]
116+
coeffs = Substitutions(*coeffs)
117+
118+
eq0 = Eq(u.forward, u.dx.dz + 1.0)
119+
eq1 = Eq(u.forward, u.dx.dz + 1.0, coefficients=coeffs)
120+
121+
op0 = Operator(eq0, opt=('advanced', {'expand': False}))
122+
op1 = Operator(eq1, opt=('advanced', {'expand': False}))
123+
124+
assert (op0._profiler._sections['section0'].sops ==
125+
op1._profiler._sections['section0'].sops)
126+
weights = [i for i in FindSymbols().visit(op1) if isinstance(i, Weights)]
127+
w0, w1 = sorted(weights, key=lambda i: i.name)
128+
assert all(i.args[1] == 1/x.spacing for i in w0.weights)
129+
assert all(i.args[1] == 1/z.spacing for i in w1.weights)
130+
108131

109132
class Test1Pass:
110133

0 commit comments

Comments
 (0)