1717 MatShellSetOp , PetscMetaData )
1818from devito .petsc .iet .utils import (petsc_call , petsc_struct , zero_vector ,
1919 dereference_funcs , residual_bundle )
20- from devito .petsc .utils import solver_mapper
2120from devito .petsc .types import (PETScArray , PetscBundle , DM , Mat , CallbackVec , Vec ,
2221 KSP , PC , SNES , PetscInt , StartPtr , PointerIS , PointerDM ,
2322 VecScatter , DMCast , JacobianStruct , SubMatrixStruct ,
24- CallbackDM , PetscBool )
23+ CallbackDM )
2524from devito .petsc .types .macros import petsc_func_begin_user , Null
2625
2726
@@ -119,14 +118,14 @@ def _make_core(self):
119118 def _petsc_options_callback (self ):
120119 objs = self .objs
121120 params = self .solver_parameters
122- options_prefix = self .inject_solve .expr .rhs .options_prefix
121+ prefix = self .inject_solve .expr .rhs .formatted_prefix
123122
124- body = []
125-
126- # from IPython import embed; embed()
127- # TODO: improve
128- for k , v in params .items ():
129- body . append ( petsc_call ( 'SetPetscOption' , [ String ( "-" + options_prefix + k ), String ( v )]))
123+ body = [
124+ petsc_call (
125+ 'SetPetscOption' , [ String ( f"- { prefix } { k } " ), String ( str ( v ))]
126+ )
127+ for k , v in params .items ()
128+ ]
130129
131130 body = CallableBody (
132131 List (body = body ),
@@ -695,6 +694,7 @@ def _make_core(self):
695694 for sm in self .field_data .jacobian .nonzero_submatrices :
696695 self ._make_matvec (sm , prefix = f'{ sm .name } _MatMult' )
697696
697+ self ._petsc_options_callback ()
698698 self ._make_whole_matvec ()
699699 self ._make_whole_formfunc ()
700700 self ._make_user_struct_callback ()
@@ -1089,7 +1089,7 @@ def _build(self):
10891089 targets = self .field_data .targets
10901090
10911091 snes_name = sreg .make_name (prefix = 'snes' )
1092- options_prefix = self .inject_solve .expr .rhs .options_prefix
1092+ formatted_prefix = self .inject_solve .expr .rhs .formatted_prefix
10931093
10941094 base_dict = {
10951095 'Jac' : Mat (sreg .make_name (prefix = 'J' )),
@@ -1103,9 +1103,9 @@ def _build(self):
11031103 'localsize' : PetscInt (sreg .make_name (prefix = 'localsize' )),
11041104 'dmda' : DM (sreg .make_name (prefix = 'da' ), dofs = len (targets )),
11051105 'callbackdm' : CallbackDM (sreg .make_name (prefix = 'dm' )),
1106- 'snesprefix' : String (options_prefix or '' ),
1107- 'options_prefix' : options_prefix ,
1106+ 'snes_prefix' : String (formatted_prefix ),
11081107 }
1108+
11091109 base_dict ['comm' ] = self .comm
11101110 self ._target_dependent (base_dict )
11111111 return self ._extend_build (base_dict )
@@ -1244,6 +1244,7 @@ def __init__(self, **kwargs):
12441244 self .solver_objs = kwargs .get ('solver_objs' )
12451245 self .cbbuilder = kwargs .get ('cbbuilder' )
12461246 self .field_data = self .inject_solve .expr .rhs .field_data
1247+ self .formatted_prefix = self .inject_solve .expr .rhs .formatted_prefix
12471248 self .calls = self ._setup ()
12481249
12491250 @property
@@ -1255,18 +1256,14 @@ def snes_ctx(self):
12551256 return VOID (self .solver_objs ['dmda' ], stars = '*' )
12561257
12571258 def _setup (self ):
1258- objs = self .objs
12591259 sobjs = self .solver_objs
1260-
12611260 dmda = sobjs ['dmda' ]
12621261
1263- # solver_params = self.inject_solve.expr.rhs.solver_parameters
1264-
12651262 snes_create = petsc_call ('SNESCreate' , [sobjs ['comm' ], Byref (sobjs ['snes' ])])
12661263
12671264 snes_options_prefix = petsc_call (
1268- 'SNESSetOptionsPrefix' , [sobjs ['snes' ], sobjs ['snesprefix ' ]]
1269- ) if sobjs [ 'options_prefix' ] else None
1265+ 'SNESSetOptionsPrefix' , [sobjs ['snes' ], sobjs ['snes_prefix ' ]]
1266+ ) if self . formatted_prefix else None
12701267
12711268 set_options = petsc_call (
12721269 self .cbbuilder ._options_efunc .name , []
@@ -1276,9 +1273,6 @@ def _setup(self):
12761273
12771274 create_matrix = petsc_call ('DMCreateMatrix' , [dmda , Byref (sobjs ['Jac' ])])
12781275
1279- # NOTE: Assuming all solves are linear for now
1280- snes_set_type = petsc_call ('SNESSetType' , [sobjs ['snes' ], 'SNESKSPONLY' ])
1281-
12821276 snes_set_jac = petsc_call (
12831277 'SNESSetJacobian' , [sobjs ['snes' ], sobjs ['Jac' ],
12841278 sobjs ['Jac' ], 'MatMFFDComputeJacobian' , Null ]
@@ -1295,6 +1289,7 @@ def _setup(self):
12951289 local_size = math .prod (
12961290 v for v , dim in zip (target .shape_allocated , target .dimensions ) if dim .is_Space
12971291 )
1292+ # TODO: Check, maybe this should be VecCreateSeqWithArray
12981293 local_x = petsc_call ('VecCreateMPIWithArray' ,
12991294 [sobjs ['comm' ], 1 , local_size , 'PETSC_DECIDE' ,
13001295 field_from_ptr , Byref (sobjs ['xlocal' ])])
@@ -1310,26 +1305,6 @@ def _setup(self):
13101305 snes_get_ksp = petsc_call ('SNESGetKSP' ,
13111306 [sobjs ['snes' ], Byref (sobjs ['ksp' ])])
13121307
1313- # ksp_set_tols = petsc_call(
1314- # 'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'],
1315- # solver_params['ksp_atol'], solver_params['ksp_divtol'],
1316- # solver_params['ksp_max_it']]
1317- # )
1318-
1319- # ksp_set_type = petsc_call(
1320- # 'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
1321- # )
1322-
1323- # TODO: can drop this
1324- ksp_get_pc = petsc_call (
1325- 'KSPGetPC' , [sobjs ['ksp' ], Byref (sobjs ['pc' ])]
1326- )
1327-
1328- # Even though the default will be jacobi, set to PCNONE for now
1329- pc_set_type = petsc_call ('PCSetType' , [sobjs ['pc' ], 'PCNONE' ])
1330-
1331- ksp_set_from_ops = petsc_call ('KSPSetFromOptions' , [sobjs ['ksp' ]])
1332-
13331308 matvec = self .cbbuilder .main_matvec_callback
13341309 matvec_operation = petsc_call (
13351310 'MatShellSetOperation' ,
@@ -1366,16 +1341,11 @@ def _setup(self):
13661341 snes_set_dm ,
13671342 create_matrix ,
13681343 snes_set_jac ,
1369- snes_set_type ,
13701344 global_x ,
13711345 local_x ,
13721346 get_local_size ,
13731347 global_b ,
13741348 snes_get_ksp ,
1375- # ksp_set_tols,
1376- ksp_get_pc ,
1377- pc_set_type ,
1378- ksp_set_from_ops ,
13791349 matvec_operation ,
13801350 formfunc_operation ,
13811351 snes_set_options ,
@@ -1400,7 +1370,6 @@ def _create_dmda_calls(self, dmda):
14001370 return dmda_create , dm_setup , dm_mat_type
14011371
14021372 def _create_dmda (self , dmda ):
1403- objs = self .objs
14041373 sobjs = self .solver_objs
14051374 grid = self .field_data .grid
14061375 nspace_dims = len (grid .dimensions )
@@ -1445,23 +1414,22 @@ def _setup(self):
14451414 # TODO: minimise code duplication with superclass
14461415 objs = self .objs
14471416 sobjs = self .solver_objs
1448-
14491417 dmda = sobjs ['dmda' ]
1450- solver_params = self .inject_solve .expr .rhs .solver_parameters
14511418
14521419 snes_create = petsc_call ('SNESCreate' , [sobjs ['comm' ], Byref (sobjs ['snes' ])])
14531420
14541421 snes_options_prefix = petsc_call (
1455- 'SNESSetOptionsPrefix' , [sobjs ['snes' ], sobjs ['snesprefix' ]]
1456- ) if sobjs ['options_prefix' ] else None
1422+ 'SNESSetOptionsPrefix' , [sobjs ['snes' ], sobjs ['snes_prefix' ]]
1423+ ) if self .formatted_prefix else None
1424+
1425+ set_options = petsc_call (
1426+ self .cbbuilder ._options_efunc .name , []
1427+ )
14571428
14581429 snes_set_dm = petsc_call ('SNESSetDM' , [sobjs ['snes' ], dmda ])
14591430
14601431 create_matrix = petsc_call ('DMCreateMatrix' , [dmda , Byref (sobjs ['Jac' ])])
14611432
1462- # NOTE: Assuming all solves are linear for now
1463- snes_set_type = petsc_call ('SNESSetType' , [sobjs ['snes' ], 'SNESKSPONLY' ])
1464-
14651433 snes_set_jac = petsc_call (
14661434 'SNESSetJacobian' , [sobjs ['snes' ], sobjs ['Jac' ],
14671435 sobjs ['Jac' ], 'MatMFFDComputeJacobian' , Null ]
@@ -1478,25 +1446,6 @@ def _setup(self):
14781446 snes_get_ksp = petsc_call ('SNESGetKSP' ,
14791447 [sobjs ['snes' ], Byref (sobjs ['ksp' ])])
14801448
1481- ksp_set_tols = petsc_call (
1482- 'KSPSetTolerances' , [sobjs ['ksp' ], solver_params ['ksp_rtol' ],
1483- solver_params ['ksp_atol' ], solver_params ['ksp_divtol' ],
1484- solver_params ['ksp_max_it' ]]
1485- )
1486-
1487- ksp_set_type = petsc_call (
1488- 'KSPSetType' , [sobjs ['ksp' ], solver_mapper [solver_params ['ksp_type' ]]]
1489- )
1490-
1491- ksp_get_pc = petsc_call (
1492- 'KSPGetPC' , [sobjs ['ksp' ], Byref (sobjs ['pc' ])]
1493- )
1494-
1495- # Even though the default will be jacobi, set to PCNONE for now
1496- pc_set_type = petsc_call ('PCSetType' , [sobjs ['pc' ], 'PCNONE' ])
1497-
1498- ksp_set_from_ops = petsc_call ('KSPSetFromOptions' , [sobjs ['ksp' ]])
1499-
15001449 matvec = self .cbbuilder .main_matvec_callback
15011450 matvec_operation = petsc_call (
15021451 'MatShellSetOperation' ,
@@ -1569,19 +1518,14 @@ def _setup(self):
15691518 coupled_setup = dmda_calls + (
15701519 snes_create ,
15711520 snes_options_prefix ,
1521+ set_options ,
15721522 snes_set_dm ,
15731523 create_matrix ,
15741524 snes_set_jac ,
1575- snes_set_type ,
15761525 global_x ,
15771526 local_x ,
15781527 get_local_size ,
15791528 snes_get_ksp ,
1580- ksp_set_tols ,
1581- ksp_set_type ,
1582- ksp_get_pc ,
1583- pc_set_type ,
1584- ksp_set_from_ops ,
15851529 matvec_operation ,
15861530 formfunc_operation ,
15871531 snes_set_options ,
@@ -1678,7 +1622,6 @@ def _execute_solve(self):
16781622 Assigns the required time iterators to the struct and executes
16791623 the necessary calls to execute the SNES solver.
16801624 """
1681- objs = self .objs
16821625 sobjs = self .solver_objs
16831626 xglob = sobjs ['xglobal' ]
16841627
@@ -1794,10 +1737,10 @@ class TimeDependent(NonTimeDependent):
17941737 for each `SNESSolve` at every time step, don't require the time loop, but
17951738 may still need access to data from other time steps.
17961739 - All `Function` objects are passed through the initial lowering via the
1797- `LinearSolveExpr ` object, ensuring the correct time loop is generated
1740+ `SolveExpr ` object, ensuring the correct time loop is generated
17981741 in the main kernel.
17991742 - Another mapper is created based on the modulo dimensions
1800- generated by the `LinearSolveExpr ` object in the main kernel
1743+ generated by the `SolveExpr ` object in the main kernel
18011744 (e.g., {time: time, t: t0, t + 1: t1}).
18021745 - These two mappers are used to generate a final mapper `symb_to_moddim`
18031746 (e.g. {tau0: t0, tau1: t1}) which is used at the IET level to
0 commit comments