2222 KSP , PC , SNES , PetscInt , StartPtr , PointerIS , PointerDM ,
2323 VecScatter , DMCast , JacobianStruct , SubMatrixStruct ,
2424 CallbackDM , PetscBool )
25+ from devito .petsc .types .macros import petsc_func_begin_user , Null
2526
2627
2728class CBBuilder :
@@ -118,20 +119,18 @@ def _make_core(self):
118119 def _petsc_options_callback (self ):
119120 objs = self .objs
120121 params = self .solver_parameters
121- Null = objs [ 'Null' ]
122+ options_prefix = self . inject_solve . expr . rhs . options_prefix
122123
123- has_names = ()
124+ body = []
124125
126+ # from IPython import embed; embed()
125127 # TODO: improve
126128 for k , v in params .items ():
127- is_set = PetscBool (self .sregistry .make_name (prefix = 'set' ))
128- has_name = petsc_call ('PetscOptionsHasName' , [
129- Null , Null , String (k ), Byref (is_set )])
130- has_names += (has_name ,)
129+ body .append (petsc_call ('SetPetscOption' , [String ("-" + options_prefix + k ), String (v )]))
131130
132131 body = CallableBody (
133- List (body = has_names ),
134- init = (objs [ 'begin_user' ] ,),
132+ List (body = body ),
133+ init = (petsc_func_begin_user ,),
135134 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
136135 )
137136
@@ -281,7 +280,7 @@ def _create_matvec_body(self, body, jacobian):
281280
282281 body = CallableBody (
283282 List (body = body ),
284- init = (objs [ 'begin_user' ] ,),
283+ init = (petsc_func_begin_user ,),
285284 stacks = stacks + derefs ,
286285 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
287286 )
@@ -420,7 +419,7 @@ def _create_formfunc_body(self, body):
420419
421420 body = CallableBody (
422421 List (body = body ),
423- init = (objs [ 'begin_user' ] ,),
422+ init = (petsc_func_begin_user ,),
424423 stacks = stacks + derefs ,
425424 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),))
426425
@@ -530,7 +529,7 @@ def _create_form_rhs_body(self, body):
530529
531530 body = CallableBody (
532531 List (body = [body ]),
533- init = (objs [ 'begin_user' ] ,),
532+ init = (petsc_func_begin_user ,),
534533 stacks = stacks + derefs ,
535534 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
536535 )
@@ -608,7 +607,7 @@ def _create_initial_guess_body(self, body):
608607
609608 body = CallableBody (
610609 List (body = [body ]),
611- init = (objs [ 'begin_user' ] ,),
610+ init = (petsc_func_begin_user ,),
612611 stacks = stacks + derefs ,
613612 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
614613 )
@@ -635,7 +634,7 @@ def _make_user_struct_callback(self):
635634 for i in mainctx .callback_fields
636635 ]
637636 struct_callback_body = CallableBody (
638- List (body = body ), init = (self . objs [ 'begin_user' ] ,),
637+ List (body = body ), init = (petsc_func_begin_user ,),
639638 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
640639 )
641640 cb = Callable (
@@ -759,7 +758,7 @@ def _whole_matvec_body(self):
759758 )
760759 return CallableBody (
761760 List (body = (ctx_main , zero_y_memory , BlankLine ) + calls ),
762- init = (objs [ 'begin_user' ] ,),
761+ init = (petsc_func_begin_user ,),
763762 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
764763 )
765764
@@ -894,7 +893,7 @@ def _whole_formfunc_body(self, body):
894893
895894 formfunc_body = CallableBody (
896895 List (body = body ),
897- init = (objs [ 'begin_user' ] ,),
896+ init = (petsc_func_begin_user ,),
898897 stacks = stacks + derefs ,
899898 casts = (f_soa , x_soa ),
900899 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
@@ -943,7 +942,6 @@ def _submat_callback_body(self):
943942
944943 get_ctx = petsc_call ('MatShellGetContext' , [objs ['J' ], Byref (objs ['ljacctx' ])])
945944
946- Null = objs ['Null' ]
947945 dm_get_info = petsc_call (
948946 'DMDAGetInfo' , [
949947 sobjs ['callbackdm' ], Null , Byref (sobjs ['M' ]), Byref (sobjs ['N' ]),
@@ -1046,7 +1044,7 @@ def _submat_callback_body(self):
10461044
10471045 return CallableBody (
10481046 List (body = tuple (body )),
1049- init = (objs [ 'begin_user' ] ,),
1047+ init = (petsc_func_begin_user ,),
10501048 stacks = (get_ctx , deref_subdm ),
10511049 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
10521050 )
@@ -1105,7 +1103,7 @@ def _build(self):
11051103 'localsize' : PetscInt (sreg .make_name (prefix = 'localsize' )),
11061104 'dmda' : DM (sreg .make_name (prefix = 'da' ), dofs = len (targets )),
11071105 'callbackdm' : CallbackDM (sreg .make_name (prefix = 'dm' )),
1108- 'snesprefix' : String (( options_prefix or '' ) + '_ ' ),
1106+ 'snesprefix' : String (options_prefix or '' ),
11091107 'options_prefix' : options_prefix ,
11101108 }
11111109 base_dict ['comm' ] = self .comm
@@ -1262,7 +1260,7 @@ def _setup(self):
12621260
12631261 dmda = sobjs ['dmda' ]
12641262
1265- solver_params = self .inject_solve .expr .rhs .solver_parameters
1263+ # solver_params = self.inject_solve.expr.rhs.solver_parameters
12661264
12671265 snes_create = petsc_call ('SNESCreate' , [sobjs ['comm' ], Byref (sobjs ['snes' ])])
12681266
@@ -1283,7 +1281,7 @@ def _setup(self):
12831281
12841282 snes_set_jac = petsc_call (
12851283 'SNESSetJacobian' , [sobjs ['snes' ], sobjs ['Jac' ],
1286- sobjs ['Jac' ], 'MatMFFDComputeJacobian' , objs [ ' Null' ] ]
1284+ sobjs ['Jac' ], 'MatMFFDComputeJacobian' , Null ]
12871285 )
12881286
12891287 global_x = petsc_call ('DMCreateGlobalVector' ,
@@ -1312,16 +1310,17 @@ def _setup(self):
13121310 snes_get_ksp = petsc_call ('SNESGetKSP' ,
13131311 [sobjs ['snes' ], Byref (sobjs ['ksp' ])])
13141312
1315- ksp_set_tols = petsc_call (
1316- 'KSPSetTolerances' , [sobjs ['ksp' ], solver_params ['ksp_rtol' ],
1317- solver_params ['ksp_atol' ], solver_params ['ksp_divtol' ],
1318- solver_params ['ksp_max_it' ]]
1319- )
1313+ # ksp_set_tols = petsc_call(
1314+ # 'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'],
1315+ # solver_params['ksp_atol'], solver_params['ksp_divtol'],
1316+ # solver_params['ksp_max_it']]
1317+ # )
13201318
13211319 # ksp_set_type = petsc_call(
13221320 # 'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
13231321 # )
13241322
1323+ # TODO: can drop this
13251324 ksp_get_pc = petsc_call (
13261325 'KSPGetPC' , [sobjs ['ksp' ], Byref (sobjs ['pc' ])]
13271326 )
@@ -1339,7 +1338,7 @@ def _setup(self):
13391338 formfunc = self .cbbuilder ._F_efunc
13401339 formfunc_operation = petsc_call (
13411340 'SNESSetFunction' ,
1342- [sobjs ['snes' ], objs [ ' Null' ] , FormFunctionCallback (formfunc .name , void , void ),
1341+ [sobjs ['snes' ], Null , FormFunctionCallback (formfunc .name , void , void ),
13431342 self .snes_ctx ]
13441343 )
13451344
@@ -1373,7 +1372,7 @@ def _setup(self):
13731372 get_local_size ,
13741373 global_b ,
13751374 snes_get_ksp ,
1376- ksp_set_tols ,
1375+ # ksp_set_tols,
13771376 ksp_get_pc ,
13781377 pc_set_type ,
13791378 ksp_set_from_ops ,
@@ -1430,7 +1429,7 @@ def _create_dmda(self, dmda):
14301429 stencil_width = self .field_data .space_order
14311430
14321431 args .append (stencil_width )
1433- args .extend ([objs [ ' Null' ] ]* nspace_dims )
1432+ args .extend ([Null ]* nspace_dims )
14341433
14351434 # The distributed array object
14361435 args .append (Byref (dmda ))
@@ -1465,7 +1464,7 @@ def _setup(self):
14651464
14661465 snes_set_jac = petsc_call (
14671466 'SNESSetJacobian' , [sobjs ['snes' ], sobjs ['Jac' ],
1468- sobjs ['Jac' ], 'MatMFFDComputeJacobian' , objs [ ' Null' ] ]
1467+ sobjs ['Jac' ], 'MatMFFDComputeJacobian' , Null ]
14691468 )
14701469
14711470 global_x = petsc_call ('DMCreateGlobalVector' ,
@@ -1506,7 +1505,7 @@ def _setup(self):
15061505 formfunc = self .cbbuilder ._F_efunc
15071506 formfunc_operation = petsc_call (
15081507 'SNESSetFunction' ,
1509- [sobjs ['snes' ], objs [ ' Null' ] , FormFunctionCallback (formfunc .name , void , void ),
1508+ [sobjs ['snes' ], Null , FormFunctionCallback (formfunc .name , void , void ),
15101509 self .snes_ctx ]
15111510 )
15121511
@@ -1529,7 +1528,7 @@ def _setup(self):
15291528
15301529 create_field_decomp = petsc_call (
15311530 'DMCreateFieldDecomposition' ,
1532- [dmda , Byref (sobjs ['nfields' ]), objs [ ' Null' ] , Byref (sobjs ['fields' ]),
1531+ [dmda , Byref (sobjs ['nfields' ]), Null , Byref (sobjs ['fields' ]),
15331532 Byref (sobjs ['subdms' ])]
15341533 )
15351534 submat_cb = self .cbbuilder .submatrices_callback
@@ -1710,7 +1709,7 @@ def _execute_solve(self):
17101709 ),
17111710 petsc_call (
17121711 'VecScatterCreate' ,
1713- [xglob , field , target_xglob , self . objs [ ' Null' ] , Byref (s )]
1712+ [xglob , field , target_xglob , Null , Byref (s )]
17141713 ),
17151714 petsc_call (
17161715 'VecScatterBegin' ,
@@ -1738,7 +1737,7 @@ def _execute_solve(self):
17381737 )
17391738 )
17401739
1741- snes_solve = (petsc_call ('SNESSolve' , [sobjs ['snes' ], objs [ ' Null' ] , xglob ]),)
1740+ snes_solve = (petsc_call ('SNESSolve' , [sobjs ['snes' ], Null , xglob ]),)
17421741
17431742 return (struct_assignment ,) + pre_solve + snes_solve + post_solve + (BlankLine ,)
17441743
0 commit comments