Skip to content

Commit ee76817

Browse files
committed
dsl: Automate the staggered grid handling with petsc i.e constrain the region that is not part of the computational domain so it is excluded from the global solver
1 parent 0885857 commit ee76817

1 file changed

Lines changed: 52 additions & 4 deletions

File tree

devito/petsc/solve.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from devito.types.equation import PetscEq
22
from devito.tools import filter_ordered, as_tuple
3-
from devito.types import Symbol, SteppingDimension, TimeDimension
3+
from devito.types import Symbol, SteppingDimension, TimeDimension, Border
44
from devito.operations.solve import eval_time_derivatives
55
from devito.symbolics import retrieve_functions, retrieve_dimensions
66

@@ -132,7 +132,15 @@ def linear_solve_args(self):
132132
target, exprs = next(iter(self.target_exprs.items()))
133133
exprs = as_tuple(exprs)
134134

135+
136+
# TODO: maybe move into a seprate class/method... or at least clean up
137+
stagger_bc = _stagger_constrain_bc(target)
138+
if stagger_bc is not None:
139+
exprs = exprs + (stagger_bc,)
140+
135141
funcs = get_funcs(exprs)
142+
funcs = tuple(f for f in funcs if f.function is not target.function)
143+
# from IPython import embed; embed()
136144
self.time_mapper = generate_time_mapper(exprs)
137145
arrays = self.generate_arrays(target)
138146

@@ -147,7 +155,7 @@ def linear_solve_args(self):
147155
constrain_exprs = self._get_constrain_exprs(exprs)
148156
constrain_bc = None
149157
if constrain_exprs:
150-
constrain_bc = ConstrainBC(target, constrain_exprs, arrays)
158+
constrain_bc = ConstrainBC(target, constrain_exprs, arrays, self.time_mapper)
151159

152160
field_data = FieldData(
153161
target=target,
@@ -188,11 +196,25 @@ def generate_arrays(self, *targets):
188196
class InjectMixedSolve(InjectSolve):
189197

190198
def linear_solve_args(self):
199+
200+
# TODO: again, clean up, or more into separate methods/classes
201+
# add documentation
202+
augmented = {}
203+
for t, e in self.target_exprs.items():
204+
stagger_bc = _stagger_constrain_bc(t)
205+
te = as_tuple(e)
206+
if stagger_bc is not None:
207+
te = te + (stagger_bc,)
208+
augmented[t] = te
209+
self.target_exprs = augmented
210+
191211
exprs = []
192212
for e in self.target_exprs.values():
193213
exprs.extend(e)
194214

195215
funcs = get_funcs(exprs)
216+
target_set = set(self.target_exprs.keys())
217+
funcs = tuple(f for f in funcs if f.function not in target_set)
196218
self.time_mapper = generate_time_mapper(exprs)
197219

198220
targets = list(self.target_exprs.keys())
@@ -211,7 +233,7 @@ def linear_solve_args(self):
211233
for t in targets:
212234
cexprs = self._get_constrain_exprs(as_tuple(self.target_exprs[t]))
213235
if cexprs:
214-
constrain_bc[t] = ConstrainBC(t, cexprs, arrays)
236+
constrain_bc[t] = ConstrainBC(t, cexprs, arrays, self.time_mapper)
215237
constrain_bc = constrain_bc if constrain_bc else None
216238

217239
all_data = MultipleFieldData(
@@ -225,11 +247,37 @@ def linear_solve_args(self):
225247
return targets[0], funcs, all_data
226248

227249

250+
def _stagger_constrain_bc(target):
251+
"""
252+
"""
253+
if target.staggered.on_node:
254+
return None
255+
grid = target.grid
256+
grid_dim_set = set(grid.dimensions)
257+
staggered_dims = [d for d, s in zip(target.dimensions, target.staggered)
258+
if s != 0 and d in grid_dim_set]
259+
if not staggered_dims:
260+
return None
261+
dims = {d: 'right' for d in staggered_dims}
262+
border = Border(grid, border=1, dims=dims,
263+
name='_stagger_border_%s' % target.name,
264+
corners='nooverlap')
265+
# from IPython import embed; embed()
266+
from devito import Constant
267+
rhs = Constant(name='zero', dtype=target.dtype)
268+
return EssentialBC(target, rhs, subdomain=border, constrain=True)
269+
270+
228271
def get_funcs(exprs):
272+
# funcs = [
273+
# f for e in exprs
274+
# for f in retrieve_functions(eval_time_derivatives(e.lhs - e.rhs))
275+
# ]
229276
funcs = [
230277
f for e in exprs
231-
for f in retrieve_functions(eval_time_derivatives(e.lhs - e.rhs))
278+
for f in retrieve_functions(e.evaluate)
232279
]
280+
# from IPython import embed; embed()
233281
return as_tuple(filter_ordered(funcs))
234282

235283

0 commit comments

Comments
 (0)