Skip to content

Commit 9f68bb2

Browse files
committed
compiler: Progress with petscsection constrain bc callback
1 parent 0b6b282 commit 9f68bb2

9 files changed

Lines changed: 203 additions & 40 deletions

File tree

devito/operator/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def _lower_exprs(cls, expressions, **kwargs):
371371

372372
# rename etc
373373
expressions = lower_exprs_petsc(expressions, **kwargs)
374+
# from IPython import embed; embed()
374375

375376
processed = [LoweredEq(i) for i in expressions]
376377

devito/petsc/equations.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
from sympy import Eq
22
from devito.symbolics import retrieve_indexed, retrieve_dimensions
3-
from devito.petsc import EssentialBC
3+
from devito.petsc.types.equation import ConstrainEssentialBC
44
from devito.types.dimension import CustomBoundSubDimension
55
from devito import Min, Max
66

77

88
def lower_exprs_petsc(expressions, **kwargs):
9-
mapper = {}
9+
# Constrain EssentialBCs using PetscSection if specified to do so
10+
expressions = constrain_essential_bcs(expressions, **kwargs)
11+
12+
return expressions
13+
1014

11-
additional_exprs = []
15+
16+
def constrain_essential_bcs(expressions, **kwargs):
17+
"""TODO: improve docs ..Modify the subdims used in ConstrainEssentialBC equations ... to locally
18+
constrain nodes (including non owned halo nodes) ....."""
19+
20+
mapper = {}
21+
new_exprs = []
1222

1323
# build mapper
1424
for e in expressions:
15-
if not isinstance(e, EssentialBC):
25+
if not isinstance(e, ConstrainEssentialBC):
1626
continue
27+
1728
indexeds = retrieve_indexed(e)
1829
dims = retrieve_dimensions([i for j in indexeds for i in j.indices], mode='unique')
1930

@@ -30,8 +41,6 @@ def lower_exprs_petsc(expressions, **kwargs):
3041

3142
from devito.petsc.types.dimension import SubDimMax, SubDimMin
3243

33-
34-
3544
# TODO: change name..
3645

3746
# in theory this class shoulod just take in d
@@ -51,20 +60,13 @@ def lower_exprs_petsc(expressions, **kwargs):
5160
)
5261
mapper[d] = new_dim
5362

54-
# from IPython import embed; embed()
55-
5663
# build new expressions
5764
for e in expressions:
58-
if not isinstance(e, EssentialBC):
59-
continue
60-
61-
# build new expression
62-
new_e = e.subs(mapper)
63-
64-
additional_exprs.append(new_e)
65-
66-
# return expressions + additional_exprs
67-
return expressions + additional_exprs
68-
69-
70-
65+
if isinstance(e, ConstrainEssentialBC):
66+
new_e = e.subs(mapper)
67+
new_exprs.append(new_e)
68+
69+
else:
70+
new_exprs.append(e)
71+
72+
return new_exprs

devito/petsc/iet/builder.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from functools import cached_property
23

