Skip to content

Commit 16bd5d8

Browse files
committed
examples: Refactor 1D and 2D code into one file
1 parent b2fd861 commit 16bd5d8

3 files changed

Lines changed: 223 additions & 220 deletions

File tree

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
''' Script that demonstrates the functionality of the superstep in 1D and 2D
2+
with an initial condition
3+
In 1D: "Wave on a string"
4+
In 2d: "Ripple on a pond"
5+
'''
6+
from argparse import ArgumentParser
7+
from dataclasses import dataclass
8+
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
12+
from devito import ConditionalDimension, Eq, Function, Grid, Operator, TimeFunction, solve
13+
from devito.timestepping.superstep import superstep_generator
14+
15+
16+
@dataclass
17+
class Parameters:
18+
# Spatial
19+
shape: tuple[int]
20+
origin: tuple[float]
21+
extent: tuple[float]
22+
# Time
23+
t0: float
24+
t1: float
25+
critical_dt: float
26+
# Initial Condition
27+
mu: float
28+
sigma_sq: float
29+
lim: float
30+
31+
32+
def gaussian_1d(x, mu=0, sigma_sq=1):
33+
"""
34+
Generate a 1D Gaussian initial condition
35+
"""
36+
return np.exp(-((x - mu)**2)/(2*sigma_sq))/(np.sqrt(2*np.pi*sigma_sq))
37+
38+
39+
def gaussian_2d(xx, yy, mu=0, sigma_sq=1):
40+
"""
41+
Generate a 2D Gaussian initial condition
42+
"""
43+
return np.exp(-((xx - mu)**2 + (yy - mu)**2)/(2*sigma_sq))/(np.sqrt(2*np.pi*sigma_sq))
44+
45+
46+
def simulate_ic(parameters, step=1, snapshots=-1):
47+
p = parameters
48+
d = len(p.shape)
49+
# Construct Grid
50+
grid = Grid(shape=p.shape, extent=p.extent)
51+
52+
# Need to ensure that the velocity function supports the largest superstep stencil
53+
velocity = Function(name="velocity", grid=grid, space_order=(2, step, step))
54+
velocity.data[:] = 1500 if d == 2 else 1
55+
56+
u = TimeFunction(name="u", grid=grid, time_order=2, space_order=2)
57+
58+
pde = (1/velocity**2)*u.dt2 - u.laplace
59+
stencil = Eq(u.forward, solve(pde, u.forward))
60+
61+
# Initial condition
62+
x = np.linspace(p.origin[0], p.extent[0], p.shape[0])
63+
if d == 1:
64+
ic = gaussian_1d(x, mu=p.mu, sigma_sq=p.sigma_sq)
65+
elif d == 2:
66+
y = np.linspace(p.origin[1], p.extent[1], p.shape[1])
67+
xx, yy = np.meshgrid(x, y)
68+
ic = gaussian_2d(xx, yy, mu=p.mu, sigma_sq=p.sigma_sq)
69+
70+
# Stencil and operator
71+
if step == 1:
72+
# Non-superstep case
73+
new_u = u
74+
stencil = [stencil]
75+
new_u.data[0, :] = ic
76+
new_u.data[1, :] = ic
77+
else:
78+
new_u, new_u_p, *stencil = superstep_generator(u, stencil.rhs, step)
79+
80+
new_u.data[0, :] = ic
81+
new_u.data[1, :] = ic
82+
new_u_p.data[0, :] = ic
83+
new_u_p.data[1, :] = ic
84+
85+
nt = int(np.ceil((p.t1 - p.t0)/p.critical_dt))
86+
dt = p.t1/nt
87+
88+
# Snapshot the solution
89+
if snapshots > 0:
90+
factor = int(np.ceil(nt/(snapshots + 1)))
91+
t_sub = ConditionalDimension('t_sub', parent=grid.time_dim, factor=factor)
92+
u_save = TimeFunction(
93+
name='usave', grid=grid,
94+
time_order=0, space_order=2,
95+
save=snapshots//step + 1, time_dim=t_sub
96+
)
97+
else:
98+
u_save = TimeFunction(
99+
name='usave', grid=grid,
100+
time_order=0, space_order=2,
101+
save=1, time_dim=new_u.time_dim
102+
)
103+
save = Eq(u_save, new_u)
104+
105+
op = Operator([*stencil, save], opt='noop')
106+
if d == 1:
107+
op(time=nt - 2, dt=dt)
108+
elif d == 2:
109+
op(dt=dt)
110+
111+
if d == 2 and step == 1:
112+
u_save.data[0, :] = ic
113+
114+
return u_save.data
115+
116+
117+
def plot_1d(k, data, parameters):
118+
p = parameters
119+
fig, ax = plt.subplots(1, 1)
120+
fig.set_size_inches(8, 4)
121+
x = np.linspace(p.origin[0], p.extent[0], *p.shape)
122+
ax.plot(
123+
x, gaussian_1d(x, mu=p.mu, sigma_sq=p.sigma_sq),
124+
color='k', ls='--', label='Initial Condition'
125+
)
126+
127+
for step, d in zip(k, data, strict=True):
128+
label = 'Normal timestepping' if step == 1 else f'Superstep size {step}'
129+
ax.plot(x, d[-1], label=label)
130+
131+
ax.set_xlim(p.origin[0], p.extent[0])
132+
ax.set_ylim(-p.lim, p.lim)
133+
ax.legend()
134+
return fig, ax
135+
136+
137+
def plot_2d(k, data, parameters, snapshots=1):
138+
p = parameters
139+
fig, axes = plt.subplots(len(data), snapshots)
140+
fig.set_size_inches(16, 5)
141+
142+
for step, d, ax_row in zip(k, data, axes, strict=True):
143+
idx = 0
144+
for ii, ax in enumerate(ax_row):
145+
if ii % step == 0:
146+
ax.imshow(
147+
d[idx, :, :].T,
148+
extent=[p.origin[0], p.extent[0], p.extent[1], p.origin[1]],
149+
vmin=-p.lim, vmax=p.lim,
150+
cmap='seismic'
151+
)
152+
idx += 1
153+
if step == 1:
154+
ax.set_title(f't = {(ii*p.t1)/(snapshots - 1):0.3f}')
155+
if ii != 0:
156+
ax.set_yticklabels([])
157+
if ii % 2 == 1:
158+
ax.set_xticklabels([])
159+
else:
160+
ax.remove()
161+
fig.subplots_adjust(
162+
left=0.05,
163+
bottom=0.02,
164+
right=0.99,
165+
top=0.96,
166+
wspace=0.23,
167+
hspace=0.0
168+
)
169+
return fig, ax
170+
171+
172+
if __name__ == '__main__':
173+
parser = ArgumentParser()
174+
parser.add_argument('-d', '--dimension', type=int, default=1, choices=[1, 2])
175+
args = parser.parse_args()
176+
177+
d = args.dimension
178+
parameters = {}
179+
# 1D Simulation parameters
180+
parameters[1] = Parameters(
181+
shape=(501, ),
182+
origin=(0, ),
183+
extent=(1, ),
184+
t0=0,
185+
t1=0.15,
186+
critical_dt=0.0014142,
187+
mu=0.5,
188+
sigma_sq=0.005,
189+
lim=np.ceil(1/np.sqrt(2*np.pi*0.005)),
190+
)
191+
# 2D Simulation parameters
192+
parameters[2] = Parameters(
193+
shape=(101, 101),
194+
origin=(0., 0.),
195+
extent=(1000, 1000), # 1kmx1km
196+
# Time Domain
197+
t0=0,
198+
t1=0.5,
199+
critical_dt=0.0047140,
200+
# Initial Condition
201+
mu=500,
202+
sigma_sq=5000,
203+
lim=1/(2*np.sqrt(2*np.pi*5000))
204+
)
205+
206+
# Supersteps
207+
if d == 1:
208+
k = range(1, 6)
209+
m = -1
210+
elif d == 2:
211+
k = [1, 3, 4]
212+
# Snapshots
213+
m = 13
214+
215+
data = [simulate_ic(parameters[d], step, snapshots=m) for step in k]
216+
217+
if d == 1:
218+
fig, ax = plot_1d(k, data, parameters[d])
219+
elif d == 2:
220+
fig, ax = plot_2d(k, data, parameters[d], m)
221+
222+
fig.savefig(f'{d}d_example.png', dpi=300)
223+
plt.show()

examples/timestepping/superstep_1d.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

0 commit comments

Comments
 (0)