Skip to content

Commit 01e478d

Browse files
committed
api: fix backward compatibility of Substitution, missing function check
1 parent 490a627 commit 01e478d

4 files changed

Lines changed: 32 additions & 2 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def _xreplace(self, subs):
307307
if self in subs:
308308
new = subs.pop(self)
309309
try:
310-
return new._xreplace(subs)
310+
new, flag = new._xreplace(subs)
311+
return new, True
311312
except AttributeError:
312313
return new, True
313314

devito/finite_differences/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def generate_indices(expr, dim, order, side=None, matvec=None, x0=None, nweights
265265
if nweights > 0:
266266
do, dw = order + 1 + order % 2, nweights
267267
if do < dw:
268-
raise ValueError(f"More weights ({nweights}) provided than the maximum"
268+
raise ValueError(f"More weights ({nweights}) provided than the maximum "
269269
f"stencil size ({order + 1}) for order {order} scheme")
270270
elif do > dw:
271271
warning(f"Less weights ({nweights}) provided than the stencil size"

devito/types/equation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def _apply_coeffs(cls, expr, coefficients):
9090
for coeff in coefficients.coefficients:
9191
derivs = [d for d in retrieve_derivatives(expr)
9292
if coeff.dimension in d.dims and
93+
coeff.function in d.expr._functions and
9394
coeff.deriv_order == d.deriv_order.get(coeff.dimension, None)]
9495
if not derivs:
9596
continue

tests/test_symbolic_coefficients.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,31 @@ def test_spacing(self):
344344
df_s = f.dx(weights=coeffs1)
345345

346346
assert sp.simplify(df_s.evaluate - df.evaluate) == 0
347+
348+
def test_backward_compat_mixed(self):
349+
350+
grid = Grid(shape=(11,))
351+
x, = grid.dimensions
352+
353+
f = Function(name='f', grid=grid, space_order=8)
354+
g = Function(name='g', grid=grid, space_order=2)
355+
356+
coeffs0 = np.arange(0, 9)
357+
358+
coeffs = Coefficient(1, f, x, coeffs0)
359+
360+
eq = Eq(f, f.dx * g.dxc, coefficients=Substitutions(coeffs))
361+
362+
derivs = retrieve_derivatives(eq.rhs)
363+
364+
assert len(derivs) == 2
365+
df = [d for d in derivs if d.expr == f][0]
366+
dg = [d for d in derivs if d.expr == g][0]
367+
368+
assert np.all(df.weights == coeffs0)
369+
assert dg.weights is None
370+
371+
eqe = eq.evaluate
372+
assert '7.0*f(x + 3*h_x)' in str(eqe.rhs)
373+
assert '0.5*g(x + h_x)' in str(eqe.rhs)
374+
assert 'g(x + 2*h_x)' not in str(eqe.rhs)

0 commit comments

Comments
 (0)