|
12 | 12 | PrecomputedSparseTimeFunction, SubDomain) |
13 | 13 | from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext |
14 | 14 | from devito.data import LEFT, OWNED |
| 15 | +from devito.finite_differences.differentiable import Weights |
15 | 16 | from devito.finite_differences.tools import direct, transpose, left, right, centered |
16 | 17 | from devito.mpi.halo_scheme import Halo |
17 | 18 | from devito.mpi.routines import (MPIStatusObject, MPIMsgEnriched, MPIRequestObject, |
18 | 19 | MPIRegion) |
19 | 20 | from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar, |
20 | 21 | PointerArray, Lock, PThreadArray, SharedData, Timer, |
21 | 22 | DeviceID, NPThreads, ThreadID, TempFunction, Indirection, |
22 | | - FIndexed) |
| 23 | + FIndexed, StencilDimension) |
23 | 24 | from devito.types.basic import BoundSymbol, AbstractSymbol |
24 | 25 | from devito.tools import EnrichedTuple |
25 | 26 | from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, |
@@ -416,6 +417,25 @@ def test_findexed(self, pickle): |
416 | 417 | assert new_fi.indices == (x+1, y, z-2) |
417 | 418 | assert new_fi.strides_map == fi.strides_map |
418 | 419 |
|
| 420 | + def test_weights_to_array(self, pickle): |
| 421 | + grid = Grid(shape=(3, 3, 3)) |
| 422 | + x, y, z = grid.dimensions |
| 423 | + h_x = x.spacing |
| 424 | + |
| 425 | + i = StencilDimension('i0', 0, 2) |
| 426 | + w = Weights(name='w0', dimensions=i, |
| 427 | + initvalue=[1/(h_x**2), 2/(h_x**2), 3/(h_x**2)]) |
| 428 | + a = Array(name='w0', dimensions=w.dimensions, initvalue=w.initvalue, |
| 429 | + scope='stack') |
| 430 | + |
| 431 | + pkl_a = pickle.dumps(a) |
| 432 | + new_a = pickle.loads(pkl_a) |
| 433 | + |
| 434 | + # Weights optimizes `initvalue` by turning pows into muls. This test checks |
| 435 | + # that the optimization is correctly carried over to the pickled object |
| 436 | + # (in practice, the optimized expressions must have been frozen) |
| 437 | + assert a.initvalue == new_a.initvalue |
| 438 | + |
419 | 439 | def test_symbolics(self, pickle): |
420 | 440 | a = Symbol('a') |
421 | 441 |
|
|
0 commit comments