Skip to content

Commit 7051473

Browse files
committed
dsl: Add superstep generators to the dsl
1 parent dd1e6c7 commit 7051473

1 file changed

Lines changed: 87 additions & 0 deletions

File tree

devito/timestepping/superstep.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from devito.types import Eq, Function, TimeFunction
2+
3+
4+
def superstep_generator_iterative(field, stencil, k):
5+
''' Generate superstep iteratively:
6+
A^j+1 = A·A^j
7+
'''
8+
# New fields, for vector formulation both current and previous timestep are needed
9+
name = field.name
10+
grid = field.grid
11+
u = TimeFunction(name=f'{name}_ss', grid=grid, time_order=2, space_order=2*k)
12+
u_prev = TimeFunction(name=f'{name}_ss_p', grid=grid, time_order=2, space_order=2*k)
13+
14+
# Substitute new fields into stencil
15+
ss_stencil = stencil.subs({field: u, field.backward: u_prev}, postprocess=False)
16+
ss_stencil = ss_stencil.expand().expand(add=True, nest=True)
17+
current = ss_stencil
18+
19+
# Placeholder fields for forming the superstep
20+
a_tmp = Function(name="a_tmp", grid=grid, space_order=2*k)
21+
b_tmp = Function(name="b_tmp", grid=grid, space_order=2*k)
22+
23+
if k >= 2:
24+
for _ in range(k - 2):
25+
current = current.subs(
26+
{u: a_tmp, u_prev: b_tmp}, postprocess=False).subs(
27+
{a_tmp: ss_stencil, b_tmp: u}, postprocess=False
28+
)
29+
current = current.expand().expand(add=True, nest=True)
30+
else:
31+
current = u
32+
33+
stencil_next = current.subs(
34+
{u: a_tmp, u_prev: b_tmp}, postprocess=False).subs(
35+
{a_tmp: ss_stencil, b_tmp: u}, postprocess=False
36+
)
37+
stencil_next = stencil_next.expand().expand(add=True, nest=True)
38+
return u, u_prev, Eq(u.forward, stencil_next), Eq(u_prev.forward, current)
39+
40+
41+
def superstep_generator(field, stencil, k):
42+
''' Generate superstep using a binary decomposition:
43+
A^k = a_j A^2^j + ... + a_2 A^2^2 + a_1 A² + a_0 A
44+
'''
45+
# New fields, for vector formulation both current and previous timestep are needed
46+
name = field.name
47+
grid = field.grid
48+
u = TimeFunction(name=f'{name}_ss', grid=grid, time_order=2, space_order=2*k)
49+
u_prev = TimeFunction(name=f'{name}_ss_p', grid=grid, time_order=2, space_order=2*k)
50+
51+
# Substitute new fields into stencil
52+
ss_stencil = stencil.subs({field: u, field.backward: u_prev}, postprocess=False)
53+
ss_stencil = ss_stencil.expand().expand(add=True, nest=True)
54+
55+
# Binary decomposition algorithm
56+
current = (ss_stencil, u)
57+
q, r = divmod(k, 2)
58+
accumulate = current if r else (1, 1)
59+
while q:
60+
q, r = divmod(q, 2)
61+
current = combine_superstep(current, current, u, u_prev, k)
62+
if r:
63+
accumulate = combine_superstep(accumulate, current, u, u_prev, k)
64+
65+
return u, u_prev, Eq(u.forward, accumulate[0]), Eq(u_prev.forward, accumulate[1])
66+
67+
def combine_superstep(stencil_a, stencil_b, u, u_prev, k):
68+
''' Combine two arbitrary order supersteps
69+
'''
70+
# Placeholder fields for forming the superstep
71+
grid = u.grid
72+
a_tmp = Function(name="a_tmp", grid=grid, space_order=2*k)
73+
b_tmp = Function(name="b_tmp", grid=grid, space_order=2*k)
74+
75+
new = []
76+
if stencil_a == (1, 1):
77+
new = stencil_b
78+
else:
79+
for stencil in stencil_a:
80+
new_stencil = stencil.subs({u: a_tmp, u_prev: b_tmp}, postprocess=False)
81+
new_stencil = new_stencil.subs(
82+
{a_tmp: stencil_b[0], b_tmp: stencil_b[1]}, postprocess=False
83+
)
84+
new_stencil = new_stencil.expand().expand(add=True, nest=True)
85+
new.append(new_stencil)
86+
87+
return new

0 commit comments

Comments
 (0)