Skip to content

Commit 9fde3d8

Browse files
committed
misc: PR comments part 1
1 parent b1b0846 commit 9fde3d8

4 files changed

Lines changed: 51 additions & 42 deletions

File tree

devito/timestepping/superstep.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from devito.types import Eq, Function, TimeFunction
22

33

4-
def superstep_generator_iterative(field, stencil, k, tn=0):
5-
''' Generate superstep iteratively:
4+
def superstep_generator_iterative(field, stencil, k, nt=0):
5+
"""
6+
Generate superstep iteratively:
67
Aʲ⁺¹ = A·Aʲ
7-
'''
8+
"""
89
# New fields, for vector formulation both current and previous timestep are needed
910
name = field.name
1011
grid = field.grid
1112
u = TimeFunction(name=f'{name}_ss', grid=grid, time_order=2, space_order=2*k)
1213
u_prev = TimeFunction(name=f'{name}_ss_p', grid=grid, time_order=2, space_order=2*k)
1314

14-
superstep_solution_transfer(field, u, u_prev, tn)
15+
superstep_solution_transfer(field, u, u_prev, nt)
1516

1617
# Substitute new fields into stencil
1718
ss_stencil = stencil.subs({field: u, field.backward: u_prev}, postprocess=False)
@@ -40,25 +41,27 @@ def superstep_generator_iterative(field, stencil, k, tn=0):
4041
return u, u_prev, Eq(u.forward, stencil_next), Eq(u_prev.forward, current)
4142

4243

43-
def superstep_generator(field, stencil, k, tn=0):
44-
''' Generate superstep using a binary decomposition:
44+
def superstep_generator(field, stencil, k, nt=0):
45+
"""
46+
Generate superstep using a binary decomposition:
4547
A^k = aⱼ A^2ʲ × ... × a₂ A^2² × a₁ A² × a₀ A
4648
where k = aⱼ·2ʲ + ... + a₂·2² + a₁·2¹ + a₀·2⁰
47-
'''
49+
"""
4850
# New fields, for vector formulation both current and previous timestep are needed
4951
name = field.name
50-
grid = field.grid
5152
# time_order of `field` needs to be 2
52-
u = TimeFunction(name=f'{name}_ss', grid=grid, time_order=1, space_order=2*k)
53-
u_prev = TimeFunction(name=f'{name}_ss_p', grid=grid, time_order=1, space_order=2*k)
53+
u = field._rebuild(name=f'{name}_ss', time_order=1, space_order=2*k)
54+
u_prev = field._rebuild(name=f'{name}_ss', time_order=1, space_order=2*k)
5455

55-
superstep_solution_transfer(field, u, u_prev, tn)
56+
superstep_solution_transfer(field, u, u_prev, nt)
5657

5758
# Substitute new fields into stencil
5859
ss_stencil = stencil.subs({field: u, field.backward: u_prev}, postprocess=False)
5960
ss_stencil = ss_stencil.expand().expand(add=True, nest=True)
6061

61-
# Binary decomposition algorithm
62+
# Binary decomposition algorithm (see docstring):
63+
# Calculate the binary decomposition of the exponent (k) and accumulate the
64+
# resultant operator
6265
current = (ss_stencil, u)
6366
q, r = divmod(k, 2)
6467
accumulate = current if r else (1, 1)
@@ -71,23 +74,26 @@ def superstep_generator(field, stencil, k, tn=0):
7174
return u, u_prev, Eq(u.forward, accumulate[0]), Eq(u_prev.forward, accumulate[1])
7275

7376

74-
def superstep_solution_transfer(old, new, new_p, tn):
75-
''' Transfer state from a previous TimeFunction to a 2 field superstep
77+
def superstep_solution_transfer(old, new, new_p, nt):
78+
"""
79+
Transfer state from a previous TimeFunction to a 2 field superstep
7680
Used after injecting source using standard timestepping.
77-
'''
78-
# 3 should be replaced with `old.time_order + 1` although this needs some thought
79-
idx = tn % 3 if old.save is None else -1
80-
new.data[0, :] = old.data[idx - 1]
81-
new.data[1, :] = old.data[idx]
82-
new_p.data[0, :] = old.data[idx - 2]
83-
new_p.data[1, :] = old.data[idx - 1]
81+
"""
82+
# This method is completely generic for future development, but currently
83+
# only time_order == 2 is implemented!
84+
idx = nt % (old.time_order + 1) if old.save is None else -1
85+
for ii in range(old.time_order + 1):
86+
new.data[ii, :] = old.data[idx - ii - 1]
87+
new_p.data[ii, :] = old.data[idx - ii - 2]
8488

8589

8690
def _combine_superstep(stencil_a, stencil_b, u, u_prev, k):
87-
''' Combine two arbitrary order supersteps
88-
'''
91+
"""
92+
Combine two arbitrary order supersteps
93+
"""
8994
# Placeholder fields for forming the superstep
9095
grid = u.grid
96+
# Can I use a TempFunction here?
9197
a_tmp = Function(name="a_tmp", grid=grid, space_order=2*k)
9298
b_tmp = Function(name="b_tmp", grid=grid, space_order=2*k)
9399

