Skip to content

Commit 21579c2

Browse files
committed
compiler: Fix Weights reconstruction
1 parent 34cdc53 commit 21579c2

5 files changed

Lines changed: 46 additions & 33 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,8 +746,8 @@ def __init_finalize__(self, *args, **kwargs):
746746
assert isinstance(weights, (list, tuple, np.ndarray))
747747

748748
# Normalize `weights`
749-
from devito.symbolics import pow_to_mul # noqa, sigh
750-
weights = tuple(pow_to_mul(sympy.sympify(i)) for i in weights)
749+
from devito.symbolics import pow_to_mul, unevaluate # noqa, sigh
750+
weights = tuple(unevaluate(pow_to_mul(sympy.sympify(i))) for i in weights)
751751

752752
kwargs['scope'] = kwargs.get('scope', 'stack')
753753
kwargs['initvalue'] = weights
Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import sympy
2-
31
from devito.ir import cluster_pass
4-
from devito.symbolics import reuse_if_untouched, q_leaf
5-
from devito.symbolics.unevaluation import Add, Mul, Pow
2+
from devito.symbolics import unevaluate as _unevaluate
63

74
__all__ = ['unevaluate']
85

@@ -12,22 +9,3 @@ def unevaluate(cluster):
129
exprs = [_unevaluate(e) for e in cluster.exprs]
1310

1411
return cluster.rebuild(exprs=exprs)
15-
16-
17-
mapper = {
18-
sympy.Add: Add,
19-
sympy.Mul: Mul,
20-
sympy.Pow: Pow
21-
}
22-
23-
24-
def _unevaluate(expr):
25-
if q_leaf(expr):
26-
return expr
27-
28-
args = [_unevaluate(a) for a in expr.args]
29-
30-
try:
31-
return mapper[expr.func](*args)
32-
except KeyError:
33-
return reuse_if_untouched(expr, args)

devito/symbolics/manipulation.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from devito.symbolics.extended_sympy import DefFunction, rfunc
1414
from devito.symbolics.queries import q_leaf
1515
from devito.symbolics.search import retrieve_indexed, retrieve_functions
16-
from devito.symbolics.unevaluation import Mul as UMul
16+
from devito.symbolics.unevaluation import UnevalAdd, UnevalMul, UnevalPow
1717
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
1818
from devito.types.basic import Basic, Indexed
1919
from devito.types.array import ComponentAccess
@@ -22,7 +22,7 @@
2222

2323
__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args',
2424
'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite',
25-
'reuse_if_untouched', 'evalrel', 'flatten_args']
25+
'reuse_if_untouched', 'evalrel', 'flatten_args', 'unevaluate']
2626

2727

2828
def uxreplace(expr, rule):
@@ -338,7 +338,7 @@ def pow_to_mul(expr):
338338
# but at least we traverse the base looking for other Pows
339339
return expr.func(pow_to_mul(base), exp, evaluate=False)
340340
elif exp > 0:
341-
return UMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
341+
return UnevalMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
342342
elif exp < 0:
343343
# Reciprocal powers become inverse of the negative power
344344
# for example Pow(expr, -2) becomes Pow(expr * expr, -1)
@@ -502,3 +502,18 @@ def evalrel(func=min, input=None, assumptions=None):
502502
except TypeError:
503503
pass
504504
return rfunc(func, *input)
505+
506+
507+
uneval_mapper = {Add: UnevalAdd, Mul: UnevalMul, Pow: UnevalPow}
508+
509+
510+
def unevaluate(expr):
511+
if q_leaf(expr):
512+
return expr
513+
514+
args = [unevaluate(a) for a in expr.args]
515+
516+
try:
517+
return uneval_mapper[expr.func](*args)
518+
except KeyError:
519+
return reuse_if_untouched(expr, args)

devito/symbolics/unevaluation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sympy
22

3-
__all__ = ['Add', 'Mul', 'Pow']
3+
__all__ = ['UnevalAdd', 'UnevalMul', 'UnevalPow']
44

55

66
class UnevaluableMixin:
@@ -9,15 +9,15 @@ def __new__(cls, *args, evaluate=None, **kwargs):
99
return cls.__base__.__new__(cls, *args, evaluate=False, **kwargs)
1010

1111

12-
class Add(sympy.Add, UnevaluableMixin):
12+
class UnevalAdd(sympy.Add, UnevaluableMixin):
1313
__new__ = UnevaluableMixin.__new__
1414

1515

16-
class Mul(sympy.Mul, UnevaluableMixin):
16+
class UnevalMul(sympy.Mul, UnevaluableMixin):
1717
__new__ = UnevaluableMixin.__new__
1818

1919

20-
class Pow(sympy.Pow, UnevaluableMixin):
20+
class UnevalPow(sympy.Pow, UnevaluableMixin):
2121

2222
def __new__(cls, base, exp, evaluate=None, **kwargs):
2323
if base == 1:

tests/test_pickle.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
PrecomputedSparseTimeFunction, SubDomain)
1313
from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext
1414
from devito.data import LEFT, OWNED
15+
from devito.finite_differences.differentiable import Weights
1516
from devito.finite_differences.tools import direct, transpose, left, right, centered
1617
from devito.mpi.halo_scheme import Halo
1718
from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject,
1819
MPIRegion)
1920
from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar,
2021
PointerArray, Lock, PThreadArray, SharedData, Timer,
2122
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
22-
FIndexed)
23+
FIndexed, StencilDimension)
2324
from devito.types.basic import BoundSymbol, AbstractSymbol
2425
from devito.tools import EnrichedTuple
2526
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
@@ -416,6 +417,25 @@ def test_findexed(self, pickle):
416417
assert new_fi.indices == (x+1, y, z-2)
417418
assert new_fi.strides_map == fi.strides_map
418419

420+
def test_weights_to_array(self, pickle):
421+
grid = Grid(shape=(3, 3, 3))
422+
x, y, z = grid.dimensions
423+
h_x = x.spacing
424+
425+
i = StencilDimension('i0', 0, 2)
426+
w = Weights(name='w0', dimensions=i,
427+
initvalue=[1/(h_x**2), 2/(h_x**2), 3/(h_x**2)])
428+
a = Array(name='w0', dimensions=w.dimensions, initvalue=w.initvalue,
429+
scope='stack')
430+
431+
pkl_a = pickle.dumps(a)
432+
new_a = pickle.loads(pkl_a)
433+
434+
# Weights optimizes `initvalue` by turning pows into muls. This test checks
435+
# that the optimization is correctly carried over to the pickled object
436+
# (in practice, the optimized expressions must have been frozen)
437+
assert a.initvalue == new_a.initvalue
438+
419439
def test_symbolics(self, pickle):
420440
a = Symbol('a')
421441

0 commit comments

Comments
 (0)