11from devito .types .equation import PetscEq
22from devito .tools import filter_ordered , as_tuple
3- from devito .types import Symbol , SteppingDimension , TimeDimension
3+ from devito .types import Symbol , SteppingDimension , TimeDimension , Border
44from devito .operations .solve import eval_time_derivatives
55from devito .symbolics import retrieve_functions , retrieve_dimensions
66
@@ -132,7 +132,15 @@ def linear_solve_args(self):
132132 target , exprs = next (iter (self .target_exprs .items ()))
133133 exprs = as_tuple (exprs )
134134
135+
136+ # TODO: maybe move into a seprate class/method... or at least clean up
137+ stagger_bc = _stagger_constrain_bc (target )
138+ if stagger_bc is not None :
139+ exprs = exprs + (stagger_bc ,)
140+
135141 funcs = get_funcs (exprs )
142+ funcs = tuple (f for f in funcs if f .function is not target .function )
143+ # from IPython import embed; embed()
136144 self .time_mapper = generate_time_mapper (exprs )
137145 arrays = self .generate_arrays (target )
138146
@@ -147,7 +155,7 @@ def linear_solve_args(self):
147155 constrain_exprs = self ._get_constrain_exprs (exprs )
148156 constrain_bc = None
149157 if constrain_exprs :
150- constrain_bc = ConstrainBC (target , constrain_exprs , arrays )
158+ constrain_bc = ConstrainBC (target , constrain_exprs , arrays , self . time_mapper )
151159
152160 field_data = FieldData (
153161 target = target ,
@@ -188,11 +196,25 @@ def generate_arrays(self, *targets):
188196class InjectMixedSolve (InjectSolve ):
189197
190198 def linear_solve_args (self ):
199+
200+ # TODO: again, clean up, or more into separate methods/classes
201+ # add documentation
202+ augmented = {}
203+ for t , e in self .target_exprs .items ():
204+ stagger_bc = _stagger_constrain_bc (t )
205+ te = as_tuple (e )
206+ if stagger_bc is not None :
207+ te = te + (stagger_bc ,)
208+ augmented [t ] = te
209+ self .target_exprs = augmented
210+
191211 exprs = []
192212 for e in self .target_exprs .values ():
193213 exprs .extend (e )
194214
195215 funcs = get_funcs (exprs )
216+ target_set = set (self .target_exprs .keys ())
217+ funcs = tuple (f for f in funcs if f .function not in target_set )
196218 self .time_mapper = generate_time_mapper (exprs )
197219
198220 targets = list (self .target_exprs .keys ())
@@ -211,7 +233,7 @@ def linear_solve_args(self):
211233 for t in targets :
212234 cexprs = self ._get_constrain_exprs (as_tuple (self .target_exprs [t ]))
213235 if cexprs :
214- constrain_bc [t ] = ConstrainBC (t , cexprs , arrays )
236+ constrain_bc [t ] = ConstrainBC (t , cexprs , arrays , self . time_mapper )
215237 constrain_bc = constrain_bc if constrain_bc else None
216238
217239 all_data = MultipleFieldData (
@@ -225,11 +247,37 @@ def linear_solve_args(self):
225247 return targets [0 ], funcs , all_data
226248
227249
250+ def _stagger_constrain_bc (target ):
251+ """
252+ """
253+ if target .staggered .on_node :
254+ return None
255+ grid = target .grid
256+ grid_dim_set = set (grid .dimensions )
257+ staggered_dims = [d for d , s in zip (target .dimensions , target .staggered )
258+ if s != 0 and d in grid_dim_set ]
259+ if not staggered_dims :
260+ return None
261+ dims = {d : 'right' for d in staggered_dims }
262+ border = Border (grid , border = 1 , dims = dims ,
263+ name = '_stagger_border_%s' % target .name ,
264+ corners = 'nooverlap' )
265+ # from IPython import embed; embed()
266+ from devito import Constant
267+ rhs = Constant (name = 'zero' , dtype = target .dtype )
268+ return EssentialBC (target , rhs , subdomain = border , constrain = True )
269+
270+
228271def get_funcs (exprs ):
272+ # funcs = [
273+ # f for e in exprs
274+ # for f in retrieve_functions(eval_time_derivatives(e.lhs - e.rhs))
275+ # ]
229276 funcs = [
230277 f for e in exprs
231- for f in retrieve_functions (eval_time_derivatives ( e . lhs - e . rhs ) )
278+ for f in retrieve_functions (e . evaluate )
232279 ]
280+ # from IPython import embed; embed()
233281 return as_tuple (filter_ordered (funcs ))
234282
235283
0 commit comments