34
from devito.ir.iet import DummyExpr, BlankLine
45
from devito.symbolics import (Byref, FieldFromPointer, VOID,
@@ -22,7 +23,11 @@ def __init__(self, **kwargs):
2223
self.callback_builder = kwargs.get('callback_builder')
2324
self.field_data = self.inject_solve.expr.rhs.field_data
2425
self.formatted_prefix = self.inject_solve.expr.rhs.formatted_prefix
25-
self.calls = self._setup()
26+
# self.calls = self._setup()
27+
28+
@cached_property
29+
def calls(self):
30+
return self._setup()
2631

2732
@property
2833
def snes_ctx(self):
@@ -142,9 +147,11 @@ def _extend_setup(self):
142147

143148
def _create_dmda_calls(self, dmda):
144149
dmda_create = self._create_dmda(dmda)
150+
# TODO: probs need to set the dm options prefix the same as snes?
151+
dm_set_from_opts = petsc_call('DMSetFromOptions', [dmda])
145152
dm_setup = petsc_call('DMSetUp', [dmda])
146153
dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL'])
147-
return dmda_create, dm_setup, dm_mat_type
154+
return dmda_create, dm_set_from_opts, dm_setup, dm_mat_type
148155

149156
def _create_dmda(self, dmda):
150157
sobjs = self.solver_objs
@@ -184,7 +191,7 @@ def _create_dmda(self, dmda):
184191
dmda = petsc_call(f'DMDACreate{nspace_dims}d', args)
185192

186193
return dmda
187-
194+
188195

189196
class CoupledBuilder(BuilderBase):
190197
def _setup(self):
@@ -332,6 +339,35 @@ def _setup(self):
332339
create_submats) + \
333340
tuple(deref_dms) + tuple(xglobals) + tuple(xlocals) + (BlankLine,)
334341
return coupled_setup
342+
343+
344+
class ConstrainedBCMixin:
345+
"""
346+
"""
347+
def _create_dmda_calls(self, dmda):
348+
# TODO: CLEAN UP
349+
dmda_create = self._create_dmda(dmda)
350+
# TODO: probs need to set the dm options prefix the same as snes?
351+
# don't hardcode this probs? - the dm needs to be specific to the solver as well
352+
da_create_section = petsc_call('PetscOptionsSetValue', [Null, '-da_use_section', Null])
353+
dm_set_from_opts = petsc_call('DMSetFromOptions', [dmda])
354+
dm_setup = petsc_call('DMSetUp', [dmda])
355+
dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL'])
356+
357+
set_constraints = petsc_call(
358+
self.callback_builder._constrain_bc_efunc.name, []
359+
)
360+
# OBVS CLEANUP
361+
return dmda_create, da_create_section, dm_set_from_opts, dm_setup, dm_mat_type, set_constraints
362+
363+
364+
class ConstrainedBCBuilder(ConstrainedBCMixin, BuilderBase):
365+
pass
366+
367+
368+
# TODO: Implement this properly
369+
class CoupledConstrainedBCBuilder(ConstrainedBCMixin, CoupledBuilder):
370+
pass
335371

336372

337373
def petsc_call_mpi(specific_call, call_args):

devito/petsc/iet/callbacks.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,18 @@ def __init__(self, **kwargs):
3737
self._efuncs = OrderedDict()
3838
self._struct_params = []
3939

40+
# TODO: use either efunc or callback lingo here
4041
self._set_options_efunc = None
4142
self._clear_options_efunc = None
4243
self._main_matvec_callback = None
4344
self._user_struct_callback = None
4445
self._F_efunc = None
4546
self._b_efunc = None
47+
self._constrain_bc_efunc = None
4648

4749
self._J_efuncs = []
48-
self._initial_guesses = []
50+
# TODO: isn't there only ever one of these per solver so why is it a list?
51+
self._initial_guess_efuncs = []
4952

