1616from devito .petsc .types import PETScArray , PetscBundle
1717from devito .petsc .iet .nodes import (PETScCallable , FormFunctionCallback ,
1818 MatShellSetOp , PetscMetaData )
19- from devito .petsc .iet .utils import petsc_call , petsc_struct , zero_vector
19+ from devito .petsc .iet .utils import (petsc_call , petsc_struct , zero_vector ,
20+ dereference_funcs , residual_bundle )
2021from devito .petsc .utils import solver_mapper
2122from devito .petsc .types import (DM , Mat , CallbackVec , Vec , KSP , PC , SNES ,
2223 PetscInt , StartPtr , PointerIS , PointerDM , VecScatter ,
@@ -120,7 +121,7 @@ def _make_core(self):
120121 self ._make_user_struct_callback ()
121122
122123 def _make_matvec (self , jacobian , prefix = 'MatMult' ):
123- # Compile matvec `eqns ` into an IET via recursive compilation
124+ # Compile `matvecs ` into an IET via recursive compilation
124125 matvecs = jacobian .matvecs
125126 irs , _ = self .rcompile (
126127 matvecs , options = {'mpi' : False }, sregistry = self .sregistry ,
@@ -251,7 +252,7 @@ def _create_matvec_body(self, body, jacobian):
251252 )
252253
253254 # Dereference function data in struct
254- derefs = self . dereference_funcs (ctx , fields )
255+ derefs = dereference_funcs (ctx , fields )
255256
256257 body = CallableBody (
257258 List (body = body ),
@@ -390,7 +391,7 @@ def _create_formfunc_body(self, body):
390391 )
391392
392393 # Dereference function data in struct
393- derefs = self . dereference_funcs (ctx , fields )
394+ derefs = dereference_funcs (ctx , fields )
394395
395396 body = CallableBody (
396397 List (body = body ),
@@ -500,7 +501,7 @@ def _create_form_rhs_body(self, body):
500501 )
501502
502503 # Dereference function data in struct
503- derefs = self . dereference_funcs (ctx , fields )
504+ derefs = dereference_funcs (ctx , fields )
504505
505506 body = CallableBody (
506507 List (body = [body ]),
@@ -578,7 +579,7 @@ def _create_initial_guess_body(self, body):
578579 )
579580
580581 # Dereference function data in struct
581- derefs = self . dereference_funcs (ctx , fields )
582+ derefs = dereference_funcs (ctx , fields )
582583
583584 body = CallableBody (
584585 List (body = [body ]),
@@ -643,12 +644,6 @@ def _uxreplace_efuncs(self):
643644 mapper .update ({k : visitor .visit (v )})
644645 return mapper
645646
646- def dereference_funcs (self , struct , fields ):
647- return tuple (
648- [Dereference (i , struct ) for i in
649- fields if isinstance (i .function , AbstractFunction )]
650- )
651-
652647
653648class CCBBuilder (CBBuilder ):
654649 def __init__ (self , ** kwargs ):
@@ -749,17 +744,17 @@ def _whole_matvec_body(self):
749744
750745 def _make_whole_formfunc (self ):
751746 F_exprs = self .fielddata .residual .F_exprs
752- # Compile formfunc `eqns ` into an IET via recursive compilation
753- irs_formfunc , _ = self .rcompile (
747+ # Compile `F_exprs ` into an IET via recursive compilation
748+ irs , _ = self .rcompile (
754749 F_exprs , options = {'mpi' : False }, sregistry = self .sregistry ,
755750 concretize_mapper = self .concretize_mapper
756751 )
757- body_formfunc = self ._whole_formfunc_body (List (body = irs_formfunc .uiet .body ))
752+ body = self ._whole_formfunc_body (List (body = irs .uiet .body ))
758753
759754 objs = self .objs
760755 cb = PETScCallable (
761756 self .sregistry .make_name (prefix = 'WholeFormFunc' ),
762- body_formfunc ,
757+ body ,
763758 retval = objs ['err' ],
764759 parameters = (objs ['snes' ], objs ['X' ], objs ['F' ], objs ['dummyptr' ])
765760 )
@@ -783,7 +778,8 @@ def _whole_formfunc_body(self, body):
783778 bundles = sobjs ['bundles' ]
784779 fbundle = bundles ['f' ]
785780 xbundle = bundles ['x' ]
786- body = self .residual_bundle (body , bundles )
781+
782+ body = residual_bundle (body , bundles )
787783
788784 dm_cast = DummyExpr (dmda , DMCast (objs ['dummyptr' ]), init = True )
789785
@@ -870,7 +866,7 @@ def _whole_formfunc_body(self, body):
870866 )
871867
872868 # Dereference function data in struct
873- derefs = self . dereference_funcs (ctx , fields )
869+ derefs = dereference_funcs (ctx , fields )
874870
875871 f_soa = PointerCast (fbundle )
876872 x_soa = PointerCast (xbundle )
@@ -1034,21 +1030,6 @@ def _submat_callback_body(self):
10341030 retstmt = (Call ('PetscFunctionReturn' , arguments = [0 ]),)
10351031 )
10361032
1037- def residual_bundle (self , body , bundles ):
1038- mapper = bundles ['bundle_mapper' ]
1039- indexeds = FindSymbols ('indexeds' ).visit (body )
1040- subs = {}
1041-
1042- for i in indexeds :
1043- if i .base in mapper :
1044- bundle = mapper [i .base ]
1045- index = bundles ['target_indices' ][i .function .target ]
1046- index = (index ,) + i .indices
1047- subs [i ] = bundle .__getitem__ (index )
1048-
1049- body = Uxreplace (subs ).visit (body )
1050- return body
1051-
10521033
10531034class BaseObjectBuilder :
10541035 """
0 commit comments