2121from devito .petsc .types import (PETScArray , PetscBundle , DM , Mat , CallbackVec , Vec ,
2222 KSP , PC , SNES , PetscInt , StartPtr , PointerIS , PointerDM ,
2323 VecScatter , DMCast , JacobianStruct , SubMatrixStruct ,
24- CallbackDM )
24+ CallbackDM , PetscBool )
2525
2626
2727class CBBuilder :
@@ -37,10 +37,12 @@ def __init__(self, **kwargs):
3737 self .objs = kwargs .get ('objs' )
3838 self .solver_objs = kwargs .get ('solver_objs' )
3939 self .inject_solve = kwargs .get ('inject_solve' )
40+ self .solver_parameters = self .inject_solve .expr .rhs .solver_parameters
4041
4142 self ._efuncs = OrderedDict ()
4243 self ._struct_params = []
4344
45+ self ._options_efunc = None
4446 self ._main_matvec_callback = None
4547 self ._user_struct_callback = None
4648 self ._F_efunc = None
@@ -105,13 +107,44 @@ def target(self):
105107 return self .field_data .target
106108
107109 def _make_core (self ):
110+ self ._petsc_options_callback ()
108111 self ._make_matvec (self .field_data .jacobian )
109112 self ._make_formfunc ()
110113 self ._make_formrhs ()
111114 if self .field_data .initial_guess .exprs :
112115 self ._make_initial_guess ()
113116 self ._make_user_struct_callback ()
114117
118+ def _petsc_options_callback (self ):
119+ objs = self .objs
120+ params = self .solver_parameters
121+ Null = objs ['Null' ]
122+
123+ has_names = ()
124+
125+ # TODO: improve
126+ for k , v in params .items ():
127+ is_set = PetscBool (self .sregistry .make_name (prefix = 'set' ))
128+ has_name = petsc_call ('PetscOptionsHasName' , [
129+ Null , Null , String (k ), Byref (is_set )])
130+ has_names += (has_name ,)
131+
132+ body = CallableBody (
133+ List (body = has_names ),
134+ init = (objs ['begin_user' ],),
135+ retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
136+ )
137+
138+ objs = self .objs
139+ cb = PETScCallable (
140+ self .sregistry .make_name (prefix = 'SetPetscOptions' ),
141+ body ,
142+ retval = objs ['err' ],
143+ parameters = ()
144+ )
145+ self ._options_efunc = cb
146+ self ._efuncs [cb .name ] = cb
147+
115148 def _make_matvec (self , jacobian , prefix = 'MatMult' ):
116149 # Compile `matvecs` into an IET via recursive compilation
117150 matvecs = jacobian .matvecs
@@ -1237,6 +1270,10 @@ def _setup(self):
12371270 'SNESSetOptionsPrefix' , [sobjs ['snes' ], sobjs ['snesprefix' ]]
12381271 ) if sobjs ['options_prefix' ] else None
12391272
1273+ set_options = petsc_call (
1274+ self .cbbuilder ._options_efunc .name , []
1275+ )
1276+
12401277 snes_set_dm = petsc_call ('SNESSetDM' , [sobjs ['snes' ], dmda ])
12411278
12421279 create_matrix = petsc_call ('DMCreateMatrix' , [dmda , Byref (sobjs ['Jac' ])])
@@ -1281,9 +1318,9 @@ def _setup(self):
12811318 solver_params ['ksp_max_it' ]]
12821319 )
12831320
1284- ksp_set_type = petsc_call (
1285- 'KSPSetType' , [sobjs ['ksp' ], solver_mapper [solver_params ['ksp_type' ]]]
1286- )
1321+ # ksp_set_type = petsc_call(
1322+ # 'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
1323+ # )
12871324
12881325 ksp_get_pc = petsc_call (
12891326 'KSPGetPC' , [sobjs ['ksp' ], Byref (sobjs ['pc' ])]
@@ -1326,6 +1363,7 @@ def _setup(self):
13261363 base_setup = dmda_calls + (
13271364 snes_create ,
13281365 snes_options_prefix ,
1366+ set_options ,
13291367 snes_set_dm ,
13301368 create_matrix ,
13311369 snes_set_jac ,
@@ -1336,7 +1374,6 @@ def _setup(self):
13361374 global_b ,
13371375 snes_get_ksp ,
13381376 ksp_set_tols ,
1339- ksp_set_type ,
13401377 ksp_get_pc ,
13411378 pc_set_type ,
13421379 ksp_set_from_ops ,
0 commit comments