Skip to content

Commit f71764a

Browse files
authored
Merge pull request #2524 from devitocodes/fix-derivs-7865
api: Fix custom coefficients inlining
2 parents 8cec681 + 0698917 commit f71764a

5 files changed

Lines changed: 35 additions & 10 deletions

File tree

devito/finite_differences/finite_difference.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,11 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
162162
weights = [weights._subs(wdim, i) for i in range(len(indices))]
163163

164164
# Enforce fixed precision FD coefficients to avoid variations in results
165-
weights = [sympify(w).evalf(_PRECISION) for w in weights]
165+
if scale:
166+
scale = dim.spacing**(-deriv_order)
167+
else:
168+
scale = 1
169+
weights = [sympify(scale * w).evalf(_PRECISION) for w in weights]
166170

167171
# Transpose the FD, if necessary
168172
if matvec == transpose:
@@ -208,7 +212,4 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici
208212

209213
deriv = EvalDerivative(*terms, base=expr)
210214

211-
if scale:
212-
deriv = dim.spacing**(-deriv_order) * deriv
213-
214215
return deriv

devito/types/equation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _apply_coeffs(cls, expr, coefficients):
9898
if not mapper:
9999
return expr
100100

101-
return expr.xreplace(mapper)
101+
return expr.subs(mapper)
102102

103103
def _evaluate(self, **kwargs):
104104
"""

examples/seismic/tutorials/07_DRP_schemes.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
"name": "stdout",
8484
"output_type": "stream",
8585
"text": [
86-
"Eq(-u(t, x, y)/dt + u(t + dt, x, y)/dt + (0.1*u(t, x, y) - 0.6*u(t, x - h_x, y) + 0.6*u(t, x + h_x, y))/h_x, 0)\n"
86+
"Eq(-u(t, x, y)/dt + u(t + dt, x, y)/dt + 0.1*u(t, x, y)/h_x - 0.6*u(t, x - h_x, y)/h_x + 0.6*u(t, x + h_x, y)/h_x, 0)\n"
8787
]
8888
}
8989
],

tests/test_symbolic_coefficients.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ def test_staggered_equation(self):
202202

203203
eq_f = Eq(f, f.dx2(weights=weights))
204204

205-
expected = 'Eq(f(x + h_x/2), (1.0*f(x - h_x/2) - 2.0*f(x + h_x/2)'\
206-
' + 1.0*f(x + 3*h_x/2))/h_x**2)'
205+
expected = 'Eq(f(x + h_x/2), 1.0*f(x - h_x/2)/h_x**2 - 2.0*f(x + h_x/2)/h_x**2 '\
206+
'+ 1.0*f(x + 3*h_x/2)/h_x**2)'
207207
assert(str(eq_f.evaluate) == expected)
208208

209209
@pytest.mark.parametrize('stagger', [True, False])

tests/test_unexpansion.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_numeric_coeffs(self):
5151
Operator(Eq(u, (v*u.dx).dy(weights=w)), opt=opt).cfunction
5252

5353
@pytest.mark.parametrize('coeffs,expected', [
54-
((7, 7, 7), 1), # We've had a bug triggered by identical coeffs
54+
((7, 7, 7), 3), # We've had a bug triggered by identical coeffs
5555
((5, 7, 9), 3),
5656
])
5757
def test_multiple_cross_derivs(self, coeffs, expected):
@@ -89,7 +89,8 @@ def test_legacy_api(self, order, nweight):
8989
coefficients='symbolic')
9090

9191
w0 = np.arange(so + 1 + nweight) + 1
92-
wstr = '{' + ', '.join([f"{w:1.1f}F" for w in w0]) + '}'
92+
s = f'({x.spacing}*{x.spacing})' if order == 2 else f'{x.spacing}'
93+
wstr = f'{{{w0[0]:1.1f}F/{s},'
9394
wdef = f'[{so + 1 + nweight}] __attribute__ ((aligned (64)))'
9495

9596
coeffs_x_p1 = Coefficient(order, u, x, w0)
@@ -105,6 +106,29 @@ def test_legacy_api(self, order, nweight):
105106
op = Operator(eqn, opt=('advanced', {'expand': False}))
106107
assert f'{wdef} = {wstr}' in str(op)
107108

109+
def test_legacy_api_v2(self):
110+
grid = Grid(shape=(10, 10, 10))
111+
x, y, z = grid.dimensions
112+
113+
u = TimeFunction(name='u', grid=grid, space_order=4)
114+
115+
cc = np.array([2, 2, 2, 2, 2])
116+
coeffs = [Coefficient(1, u, d, cc) for d in grid.dimensions]
117+
coeffs = Substitutions(*coeffs)
118+
119+
eq0 = Eq(u.forward, u.dx.dz + 1.0)
120+
eq1 = Eq(u.forward, u.dx.dz + 1.0, coefficients=coeffs)
121+
122+
op0 = Operator(eq0, opt=('advanced', {'expand': False}))
123+
op1 = Operator(eq1, opt=('advanced', {'expand': False}))
124+
125+
assert (op0._profiler._sections['section0'].sops ==
126+
op1._profiler._sections['section0'].sops)
127+
weights = [i for i in FindSymbols().visit(op1) if isinstance(i, Weights)]
128+
w0, w1 = sorted(weights, key=lambda i: i.name)
129+
assert all(i.args[1] == 1/x.spacing for i in w0.weights)
130+
assert all(i.args[1] == 1/z.spacing for i in w1.weights)
131+
108132

109133
class Test1Pass:
110134

0 commit comments

Comments
 (0)