@@ -117,18 +117,40 @@ def _make_core(self):
117117 self ._make_initial_guess ()
118118 self ._make_user_struct_callback ()
119119
120+ def _make_petsc_callable (self , prefix , body , parameters = ()):
121+ return PETScCallable (
122+ self .sregistry .make_name (prefix = prefix ),
123+ body ,
124+ retval = self .objs ['err' ],
125+ parameters = parameters
126+ )
127+
128+ def _make_callable_body (self , body , stacks = (), casts = ()):
129+ return CallableBody (
130+ List (body = body ),
131+ init = (petsc_func_begin_user ,),
132+ stacks = stacks ,
133+ casts = casts ,
134+ retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
135+ )
136+
120137 def _make_options_callback (self ):
121- objs = self .objs
138+ """
139+ Create two callbacks: one to set PETSc options and one for
140+ to clear them.
141+
142+ Options are only set/cleared if they were not specifed via
143+ command line arguments.
144+ """
122145 params = self .solver_parameters
123146 prefix = self .inject_solve .expr .rhs .formatted_prefix
124147
125- set_body = []
126- clear_body = []
148+ set_body , clear_body = [], []
127149
128150 for k , v in params .items ():
129151 option = f'-{ prefix } { k } '
130152 if option in sys .argv :
131- # Ensures that the command line options take priority
153+ # Ensures that the command line args take priority
132154 continue
133155 option_name = String (option )
134156 option_value = Null if v is None else String (str (v ))
@@ -139,31 +161,12 @@ def _make_options_callback(self):
139161 petsc_call ('PetscOptionsClearValue' , [Null , option_name ])
140162 )
141163
142- set_body = CallableBody (
143- List (body = set_body ),
144- init = (petsc_func_begin_user ,),
145- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
146- )
164+ set_body = self ._make_callable_body (set_body )
165+ clear_body = self ._make_callable_body (clear_body )
147166
148- clear_body = CallableBody (
149- List (body = clear_body ),
150- init = (petsc_func_begin_user ,),
151- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
152- )
167+ set_callback = self ._make_petsc_callable ('SetPetscOptions' , set_body )
168+ clear_callback = self ._make_petsc_callable ('ClearPetscOptions' , clear_body )
153169
154- set_callback = PETScCallable (
155- self .sregistry .make_name (prefix = 'SetPetscOptions' ),
156- set_body ,
157- retval = objs ['err' ],
158- parameters = ()
159- )
160-
161- clear_callback = PETScCallable (
162- self .sregistry .make_name (prefix = 'ClearPetscOptions' ),
163- clear_body ,
164- retval = objs ['err' ],
165- parameters = ()
166- )
167170 self ._set_options_efunc = set_callback
168171 self ._efuncs [set_callback .name ] = set_callback
169172 self ._clear_options_efunc = clear_callback
@@ -179,13 +182,9 @@ def _make_matvec(self, jacobian, prefix='MatMult'):
179182 body = self ._create_matvec_body (
180183 List (body = irs .uiet .body ), jacobian
181184 )
182-
183185 objs = self .objs
184- cb = PETScCallable (
185- self .sregistry .make_name (prefix = prefix ),
186- body ,
187- retval = objs ['err' ],
188- parameters = (objs ['J' ], objs ['X' ], objs ['Y' ])
186+ cb = self ._make_petsc_callable (
187+ prefix , body , parameters = (objs ['J' ], objs ['X' ], objs ['Y' ])
189188 )
190189 self ._J_efuncs .append (cb )
191190 self ._efuncs [cb .name ] = cb
@@ -303,12 +302,7 @@ def _create_matvec_body(self, body, jacobian):
303302 # Dereference function data in struct
304303 derefs = dereference_funcs (ctx , fields )
305304
306- body = CallableBody (
307- List (body = body ),
308- init = (petsc_func_begin_user ,),
309- stacks = stacks + derefs ,
310- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
311- )
305+ body = self ._make_callable_body (body , stacks = stacks + derefs )
312306
313307 # Replace non-function data with pointer to data in struct
314308 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for i in fields }
@@ -318,6 +312,7 @@ def _create_matvec_body(self, body, jacobian):
318312 return body
319313
320314 def _make_formfunc (self ):
315+ objs = self .objs
321316 F_exprs = self .field_data .residual .F_exprs
322317 # Compile `F_exprs` into an IET via recursive compilation
323318 irs , _ = self .rcompile (
@@ -327,13 +322,9 @@ def _make_formfunc(self):
327322 body_formfunc = self ._create_formfunc_body (
328323 List (body = irs .uiet .body )
329324 )
330- objs = self .objs
331- cb = PETScCallable (
332- self .sregistry .make_name (prefix = 'FormFunction' ),
333- body_formfunc ,
334- retval = objs ['err' ],
335- parameters = (objs ['snes' ], objs ['X' ], objs ['F' ], objs ['dummyptr' ])
336- )
325+ parameters = (objs ['snes' ], objs ['X' ], objs ['F' ], objs ['dummyptr' ])
326+ cb = self ._make_petsc_callable ('FormFunction' , body_formfunc , parameters )
327+
337328 self ._F_efunc = cb
338329 self ._efuncs [cb .name ] = cb
339330
@@ -442,12 +433,7 @@ def _create_formfunc_body(self, body):
442433 # Dereference function data in struct
443434 derefs = dereference_funcs (ctx , fields )
444435
445- body = CallableBody (
446- List (body = body ),
447- init = (petsc_func_begin_user ,),
448- stacks = stacks + derefs ,
449- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),))
450-
436+ body = self ._make_callable_body (body , stacks = stacks + derefs )
451437 # Replace non-function data with pointer to data in struct
452438 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for i in fields }
453439
@@ -466,11 +452,8 @@ def _make_formrhs(self):
466452 List (body = irs .uiet .body )
467453 )
468454 objs = self .objs
469- cb = PETScCallable (
470- self .sregistry .make_name (prefix = 'FormRHS' ),
471- body ,
472- retval = objs ['err' ],
473- parameters = (sobjs ['callbackdm' ], objs ['B' ])
455+ cb = self ._make_petsc_callable (
456+ 'FormRHS' , body , parameters = (sobjs ['callbackdm' ], objs ['B' ])
474457 )
475458 self ._b_efunc = cb
476459 self ._efuncs [cb .name ] = cb
@@ -552,12 +535,7 @@ def _create_form_rhs_body(self, body):
552535 # Dereference function data in struct
553536 derefs = dereference_funcs (ctx , fields )
554537
555- body = CallableBody (
556- List (body = [body ]),
557- init = (petsc_func_begin_user ,),
558- stacks = stacks + derefs ,
559- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
560- )
538+ body = self ._make_callable_body ([body ], stacks = stacks + derefs )
561539
562540 # Replace non-function data with pointer to data in struct
563541 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
@@ -568,6 +546,7 @@ def _create_form_rhs_body(self, body):
568546 def _make_initial_guess (self ):
569547 exprs = self .field_data .initial_guess .exprs
570548 sobjs = self .solver_objs
549+ objs = self .objs
571550
572551 # Compile initital guess `eqns` into an IET via recursive compilation
573552 irs , _ = self .rcompile (
@@ -577,12 +556,8 @@ def _make_initial_guess(self):
577556 body = self ._create_initial_guess_body (
578557 List (body = irs .uiet .body )
579558 )
580- objs = self .objs
581- cb = PETScCallable (
582- self .sregistry .make_name (prefix = 'FormInitialGuess' ),
583- body ,
584- retval = objs ['err' ],
585- parameters = (sobjs ['callbackdm' ], objs ['xloc' ])
559+ cb = self ._make_petsc_callable (
560+ 'FormInitialGuess' , body , parameters = (sobjs ['callbackdm' ], objs ['xloc' ])
586561 )
587562 self ._initial_guesses .append (cb )
588563 self ._efuncs [cb .name ] = cb
@@ -629,13 +604,7 @@ def _create_initial_guess_body(self, body):
629604
630605 # Dereference function data in struct
631606 derefs = dereference_funcs (ctx , fields )
632-
633- body = CallableBody (
634- List (body = [body ]),
635- init = (petsc_func_begin_user ,),
636- stacks = stacks + derefs ,
637- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
638- )
607+ body = self ._make_callable_body (body , stacks = stacks + derefs )
639608
640609 # Replace non-function data with pointer to data in struct
641610 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
@@ -658,10 +627,7 @@ def _make_user_struct_callback(self):
658627 DummyExpr (FieldFromPointer (i ._C_symbol , mainctx ), i ._C_symbol )
659628 for i in mainctx .callback_fields
660629 ]
661- struct_callback_body = CallableBody (
662- List (body = body ), init = (petsc_func_begin_user ,),
663- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
664- )
630+ struct_callback_body = self ._make_callable_body (body )
665631 cb = Callable (
666632 self .sregistry .make_name (prefix = 'PopulateUserContext' ),
667633 struct_callback_body , self .objs ['err' ],
@@ -731,11 +697,9 @@ def _make_whole_matvec(self):
731697 objs = self .objs
732698 body = self ._whole_matvec_body ()
733699
734- cb = PETScCallable (
735- self .sregistry .make_name (prefix = 'WholeMatMult' ),
736- List (body = body ),
737- retval = objs ['err' ],
738- parameters = (objs ['J' ], objs ['X' ], objs ['Y' ])
700+ parameters = (objs ['J' ], objs ['X' ], objs ['Y' ])
701+ cb = self ._make_petsc_callable (
702+ 'WholeMatMult' , List (body = body ), parameters = parameters
739703 )
740704 self ._main_matvec_callback = cb
741705 self ._efuncs [cb .name ] = cb
@@ -782,13 +746,11 @@ def _whole_matvec_body(self):
782746 [objs ['Y' ], Deref (FieldFromPointer (rows , ctx )), Byref (Y )]
783747 ),
784748 )
785- return CallableBody (
786- List (body = (ctx_main , zero_y_memory , BlankLine ) + calls ),
787- init = (petsc_func_begin_user ,),
788- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
789- )
749+ body = (ctx_main , zero_y_memory , BlankLine ) + calls
750+ return self ._make_callable_body (body )
790751
791752 def _make_whole_formfunc (self ):
753+ objs = self .objs
792754 F_exprs = self .field_data .residual .F_exprs
793755 # Compile `F_exprs` into an IET via recursive compilation
794756 irs , _ = self .rcompile (
@@ -797,13 +759,11 @@ def _make_whole_formfunc(self):
797759 )
798760 body = self ._whole_formfunc_body (List (body = irs .uiet .body ))
799761
800- objs = self .objs
801- cb = PETScCallable (
802- self .sregistry .make_name (prefix = 'WholeFormFunc' ),
803- body ,
804- retval = objs ['err' ],
805- parameters = (objs ['snes' ], objs ['X' ], objs ['F' ], objs ['dummyptr' ])
762+ parameters = (objs ['snes' ], objs ['X' ], objs ['F' ], objs ['dummyptr' ])
763+ cb = self ._make_petsc_callable (
764+ 'WholeFormFunc' , body , parameters = parameters
806765 )
766+
807767 self ._F_efunc = cb
808768 self ._efuncs [cb .name ] = cb
809769
@@ -917,14 +877,10 @@ def _whole_formfunc_body(self, body):
917877 f_soa = PointerCast (fbundle )
918878 x_soa = PointerCast (xbundle )
919879
920- formfunc_body = CallableBody (
921- List (body = body ),
922- init = (petsc_func_begin_user ,),
923- stacks = stacks + derefs ,
880+ formfunc_body = self ._make_callable_body (
881+ body , stacks = stacks + derefs ,
924882 casts = (f_soa , x_soa ),
925- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
926883 )
927-
928884 # Replace non-function data with pointer to data in struct
929885 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for i in fields }
930886
@@ -941,12 +897,9 @@ def _create_submatrices(self):
941897 objs ['matreuse' ],
942898 objs ['Submats' ],
943899 )
944- cb = PETScCallable (
945- self .sregistry .make_name (prefix = 'MatCreateSubMatrices' ),
946- body ,
947- retval = objs ['err' ],
948- parameters = params
949- )
900+ cb = self ._make_petsc_callable (
901+ 'MatCreateSubMatrices' , body , parameters = params )
902+
950903 self ._submatrices_callback = cb
951904 self ._efuncs [cb .name ] = cb
952905
@@ -1068,12 +1021,7 @@ def _submat_callback_body(self):
10681021 iteration ,
10691022 ] + matmult_op
10701023
1071- return CallableBody (
1072- List (body = tuple (body )),
1073- init = (petsc_func_begin_user ,),
1074- stacks = (get_ctx , deref_subdm ),
1075- retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
1076- )
1024+ return self ._make_callable_body (tuple (body ), stacks = (get_ctx , deref_subdm ))
10771025
10781026
10791027class BaseObjectBuilder :
0 commit comments