5053
self._make_core()
5154
self._efuncs = self._uxreplace_efuncs()
@@ -82,9 +85,10 @@ def J_efuncs(self):
8285
"""
8386
return self._J_efuncs
8487

88+
# TODO: do i really need a property for this - probs not?
8589
@property
86-
def initial_guesses(self):
87-
return self._initial_guesses
90+
def initial_guess_efuncs(self):
91+
return self._initial_guess_efuncs
8892

8993
@property
9094
def user_struct_callback(self):
@@ -112,11 +116,18 @@ def target(self):
112116

113117
def _make_core(self):
114118
self._make_options_callback()
119+
# Make the mat-vec callback to form the matfree Jacobian
115120
self._make_matvec(self.field_data.jacobian)
121+
# Make the residual callback
116122
self._make_formfunc()
123+
# Make the RHS callback
117124
self._make_formrhs()
125+
# Make the initial guess callback
118126
if self.field_data.initial_guess.exprs:
119127
self._make_initial_guess()
128+
# Make the callback used to constrain boundary nodes
129+
if self.field_data.constrain_bc.exprs:
130+
self._make_constrain_bc()
120131
self._make_user_struct_callback()
121132

122133
def _make_petsc_callable(self, prefix, body, parameters=()):
@@ -578,7 +589,7 @@ def _make_initial_guess(self):
578589
cb = self._make_petsc_callable(
579590
'FormInitialGuess', body, parameters=(sobjs['callbackdm'], objs['xloc'])
580591
)
581-
self._initial_guesses.append(cb)
592+
self._initial_guess_efuncs.append(cb)
582593
self._efuncs[cb.name] = cb
583594

584595
def _create_initial_guess_body(self, body):
@@ -636,6 +647,29 @@ def _create_initial_guess_body(self, body):
636647
i in fields if not isinstance(i.function, AbstractFunction)}
637648

638649
return Uxreplace(subs).visit(body)
650+
651+
def _make_constrain_bc(self):
652+
exprs = self.field_data.constrain_bc.exprs
653+
sobjs = self.solver_objs
654+
objs = self.objs
655+
656+
# Compile constrain `eqns` into an IET via recursive compilation
657+
irs, _ = self.rcompile(
658+
exprs, options={'mpi': False}, sregistry=self.sregistry,
659+
concretize_mapper=self.concretize_mapper
660+
)
661+
# from IPython import embed; embed()
662+
body = self._create_constrain_bc_body(
663+
List(body=irs.uiet.body)
664+
)
665+
cb = self._make_petsc_callable(
666+
'ConstrainBCs', body, parameters=(sobjs['callbackdm'],)
667+
)
668+
self._constrain_bc_efunc = cb
669+
self._efuncs[cb.name] = cb
670+
671+
def _create_constrain_bc_body(self, body):
672+
return body
639673

640674
def _make_user_struct_callback(self):
641675
"""

devito/petsc/iet/passes.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
get_user_struct_fields
2424
)
2525
from devito.petsc.iet.type_builder import BaseTypeBuilder, CoupledTypeBuilder, objs
26-
from devito.petsc.iet.builder import BuilderBase, CoupledBuilder, make_core_petsc_calls
26+
from devito.petsc.iet.builder import BuilderBase, CoupledBuilder, ConstrainedBCBuilder, make_core_petsc_calls
2727
from devito.petsc.iet.solve import Solve, CoupledSolve
2828
from devito.petsc.iet.time_dependence import TimeDependent, TimeIndependent
2929
from devito.petsc.iet.logging import PetscLogger
@@ -248,6 +248,7 @@ def __init__(self, inject_solve, iters, comm, section_mapper, **kwargs):
248248
self.get_info = inject_solve.expr.rhs.get_info
249249
self.kwargs = kwargs
250250
self.coupled = isinstance(inject_solve.expr.rhs.field_data, MultipleFieldData)
251+
self.constrain_bc = inject_solve.expr.rhs.field_data.constrain_bc
251252
self.common_kwargs = {
252253
'inject_solve': self.inject_solve,
253254
'objs': self.objs,
@@ -280,10 +281,22 @@ def callback_builder(self):
280281
return CoupledCallbackBuilder(**self.common_kwargs) \
281282
if self.coupled else BaseCallbackBuilder(**self.common_kwargs)
282283

284+
# @cached_property
285+
# def builder(self):
286+
# return CoupledBuilder(**self.common_kwargs) \
287+
# if self.coupled else BuilderBase(**self.common_kwargs)
288+
283289
@cached_property
284290
def builder(self):
285-
return CoupledBuilder(**self.common_kwargs) \
286-
if self.coupled else BuilderBase(**self.common_kwargs)
291+
if self.coupled and self.constrain_bc:
292+
# TODO: implement CoupledConstrainedBCBuilder
293+
return CoupledBuilder(**self.common_kwargs)
294+
elif self.coupled:
295+
return CoupledBuilder(**self.common_kwargs)
296+
elif self.constrain_bc:
297+
return ConstrainedBCBuilder(**self.common_kwargs)
298+
else:
299+
return BuilderBase(**self.common_kwargs)
287300

288301
@cached_property
289302
def solve(self):

devito/petsc/iet/solve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def _execute_solve(self):
3737

3838
vec_place_array = self.time_dependence.place_array(target)
3939

40-
if self.callback_builder.initial_guesses:
41-
initguess = self.callback_builder.initial_guesses[0]
40+
if self.callback_builder.initial_guess_efuncs:
41+
initguess = self.callback_builder.initial_guess_efuncs[0]
4242
initguess_call = petsc_call(initguess.name, [dmda, sobjs['xlocal']])
4343
else:
4444
initguess_call = None

devito/petsc/solve.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from devito.petsc.types import (
88
LinearSolverMetaData, PETScArray, DMDALocalInfo, FieldData, MultipleFieldData,
9-
Jacobian, Residual, MixedResidual, MixedJacobian, InitialGuess
9+
Jacobian, Residual, MixedResidual, MixedJacobian, InitialGuess, ConstrainBC
1010
)
1111
from devito.petsc.types.equation import EssentialBC
1212
from devito.petsc.solver_parameters import (
@@ -18,7 +18,7 @@
1818

1919

2020
def petscsolve(target_exprs, target=None, solver_parameters=None,
21-
options_prefix=None, get_info=[]):
21+
options_prefix=None, get_info=[], constrain_bcs=False):
2222
"""
2323
Returns a symbolic expression representing a linear PETSc solver,
2424
enriched with all the necessary metadata for execution within an `Operator`.
@@ -78,6 +78,12 @@ def petscsolve(target_exprs, target=None, solver_parameters=None,
7878
- ['kspgetiterationnumber', 'kspgettolerances', 'kspgetconvergedreason',
7979
'kspgettype', 'kspgetnormtype', 'snesgetiterationnumber']
8080
81+
constrain_bcs : bool, optional
82+
If `True`, essential boundary conditions specifed by `EssentialBC` equations
83+
are constrained through a `PetscSection`. As a result, the corresponding degrees
84+
of freedom are excluded from the global solver and are not imposed using
85+
trivial equations.
86+
8187
Returns
8288
-------
8389
Eq:
@@ -86,22 +92,24 @@ def petscsolve(target_exprs, target=None, solver_parameters=None,
8692
"""
8793
if target is not None:
8894
return InjectSolve(solver_parameters, {target: target_exprs},
89-
options_prefix, get_info).build_expr()
95+
options_prefix, get_info, constrain_bcs).build_expr()
9096
else:
97+
# TODO: extend mixed case to support constrain_bcs
9198
return InjectMixedSolve(solver_parameters, target_exprs,
9299
options_prefix, get_info).build_expr()
93100