examples/timestepping/acoustic_superstep_2d.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
import matplotlib.pyplot as plt
1212
import numpy as np
13+
from scipy.interpolate import interpn
14+
1315
from devito import (
1416
ConditionalDimension,
1517
Eq,
@@ -21,7 +23,6 @@
2123
solve,
2224
)
2325
from devito.timestepping.superstep import superstep_generator
24-
from scipy.interpolate import interpn
2526

2627

2728
@dataclass
@@ -99,24 +100,24 @@ def acoustic_model(model, step=1, snapshots=1):
99100
pde = (1/velocity**2)*u.dt2 - u.laplace
100101
stencil = Eq(u.forward, solve(pde, u.forward))
101102

102-
tn1 = int(np.ceil((t1 - t0)/model.critical_dt))
103-
dt = (t1 - t0)/tn1
103+
nt1 = int(np.ceil((t1 - t0)/model.critical_dt))
104+
dt = (t1 - t0)/nt1
104105

105106
# Source
106-
t = np.linspace(t0, t1, tn1)
107+
t = np.linspace(t0, t1, nt1)
107108
rick = ricker(t)
108109
source = SparseTimeFunction(
109-
name="ricker", npoint=1, coordinates=[model.source], nt=tn1, grid=grid,
110+
name="ricker", npoint=1, coordinates=[model.source], nt=nt1, grid=grid,
110111
time_order=2, space_order=4
111112
)
112113
source.data[:, 0] = rick
113114
src_term = source.inject(field=u.forward, expr=source*velocity*velocity*dt*dt)
114115

115116
op1 = Operator([stencil] + src_term)
116-
op1(time=tn1 - 1, dt=dt)
117+
op1(time=nt1 - 1, dt=dt)
117118

118119
# Stencil and operator
119-
idx = tn1 % 3
120+
idx = nt1 % 3
120121
if step == 1:
121122
# Non-superstep case
122123
new_u = TimeFunction(name="new_u", grid=grid, time_order=2, space_order=2)
@@ -127,13 +128,13 @@ def acoustic_model(model, step=1, snapshots=1):
127128
new_u.data[1, :] = u.data[idx - 1]
128129
new_u.data[2, :] = u.data[idx]
129130
else:
130-
new_u, new_u_p, *stencil = superstep_generator(u, stencil.rhs, step, tn=tn1)
131+
new_u, new_u_p, *stencil = superstep_generator(u, stencil.rhs, step, nt=nt1)
131132

132-
tn2 = int(np.ceil((t2 - t1)/model.critical_dt))
133-
dt = (t2 - t1)/tn2
133+
nt2 = int(np.ceil((t2 - t1)/model.critical_dt))
134+
dt = (t2 - t1)/nt2
134135

135136
# Snapshot the solution
136-
factor = int(np.ceil(tn2/(snapshots + 1)))
137+
factor = int(np.ceil(nt2/(snapshots + 1)))
137138
t_sub = ConditionalDimension('t_sub', parent=grid.time_dim, factor=factor)
138139
u_save = TimeFunction(
139140
name='usave', grid=grid,

examples/timestepping/superstep_1d.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
'''
44
import matplotlib.pyplot as plt
55
import numpy as np
6+
67
from devito import Eq, Function, Grid, Operator, TimeFunction, solve
78
from devito.timestepping.superstep import superstep_generator
89

@@ -64,12 +65,12 @@ def wave_on_string(step=1):
6465
stencil2,
6566
], opt='noop')
6667

67-
tn = int(np.ceil(t1/critical_dt))
68-
dt = t1/tn
68+
nt = int(np.ceil(t1/critical_dt))
69+
dt = t1/nt
6970

70-
op(time=tn, dt=dt)
71+
op(time=nt, dt=dt)
7172

72-
idx = tn % 3
73+
idx = nt % 3
7374
return newu.data[idx]
7475

7576

examples/timestepping/superstep_2d.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
'''
44
import matplotlib.pyplot as plt
55
import numpy as np
6+
67
from devito import ConditionalDimension, Eq, Function, Grid, Operator, TimeFunction, solve
78
from devito.timestepping.superstep import superstep_generator
89

@@ -60,11 +61,11 @@ def ripple_on_pond(step=1, snapshots=1):
6061
new_u_p.data[0, :] = ic
6162
new_u_p.data[1, :] = ic
6263

63-
tn = int(np.ceil((t1 - t0)/critical_dt))
64-
dt = t1/tn
64+
nt = int(np.ceil((t1 - t0)/critical_dt))
65+
dt = t1/nt
6566

6667
# Snapshot the solution
67-
factor = int(np.ceil(tn/(snapshots + 1)))
68+
factor = int(np.ceil(nt/(snapshots + 1)))
6869
t_sub = ConditionalDimension('t_sub', parent=grid.time_dim, factor=factor)
6970
u_save = TimeFunction(
7071
name='usave', grid=grid,
@@ -103,7 +104,7 @@ def ripple_on_pond(step=1, snapshots=1):
103104
)
104105
idx += 1
105106
if step == 1:
106-
ax.set_title(f't = {(ii*t1)/(m - 1) :0.3f}')
107+
ax.set_title(f't = {(ii*t1)/(m - 1):0.3f}')
107108
if ii != 0:
108109
ax.set_yticklabels([])
109110
if ii % 2 == 1:

0 commit comments

Comments
 (0)