Skip to content

Commit 490a627

Browse files
authored
Merge pull request #2520 from devitocodes/wood-run-stable
api: Misc fixes for builtins and harmonic averaging
2 parents fa903e4 + d437b25 commit 490a627

10 files changed

Lines changed: 199 additions & 171 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,23 @@ class Mod(DifferentiableOp, sympy.Mod):
623623
__sympy_class__ = sympy.Mod
624624

625625

626+
class SafeInv(Differentiable, sympy.core.function.Application):
627+
_fd_priority = 0
628+
629+
@property
630+
def base(self):
631+
return self.args[1]
632+
633+
@property
634+
def val(self):
635+
return self.args[0]
636+
637+
def __str__(self):
638+
return Pow(self.args[0], -1).__str__()
639+
640+
__repr__ = __str__
641+
642+
626643
class IndexSum(sympy.Expr, Evaluable):
627644

628645
"""
@@ -675,6 +692,8 @@ def __repr__(self):
675692
def _sympystr(self, printer):
676693
return str(self)
677694

695+
_latex = _sympystr
696+
678697
def _hashable_content(self):
679698
return super()._hashable_content() + (self.dimensions,)
680699

devito/operator/operator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,14 +894,17 @@ def apply(self, **kwargs):
894894
>>> op = Operator(Eq(u3.forward, u3 + 1))
895895
>>> summary = op.apply(time_M=10)
896896
"""
897+
# Compile the operator before building the arguments list
898+
# to avoid out of memory with greedy compilers
899+
cfunction = self.cfunction
900+
897901
# Build the arguments list to invoke the kernel function
898902
with self._profiler.timer_on('arguments'):
899903
args = self.arguments(**kwargs)
900904

901905
# Invoke kernel function with args
902906
arg_values = [args[p.name] for p in self.parameters]
903907
try:
904-
cfunction = self.cfunction
905908
with self._profiler.timer_on('apply', comm=args.comm):
906909
retval = cfunction(*arg_values)
907910
except ctypes.ArgumentError as e:

devito/passes/iet/misc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sympy
66

77
from devito.finite_differences import Max, Min
8+
from devito.finite_differences.differentiable import SafeInv
89
from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder,
910
FindApplications, FindNodes, FindSymbols, Transformer,
1011
Uxreplace, filter_iterations, retrieve_iteration_tree,
@@ -225,6 +226,13 @@ def _(expr):
225226
return ()
226227

227228