94101

95102
class InjectSolve:
96103
def __init__(self, solver_parameters=None, target_exprs=None, options_prefix=None,
97-
get_info=[]):
104+
get_info=[], constrain_bcs=False):
98105
self.solver_parameters = linear_solver_parameters(solver_parameters)
99106
self.time_mapper = None
100107
self.target_exprs = target_exprs
101108
# The original options prefix provided by the user
102109
self.user_prefix = options_prefix
103110
self.formatted_prefix = format_options_prefix(options_prefix)
104111
self.get_info = [f.lower() for f in get_info]
112+
self.constrain_bcs = constrain_bcs
105113

106114
def build_expr(self):
107115
target, funcs, field_data = self.linear_solve_args()
@@ -129,16 +137,26 @@ def linear_solve_args(self):
129137

130138
exprs = sorted(exprs, key=lambda e: not isinstance(e, EssentialBC))
131139

140+
# TODO: rethink about how essential bcs need to be treated if constrain_bcs is enabled
141+
# likely don't need various bits of functionality inside these classes if constrain_bcs is enabled
132142
jacobian = Jacobian(target, exprs, arrays, self.time_mapper)
133143
residual = Residual(target, exprs, arrays, self.time_mapper, jacobian.scdiag)
134144
initial_guess = InitialGuess(target, exprs, arrays, self.time_mapper)
135145

146+
# TODO: extend this to mixed case
147+
# TODO: clean up
148+
if self.constrain_bcs:
149+
constrain_bc = ConstrainBC(target, exprs, arrays)
150+
else:
151+
constrain_bc = None
152+
136153
field_data = FieldData(
137154
target=target,
138155
jacobian=jacobian,
139156
residual=residual,
140157
initial_guess=initial_guess,
141-
arrays=arrays
158+
arrays=arrays,
159+
constrain_bc=constrain_bc
142160
)
143161

144162
return target, funcs, field_data

0 commit comments

Comments
 (0)