Skip to content

Commit 34d6ec3

Browse files
committed
api: fix interp/eval of expressions
1 parent 676b8bf commit 34d6ec3

3 files changed

Lines changed: 69 additions & 65 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -475,12 +475,19 @@ def has_free(self, *patterns):
475475

476476

477477
def highest_priority(DiffOp):
478+
if not DiffOp._args_diff:
479+
return DiffOp
478480
# We want to get the object with highest priority
479481
# We also need to make sure that the object with the largest
480482
# set of dimensions is used when multiple ones with the same
481483
# priority appear
482484
prio = lambda x: (getattr(x, '_fd_priority', 0), len(x.dimensions))
483-
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
485+
prio_func = sorted(DiffOp._args_diff, key=prio, reverse=True)[0]
486+
487+
# The highest priority must be a Function
488+
if not isinstance(prio_func, AbstractFunction):
489+
return highest_priority(prio_func)
490+
return prio_func
484491

485492

486493
class DifferentiableOp(Differentiable):
@@ -548,8 +555,9 @@ class DifferentiableFunction(DifferentiableOp):
548555
def __new__(cls, *args, **kwargs):
549556
return cls.__sympy_class__.__new__(cls, *args, **kwargs)
550557

551-
def _eval_at(self, func):
552-
return self
558+
@property
559+
def _fd_priority(self):
560+
return highest_priority(self)._fd_priority
553561

554562

555563
class Add(DifferentiableOp, sympy.Add):
@@ -633,26 +641,12 @@ def _gather_for_diff(self):
633641
if len(set(f.staggered for f in self._args_diff)) == 1:
634642
return self
635643

636-
func_args = highest_priority(self)
637-
new_args = []
638-
ref_inds = func_args.indices_ref.getters
639-
640-
for f in self.args:
641-
if f not in self._args_diff \
642-
or f is func_args \
643-
or isinstance(f, DifferentiableFunction):
644-
new_args.append(f)
645-
else:
646-
ind_f = f.indices_ref.getters
647-
mapper = {ind_f.get(d, d): ref_inds.get(d, d)
648-
for d in self.dimensions
649-
if ind_f.get(d, d) is not ref_inds.get(d, d)}
650-
if mapper:
651-
new_args.append(f.subs(mapper))
652-
else:
653-
new_args.append(f)
654-
655-
return self.func(*new_args, evaluate=False)
644+
derivs, other = split(self.args, lambda a: isinstance(a, sympy.Derivative))
645+
if len(derivs) == 0:
646+
return self._eval_at(highest_priority(self))
647+
else:
648+
other = self.func(*other)._eval_at(highest_priority(self))
649+
return self.func(other, *derivs)
656650

657651

658652
class Pow(DifferentiableOp, sympy.Pow):
@@ -1034,6 +1028,9 @@ def __new__(cls, *args, base=None, **kwargs):
10341028
obj = super().__new__(cls, *args, **kwargs)
10351029

10361030
try:
1031+
if base is obj:
1032+
# In some rare cases (rebuild?) base may be obj itself
1033+
base = base.base
10371034
obj.base = base
10381035
except AttributeError:
10391036
# This might happen if e.g. one attempts a (re)construction with
@@ -1061,6 +1058,10 @@ def _eval_at(self, func):
10611058
# and should not be re-evaluated at a different location
10621059
return self
10631060

1061+
@property
1062+
def indices_ref(self):
1063+
return self.base.indices_ref
1064+
10641065

10651066
class diffify:
10661067

@@ -1184,6 +1185,14 @@ def _(expr, x0, **kwargs):
11841185
return expr.func(interp_for_fd(expr.expr, x0_expr, **kwargs))
11851186

11861187

1188+
@interp_for_fd.register(DifferentiableOp)
1189+
def _(expr, x0, **kwargs):
1190+
# For a expression (e.g Mul or Add), we interpolate the whole expression
1191+
d_dims = tuple((d, 0) for d in x0)
1192+
fd_order = tuple(expr.interp_order for d in x0)
1193+
return expr.diff(*d_dims, fd_order=fd_order, x0=x0, **kwargs)
1194+
1195+
11871196
@interp_for_fd.register(sympy.Expr)
11881197
def _(expr, x0, **kwargs):
11891198
if expr.args:

devito/finite_differences/tools.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,9 @@ def check_input(func):
5050
def wrapper(expr, *args, **kwargs):
5151
try:
5252
return S.Zero if expr.is_Number else func(expr, *args, **kwargs)
53-
except AttributeError:
54-
raise ValueError(
55-
f"'{expr}' must be of type Differentiable, not {type(expr)}"
56-
) from None
53+
except Exception as e:
54+
raise type(e)(f"Error while computing finite-difference for expr={expr}: "
55+
f"{e}") from e
5756
return wrapper
5857

5958

examples/seismic/tti/operators.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ def kernel_staggered_2d(model, u, v, **kwargs):
280280
epsilon = 1 + 2 * epsilon
281281
delta = sqrt(1 + 2 * delta)
282282
s = model.grid.stepping_dim.spacing
283-
x, z = model.grid.dimensions
284283

285284
# Get source
286285
qu = kwargs.get('qu', 0)
@@ -291,31 +290,31 @@ def kernel_staggered_2d(model, u, v, **kwargs):
291290

292291
if forward:
293292
# Stencils
294-
phdx = costheta * u.dx - sintheta * u.dyc
293+
phdx = costheta * u.dx - sintheta * u.dy
295294
u_vx = Eq(vx.forward, dampl * vx - dampl * s * phdx)
296295

297-
pvdz = sintheta * v.dxc + costheta * v.dy
296+
pvdz = sintheta * v.dx + costheta * v.dy
298297
u_vz = Eq(vz.forward, dampl * vz - dampl * s * pvdz)
299298

300-
dvx = costheta * vx.forward.dx - sintheta * vx.forward.dyc
301-
dvz = sintheta * vz.forward.dxc + costheta * vz.forward.dy
299+
dvx = costheta * vx.forward.dx - sintheta * vx.forward.dy
300+
dvz = sintheta * vz.forward.dx + costheta * vz.forward.dy
302301

303302
# u and v equations
304303
pv_eq = Eq(v.forward, dampl * (v - s / m * (delta * dvx + dvz)) + s / m * qv)
305304
ph_eq = Eq(u.forward, dampl * (u - s / m * (epsilon * dvx + delta * dvz)) +
306305
s / m * qu)
307306
else:
308307
# Stencils
309-
phdx = ((costheta*epsilon*u).dx - (sintheta*epsilon*u).dyc +
310-
(costheta*delta*v).dx - (sintheta*delta*v).dyc)
308+
a = epsilon * u + delta * v
309+
phdx = (costheta * a).dx - (sintheta * a).dy
311310
u_vx = Eq(vx.backward, dampl * vx + dampl * s * phdx)
312311

313-
pvdz = ((sintheta*delta*u).dxc + (costheta*delta*u).dy +
314-
(sintheta*v).dxc + (costheta*v).dy)
312+
b = delta * u + v
313+
pvdz = (sintheta * b).dx + (costheta * b).dy
315314
u_vz = Eq(vz.backward, dampl * vz + dampl * s * pvdz)
316315

317-
dvx = (costheta * vx.backward).dx - (sintheta * vx.backward).dyc
318-
dvz = (sintheta * vz.backward).dxc + (costheta * vz.backward).dy
316+
dvx = (costheta * vx.backward).dx - (sintheta * vx.backward).dy
317+
dvz = (sintheta * vz.backward).dx + (costheta * vz.backward).dy
319318

320319
# u and v equations
321320
pv_eq = Eq(v.backward, dampl * (v + s / m * dvz))
@@ -356,24 +355,24 @@ def kernel_staggered_3d(model, u, v, **kwargs):
356355
if forward:
357356
# Stencils
358357
phdx = (costheta * cosphi * u.dx +
359-
costheta * sinphi * u.dyc -
360-
sintheta * u.dzc)
358+
costheta * sinphi * u.dy -
359+
sintheta * u.dz)
361360
u_vx = Eq(vx.forward, dampl * vx - dampl * s * phdx)
362361

