@@ -14,12 +14,12 @@ def make_core_petsc_calls(objs, comm):
1414 return call_mpi , BlankLine
1515
1616
17- class BaseSetup :
17+ class BuilderBase :
1818 def __init__ (self , ** kwargs ):
1919 self .inject_solve = kwargs .get ('inject_solve' )
2020 self .objs = kwargs .get ('objs' )
2121 self .solver_objs = kwargs .get ('solver_objs' )
22- self .cbbuilder = kwargs .get ('cbbuilder ' )
22+ self .callback_builder = kwargs .get ('callback_builder ' )
2323 self .field_data = self .inject_solve .expr .rhs .field_data
2424 self .formatted_prefix = self .inject_solve .expr .rhs .formatted_prefix
2525 self .calls = self ._setup ()
@@ -31,7 +31,63 @@ def snes_ctx(self):
3131 https://petsc.org/main/manualpages/SNES/SNESSetFunction/
3232 """
3333 return VOID (self .solver_objs ['dmda' ], stars = '*' )
34+
35+ def _setup (self ):
36+ return ()
37+
38+ def _extend_setup (self ):
39+ """
40+ Hook for subclasses to add additional setup calls.
41+ """
42+ return ()
43+
44+ def _create_dmda_calls (self , dmda ):
45+ dmda_create = self ._create_dmda (dmda )
46+ dm_setup = petsc_call ('DMSetUp' , [dmda ])
47+ dm_mat_type = petsc_call ('DMSetMatType' , [dmda , 'MATSHELL' ])
48+ return dmda_create , dm_setup , dm_mat_type
49+
50+ def _create_dmda (self , dmda ):
51+ sobjs = self .solver_objs
52+ grid = self .field_data .grid
53+ nspace_dims = len (grid .dimensions )
54+
55+ # MPI communicator
56+ args = [sobjs ['comm' ]]
57+
58+ # Type of ghost nodes
59+ args .extend (['DM_BOUNDARY_GHOSTED' for _ in range (nspace_dims )])
60+
61+ # Stencil type
62+ if nspace_dims > 1 :
63+ args .append ('DMDA_STENCIL_BOX' )
64+
65+ # Global dimensions
66+ args .extend (list (grid .shape )[::- 1 ])
67+ # No.of processors in each dimension
68+ if nspace_dims > 1 :
69+ args .extend (list (grid .distributor .topology )[::- 1 ])
70+
71+ # Number of degrees of freedom per node
72+ args .append (dmda .dofs )
73+ # "Stencil width" -> size of overlap
74+ # TODO: Instead, this probably should be
75+ # extracted from field_data.target._size_outhalo?
76+ stencil_width = self .field_data .space_order
77+
78+ args .append (stencil_width )
79+ args .extend ([Null ]* nspace_dims )
80+
81+ # The distributed array object
82+ args .append (Byref (dmda ))
83+
84+ # The PETSc call used to create the DMDA
85+ dmda = petsc_call (f'DMDACreate{ nspace_dims } d' , args )
86+
87+ return dmda
88+
3489
90+ class Builder (BuilderBase ):
3591 def _setup (self ):
3692 sobjs = self .solver_objs
3793 dmda = sobjs ['dmda' ]
@@ -43,7 +99,7 @@ def _setup(self):
4399 ) if self .formatted_prefix else None
44100
45101 set_options = petsc_call (
46- self .cbbuilder ._set_options_efunc .name , []
102+ self .callback_builder ._set_options_efunc .name , []
47103 )
48104
49105 snes_set_dm = petsc_call ('SNESSetDM' , [sobjs ['snes' ], dmda ])
@@ -82,12 +138,12 @@ def _setup(self):
82138 snes_get_ksp = petsc_call ('SNESGetKSP' ,
83139 [sobjs ['snes' ], Byref (sobjs ['ksp' ])])
84140
85- matvec = self .cbbuilder .main_matvec_callback
141+ matvec = self .callback_builder .main_matvec_callback
86142 matvec_operation = petsc_call (
87143 'MatShellSetOperation' ,
88144 [sobjs ['Jac' ], 'MATOP_MULT' , MatShellSetOp (matvec .name , void , void )]
89145 )
90- formfunc = self .cbbuilder ._F_efunc
146+ formfunc = self .callback_builder ._F_efunc
91147 formfunc_operation = petsc_call (
92148 'SNESSetFunction' ,
93149 [sobjs ['snes' ], Null , FormFunctionCallback (formfunc .name , void , void ),
@@ -103,7 +159,7 @@ def _setup(self):
103159 mainctx = sobjs ['userctx' ]
104160
105161 call_struct_callback = petsc_call (
106- self .cbbuilder .user_struct_callback .name , [Byref (mainctx )]
162+ self .callback_builder .user_struct_callback .name , [Byref (mainctx )]
107163 )
108164
109165 # TODO: maybe don't need to explictly set this
@@ -134,59 +190,8 @@ def _setup(self):
134190 extended_setup = self ._extend_setup ()
135191 return base_setup + extended_setup
136192
137- def _extend_setup (self ):
138- """
139- Hook for subclasses to add additional setup calls.
140- """
141- return ()
142-
143- def _create_dmda_calls (self , dmda ):
144- dmda_create = self ._create_dmda (dmda )
145- dm_setup = petsc_call ('DMSetUp' , [dmda ])
146- dm_mat_type = petsc_call ('DMSetMatType' , [dmda , 'MATSHELL' ])
147- return dmda_create , dm_setup , dm_mat_type
148-
149- def _create_dmda (self , dmda ):
150- sobjs = self .solver_objs
151- grid = self .field_data .grid
152- nspace_dims = len (grid .dimensions )
153-
154- # MPI communicator
155- args = [sobjs ['comm' ]]
156-
157- # Type of ghost nodes
158- args .extend (['DM_BOUNDARY_GHOSTED' for _ in range (nspace_dims )])
159-
160- # Stencil type
161- if nspace_dims > 1 :
162- args .append ('DMDA_STENCIL_BOX' )
163-
164- # Global dimensions
165- args .extend (list (grid .shape )[::- 1 ])
166- # No.of processors in each dimension
167- if nspace_dims > 1 :
168- args .extend (list (grid .distributor .topology )[::- 1 ])
169-
170- # Number of degrees of freedom per node
171- args .append (dmda .dofs )
172- # "Stencil width" -> size of overlap
173- # TODO: Instead, this probably should be
174- # extracted from field_data.target._size_outhalo?
175- stencil_width = self .field_data .space_order
176-
177- args .append (stencil_width )
178- args .extend ([Null ]* nspace_dims )
179-
180- # The distributed array object
181- args .append (Byref (dmda ))
182-
183- # The PETSc call used to create the DMDA
184- dmda = petsc_call (f'DMDACreate{ nspace_dims } d' , args )
185-
186- return dmda
187-
188193
189- class CoupledSetup ( BaseSetup ):
194+ class CoupledBuilder ( BuilderBase ):
190195 def _setup (self ):
191196 # TODO: minimise code duplication with superclass
192197 objs = self .objs
@@ -200,7 +205,7 @@ def _setup(self):
200205 ) if self .formatted_prefix else None
201206
202207 set_options = petsc_call (
203- self .cbbuilder ._set_options_efunc .name , []
208+ self .callback_builder ._set_options_efunc .name , []
204209 )
205210
206211 snes_set_dm = petsc_call ('SNESSetDM' , [sobjs ['snes' ], dmda ])
@@ -223,12 +228,12 @@ def _setup(self):
223228 snes_get_ksp = petsc_call ('SNESGetKSP' ,
224229 [sobjs ['snes' ], Byref (sobjs ['ksp' ])])
225230
226- matvec = self .cbbuilder .main_matvec_callback
231+ matvec = self .callback_builder .main_matvec_callback
227232 matvec_operation = petsc_call (
228233 'MatShellSetOperation' ,
229234 [sobjs ['Jac' ], 'MATOP_MULT' , MatShellSetOp (matvec .name , void , void )]
230235 )
231- formfunc = self .cbbuilder ._F_efunc
236+ formfunc = self .callback_builder ._F_efunc
232237 formfunc_operation = petsc_call (
233238 'SNESSetFunction' ,
234239 [sobjs ['snes' ], Null , FormFunctionCallback (formfunc .name , void , void ),
@@ -244,7 +249,7 @@ def _setup(self):
244249 mainctx = sobjs ['userctx' ]
245250
246251 call_struct_callback = petsc_call (
247- self .cbbuilder .user_struct_callback .name , [Byref (mainctx )]
252+ self .callback_builder .user_struct_callback .name , [Byref (mainctx )]
248253 )
249254
250255 # TODO: maybe don't need to explictly set this
@@ -257,7 +262,7 @@ def _setup(self):
257262 [dmda , Byref (sobjs ['nfields' ]), Null , Byref (sobjs ['fields' ]),
258263 Byref (sobjs ['subdms' ])]
259264 )
260- submat_cb = self .cbbuilder .submatrices_callback
265+ submat_cb = self .callback_builder .submatrices_callback
261266 matop_create_submats_op = petsc_call (
262267 'MatShellSetOperation' ,
263268 [sobjs ['Jac' ], 'MATOP_CREATE_SUBMATRICES' ,
0 commit comments