229+
@_lower_macro_math.register(SafeInv)
230+
def _(expr):
231+
eps = np.finfo(np.float32).resolution**2
232+
return (('SAFEINV(a, b)',
233+
f'(((a) < {eps} || (b) < {eps}) ? (0.0F) : (1.0F / (a)))'),)
234+
235+
228236
@iet_pass
229237
def minimize_symbols(iet):
230238
"""

devito/symbolics/inspection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,11 @@ def estimate_cost(exprs, estimate=False):
114114
estimate_values = {
115115
'elementary': 100,
116116
'pow': 50,
117+
'SafeInv': 10,
117118
'div': 5,
118119
'Abs': 5,
119120
'floor': 1,
120-
'ceil': 1
121+
'ceil': 1,
121122
}
122123

123124

devito/symbolics/printer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ def _print_Pow(self, expr):
141141
else:
142142
return f'pow{suffix}({self._print(expr.base)}, {self._print(expr.exp)})'
143143

144+
def _print_SafeInv(self, expr):
145+
"""Print a SafeInv as a C-like division with a check for zero."""
146+
base = self._print(expr.base)
147+
val = self._print(expr.val)
148+
return f'SAFEINV({val}, {base})'
149+
144150
def _print_Mod(self, expr):
145151
"""Print a Mod as a C-like %-based operation."""
146152
args = ['(%s)' % self._print(a) for a in expr.args]

devito/types/basic.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,12 +1167,16 @@ def _evaluate(self, **kwargs):
11671167
# Apply interpolation from inner most dim
11681168
for d, i in self._grid_map.items():
11691169
retval = retval.diff(d, deriv_order=0, fd_order=2, x0={d: i})
1170-
if self._avg_mode == 'harmonic':
1171-
retval = 1 / retval
11721170

11731171
# Evaluate. Since we used `self.function` it will be on the grid when evaluate
11741172
# is called again within FD
1175-
return retval.evaluate.expand()
1173+
if self._avg_mode == 'harmonic':
1174+
from devito.finite_differences.differentiable import SafeInv
1175+
retval = SafeInv(retval.evaluate, self.function)
1176+
else:
1177+
retval = retval.evaluate
1178+
1179+
return retval
11761180

11771181
@property
11781182
def shape(self):
@@ -1450,12 +1454,19 @@ def indexify(self, indices=None, subs=None):
14501454
# Indices after substitutions
14511455
indices = []
14521456
for a, d, o, s in zip(self.args, self.dimensions, self.origin, subs):
1453-
if d in a.free_symbols:
1457+
if a.is_Function and len(a.args) == 1:
1458+
# E.g. Abs(expr)
1459+
arg = a.args[0]
1460+
func = a.func
1461+
else:
1462+
arg = a
1463+
func = lambda x: x
1464+
if d in arg.free_symbols:
14541465
# Shift by origin d -> d - o.
1455-
indices.append(sympy.sympify(a.subs(d, d - o).xreplace(s)))
1466+
indices.append(func(sympy.sympify(arg.subs(d, d - o).xreplace(s))))
14561467
else:
14571468
# Dimension has been removed, e.g. u[10], plain shift by origin
1458-
indices.append(sympy.sympify(a - o).xreplace(s))
1469+
indices.append(func(sympy.sympify(arg - o).xreplace(s)))
14591470

14601471
indices = [i.xreplace({k: sympy.Integer(k) for k in i.atoms(sympy.Float)})
14611472
for i in indices]

examples/seismic/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,6 @@ def _initialize_physics(self, vp, space_order, **kwargs):
308308
vs = kwargs.pop('vs')
309309
self.lam = self._gen_phys_param((vp**2 - 2. * vs**2)/b, 'lam', space_order,
310310
is_param=True)
311-
# Need to add small value to avoid division by zero
312-
if isinstance(vs, np.ndarray):
313-
vs = vs + 1e-12
314311
self.mu = self._gen_phys_param(vs**2 / b, 'mu', space_order, is_param=True,
315312
avg_mode='harmonic')
316313
else:

examples/seismic/tutorials/06_elastic_varying_parameters.ipynb

Lines changed: 122 additions & 158 deletions
Large diffs are not rendered by default.

tests/test_differentiable.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pytest
55

66
from devito import Function, Grid, Differentiable, NODE
7-
from devito.finite_differences.differentiable import Add, Mul, Pow, diffify, interp_for_fd
7+
from devito.finite_differences.differentiable import (Add, Mul, Pow, diffify,
8+
interp_for_fd, SafeInv)
89

910

1011
def test_differentiable():
@@ -113,4 +114,7 @@ def test_avg_mode(ndim):
113114
assert sympy.simplify(a_avg - 0.5**ndim * sum(a.subs(arg) for arg in args)) == 0
114115

115116
# Harmonic average, h(a[.5]) = 1/(.5/a[0] + .5/a[1])
116-
assert sympy.simplify(b_avg - 1/(0.5**ndim * sum(1/b.subs(arg) for arg in args))) == 0
117+
expected = 1/(0.5**ndim * sum(1/b.subs(arg) for arg in args))
118+
assert sympy.simplify(1/b_avg.args[0] - expected) == 0
119+
assert isinstance(b_avg, SafeInv)
120+
assert b_avg.base == b

tests/test_symbolics.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
1010
Min, Max)
11+
from devito.finite_differences.differentiable import SafeInv
1112
from devito.ir import Expression, FindNodes
1213
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
1314
CallFromPointer, Cast, DefFunction, FieldFromPointer,
@@ -345,6 +346,20 @@ def test_intdiv():
345346
assert ccode(v) == 'b*((a + b) / 2) + 3'
346347

347348

349+
def test_safeinv():
350+
grid = Grid(shape=(11, 11))
351+
x, y = grid.dimensions
352+
353+
u1 = Function(name='u', grid=grid)
354+
u2 = Function(name='u', grid=grid, dtype=np.float64)
355+
356+
op1 = Operator(Eq(u1, SafeInv(u1, u1)))
357+
op2 = Operator(Eq(u2, SafeInv(u2, u2)))
358+
359+
assert 'SAFEINV' in str(op1)
360+
assert 'SAFEINV' in str(op2)
361+
362+
348363
def test_def_function():
349364
foo0 = DefFunction('foo', arguments=['a', 'b'], template=['int'])
350365
foo1 = DefFunction('foo', arguments=['a', 'b'], template=['int'])

0 commit comments

Comments
 (0)