363-
phdy = -sinphi * u.dxc + cosphi * u.dy
362+
phdy = -sinphi * u.dx + cosphi * u.dy
364363
u_vy = Eq(vy.forward, dampl * vy - dampl * s * phdy)
365364

366-
pvdz = (sintheta * cosphi * v.dxc +
367-
sintheta * sinphi * v.dyc +
365+
pvdz = (sintheta * cosphi * v.dx +
366+
sintheta * sinphi * v.dy +
368367
costheta * v.dz)
369368
u_vz = Eq(vz.forward, dampl * vz - dampl * s * pvdz)
370369

371370
dvx = (costheta * cosphi * vx.forward.dx +
372-
costheta * sinphi * vx.forward.dyc -
373-
sintheta * vx.forward.dzc)
374-
dvy = -sinphi * vy.forward.dxc + cosphi * vy.forward.dy
375-
dvz = (sintheta * cosphi * vz.forward.dxc +
376-
sintheta * sinphi * vz.forward.dyc +
371+
costheta * sinphi * vx.forward.dy -
372+
sintheta * vx.forward.dz)
373+
dvy = -sinphi * vy.forward.dx + cosphi * vy.forward.dy
374+
dvz = (sintheta * cosphi * vz.forward.dx +
375+
sintheta * sinphi * vz.forward.dy +
377376
costheta * vz.forward.dz)
378377
# u and v equations
379378
pv_eq = Eq(v.forward, dampl * (v - s / m * (delta * (dvx + dvy) + dvz)) +
@@ -383,30 +382,27 @@ def kernel_staggered_3d(model, u, v, **kwargs):
383382
delta * dvz)) + s / m * qu)
384383
else:
385384
# Stencils
386-
phdx = ((costheta * cosphi * epsilon*u).dx +
387-
(costheta * sinphi * epsilon*u).dyc -
388-
(sintheta * epsilon*u).dzc + (costheta * cosphi * delta*v).dx +
389-
(costheta * sinphi * delta*v).dyc -
390-
(sintheta * delta*v).dzc)
385+
a = epsilon * u + delta * v
386+
phdx = ((costheta * cosphi * a).dx +
387+
(costheta * sinphi * a).dy -
388+
(sintheta * a).dz)
391389
u_vx = Eq(vx.backward, dampl * vx + dampl * s * phdx)
392390

393-
phdy = (-(sinphi * epsilon*u).dxc + (cosphi * epsilon*u).dy -
394-
(sinphi * delta*v).dxc + (cosphi * delta*v).dy)
391+
phdy = (-(sinphi * a).dx + (cosphi * a).dy)
395392
u_vy = Eq(vy.backward, dampl * vy + dampl * s * phdy)
396393

397-
pvdz = ((sintheta * cosphi * delta*u).dxc +
398-
(sintheta * sinphi * delta*u).dyc +
399-
(costheta * delta*u).dz + (sintheta * cosphi * v).dxc +
400-
(sintheta * sinphi * v).dyc +
401-
(costheta * v).dz)
394+
b = delta * u + v
395+
pvdz = ((sintheta * cosphi * b).dx +
396+
(sintheta * sinphi * b).dy +
397+
(costheta * b).dz)
402398
u_vz = Eq(vz.backward, dampl * vz + dampl * s * pvdz)
403399

404400
dvx = ((costheta * cosphi * vx.backward).dx +
405-
(costheta * sinphi * vx.backward).dyc -
406-
(sintheta * vx.backward).dzc)
407-
dvy = (-sinphi * vy.backward).dxc + (cosphi * vy.backward).dy
408-
dvz = ((sintheta * cosphi * vz.backward).dxc +
409-
(sintheta * sinphi * vz.backward).dyc +
401+
(costheta * sinphi * vx.backward).dy -
402+
(sintheta * vx.backward).dz)
403+
dvy = (-sinphi * vy.backward).dx + (cosphi * vy.backward).dy
404+
dvz = ((sintheta * cosphi * vz.backward).dx +
405+
(sintheta * sinphi * vz.backward).dy +
410406
(costheta * vz.backward).dz)
411407
# u and v equations
412408
pv_eq = Eq(v.backward, dampl * (v + s / m * dvz))

0 commit comments

Comments
 (0)