Skip to content

Commit 242d5b5

Browse files
committed
misc: Rename internals to config
1 parent 150dbf5 commit 242d5b5

6 files changed

Lines changed: 98 additions & 93 deletions

File tree

devito/petsc/iet/builder.py

Lines changed: 68 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ def make_core_petsc_calls(objs, comm):
1414
return call_mpi, BlankLine
1515

1616

17-
class BaseSetup:
17+
class BuilderBase:
1818
def __init__(self, **kwargs):
1919
self.inject_solve = kwargs.get('inject_solve')
2020
self.objs = kwargs.get('objs')
2121
self.solver_objs = kwargs.get('solver_objs')
22-
self.cbbuilder = kwargs.get('cbbuilder')
22+
self.callback_builder = kwargs.get('callback_builder')
2323
self.field_data = self.inject_solve.expr.rhs.field_data
2424
self.formatted_prefix = self.inject_solve.expr.rhs.formatted_prefix
2525
self.calls = self._setup()
@@ -31,7 +31,63 @@ def snes_ctx(self):
3131
https://petsc.org/main/manualpages/SNES/SNESSetFunction/
3232
"""
3333
return VOID(self.solver_objs['dmda'], stars='*')
34+
35+
def _setup(self):
36+
return ()
37+
38+
def _extend_setup(self):
39+
"""
40+
Hook for subclasses to add additional setup calls.
41+
"""
42+
return ()
43+
44+
def _create_dmda_calls(self, dmda):
45+
dmda_create = self._create_dmda(dmda)
46+
dm_setup = petsc_call('DMSetUp', [dmda])
47+
dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL'])
48+
return dmda_create, dm_setup, dm_mat_type
49+
50+
def _create_dmda(self, dmda):
51+
sobjs = self.solver_objs
52+
grid = self.field_data.grid
53+
nspace_dims = len(grid.dimensions)
54+
55+
# MPI communicator
56+
args = [sobjs['comm']]
57+
58+
# Type of ghost nodes
59+
args.extend(['DM_BOUNDARY_GHOSTED' for _ in range(nspace_dims)])
60+
61+
# Stencil type
62+
if nspace_dims > 1:
63+
args.append('DMDA_STENCIL_BOX')
64+
65+
# Global dimensions
66+
args.extend(list(grid.shape)[::-1])
67+
# No.of processors in each dimension
68+
if nspace_dims > 1:
69+
args.extend(list(grid.distributor.topology)[::-1])
70+
71+
# Number of degrees of freedom per node
72+
args.append(dmda.dofs)
73+
# "Stencil width" -> size of overlap
74+
# TODO: Instead, this probably should be
75+
# extracted from field_data.target._size_outhalo?
76+
stencil_width = self.field_data.space_order
77+
78+
args.append(stencil_width)
79+
args.extend([Null]*nspace_dims)
80+
81+
# The distributed array object
82+
args.append(Byref(dmda))
83+
84+
# The PETSc call used to create the DMDA
85+
dmda = petsc_call(f'DMDACreate{nspace_dims}d', args)
86+
87+
return dmda
88+
3489

90+
class Builder(BuilderBase):
3591
def _setup(self):
3692
sobjs = self.solver_objs
3793
dmda = sobjs['dmda']
@@ -43,7 +99,7 @@ def _setup(self):
4399
) if self.formatted_prefix else None
44100

45101
set_options = petsc_call(
46-
self.cbbuilder._set_options_efunc.name, []
102+
self.callback_builder._set_options_efunc.name, []
47103
)
48104

49105
snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda])
@@ -82,12 +138,12 @@ def _setup(self):
82138
snes_get_ksp = petsc_call('SNESGetKSP',
83139
[sobjs['snes'], Byref(sobjs['ksp'])])
84140

85-
matvec = self.cbbuilder.main_matvec_callback
141+
matvec = self.callback_builder.main_matvec_callback
86142
matvec_operation = petsc_call(
87143
'MatShellSetOperation',
88144
[sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)]
89145
)
90-
formfunc = self.cbbuilder._F_efunc
146+
formfunc = self.callback_builder._F_efunc
91147
formfunc_operation = petsc_call(
92148
'SNESSetFunction',
93149
[sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void),
@@ -103,7 +159,7 @@ def _setup(self):
103159
mainctx = sobjs['userctx']
104160

105161
call_struct_callback = petsc_call(
106-
self.cbbuilder.user_struct_callback.name, [Byref(mainctx)]
162+
self.callback_builder.user_struct_callback.name, [Byref(mainctx)]
107163
)
108164

109165
# TODO: maybe don't need to explictly set this
@@ -134,59 +190,8 @@ def _setup(self):
134190
extended_setup = self._extend_setup()
135191
return base_setup + extended_setup
136192

137-
def _extend_setup(self):
138-
"""
139-
Hook for subclasses to add additional setup calls.
140-
"""
141-
return ()
142-
143-
def _create_dmda_calls(self, dmda):
144-
dmda_create = self._create_dmda(dmda)
145-
dm_setup = petsc_call('DMSetUp', [dmda])
146-
dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL'])
147-
return dmda_create, dm_setup, dm_mat_type
148-
149-
def _create_dmda(self, dmda):
150-
sobjs = self.solver_objs
151-
grid = self.field_data.grid
152-
nspace_dims = len(grid.dimensions)
153-
154-
# MPI communicator
155-
args = [sobjs['comm']]
156-
157-
# Type of ghost nodes
158-
args.extend(['DM_BOUNDARY_GHOSTED' for _ in range(nspace_dims)])
159-
160-
# Stencil type
161-
if nspace_dims > 1:
162-
args.append('DMDA_STENCIL_BOX')
163-
164-
# Global dimensions
165-
args.extend(list(grid.shape)[::-1])
166-
# No.of processors in each dimension
167-
if nspace_dims > 1:
168-
args.extend(list(grid.distributor.topology)[::-1])
169-
170-
# Number of degrees of freedom per node
171-
args.append(dmda.dofs)
172-
# "Stencil width" -> size of overlap
173-
# TODO: Instead, this probably should be
174-
# extracted from field_data.target._size_outhalo?
175-
stencil_width = self.field_data.space_order
176-
177-
args.append(stencil_width)
178-
args.extend([Null]*nspace_dims)
179-
180-
# The distributed array object
181-
args.append(Byref(dmda))
182-
183-
# The PETSc call used to create the DMDA
184-
dmda = petsc_call(f'DMDACreate{nspace_dims}d', args)
185-
186-
return dmda
187-
188193

189-
class CoupledSetup(BaseSetup):
194+
class CoupledBuilder(BuilderBase):
190195
def _setup(self):
191196
# TODO: minimise code duplication with superclass
192197
objs = self.objs
@@ -200,7 +205,7 @@ def _setup(self):
200205
) if self.formatted_prefix else None
201206

202207
set_options = petsc_call(
203-
self.cbbuilder._set_options_efunc.name, []
208+
self.callback_builder._set_options_efunc.name, []
204209
)
205210

206211
snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda])
@@ -223,12 +228,12 @@ def _setup(self):
223228
snes_get_ksp = petsc_call('SNESGetKSP',
224229
[sobjs['snes'], Byref(sobjs['ksp'])])
225230

226-
matvec = self.cbbuilder.main_matvec_callback
231+
matvec = self.callback_builder.main_matvec_callback
227232
matvec_operation = petsc_call(
228233
'MatShellSetOperation',
229234
[sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)]
230235
)
231-
formfunc = self.cbbuilder._F_efunc
236+
formfunc = self.callback_builder._F_efunc
232237
formfunc_operation = petsc_call(
233238
'SNESSetFunction',
234239
[sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void),
@@ -244,7 +249,7 @@ def _setup(self):
244249
mainctx = sobjs['userctx']
245250

246251
call_struct_callback = petsc_call(
247-
self.cbbuilder.user_struct_callback.name, [Byref(mainctx)]
252+
self.callback_builder.user_struct_callback.name, [Byref(mainctx)]
248253
)
249254

250255
# TODO: maybe don't need to explictly set this
@@ -257,7 +262,7 @@ def _setup(self):
257262
[dmda, Byref(sobjs['nfields']), Null, Byref(sobjs['fields']),
258263
Byref(sobjs['subdms'])]
259264
)
260-
submat_cb = self.cbbuilder.submatrices_callback
265+
submat_cb = self.callback_builder.submatrices_callback
261266
matop_create_submats_op = petsc_call(
262267
'MatShellSetOperation',
263268
[sobjs['Jac'], 'MATOP_CREATE_SUBMATRICES',

devito/petsc/iet/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from devito.petsc.iet.nodes import PETScCallable, MatShellSetOp, petsc_call
1616
from devito.petsc.types import DMCast, MainUserStruct, CallbackUserStruct
17-
from devito.petsc.iet.objects import objs
17+
from devito.petsc.iet.type_builder import objs
1818
from devito.petsc.types.macros import petsc_func_begin_user
1919
from devito.petsc.types.modes import InsertMode
2020

devito/petsc/iet/passes.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from devito.petsc.iet.callbacks import (
2222
BaseCallback, CoupledCallback, populate_matrix_context, get_user_struct_fields
2323
)
24-
from devito.petsc.iet.objects import BaseObjectBuilder, CoupledObjectBuilder, objs
25-
from devito.petsc.iet.builder import BaseSetup, CoupledSetup, make_core_petsc_calls
26-
from devito.petsc.iet.solver import Solver, CoupledSolver
24+
from devito.petsc.iet.type_builder import BaseTypeBuilder, CoupledTypeBuilder, objs
25+
from devito.petsc.iet.builder import Builder, CoupledBuilder, make_core_petsc_calls
26+
from devito.petsc.iet.solve import Solve, CoupledSolve
2727
from devito.petsc.iet.time_dependence import TimeDependent, TimeIndependent
2828
from devito.petsc.iet.logging import PetscLogger
2929

@@ -86,17 +86,17 @@ def lower_petsc(iet, **kwargs):
8686

8787
for iters, (inject_solve,) in inject_solve_mapper.items():
8888

89-
builder = Builder(inject_solve, iters, comm, section_mapper, **kwargs)
89+
solver = BuildSolver(inject_solve, iters, comm, section_mapper, **kwargs)
9090

91-
setup.extend(builder.solver_setup.calls)
91+
setup.extend(solver.builder.calls)
9292

9393
# Transform the spatial iteration loop with the calls to execute the solver
94-
subs.update({builder.solve.spatial_body: builder.calls})
94+
subs.update({solver.solve.spatial_body: solver.calls})
9595

96-
efuncs.update(builder.cbbuilder.efuncs)
96+
efuncs.update(solver.callback_builder.efuncs)
9797

9898
clear_options.extend((petsc_call(
99-
builder.cbbuilder._clear_options_efunc.name, []
99+
solver.callback_builder._clear_options_efunc.name, []
100100
),))
101101

102102
populate_matrix_context(efuncs)
@@ -226,7 +226,7 @@ def finalize(iet):
226226
return iet._rebuild(body=finalize_body)
227227

228228

229-
class Builder:
229+
class BuildSolver:
230230
"""
231231
This class is designed to support future extensions, enabling
232232
different combinations of solver types, preconditioning methods,
@@ -252,17 +252,17 @@ def __init__(self, inject_solve, iters, comm, section_mapper, **kwargs):
252252
'section_mapper': self.section_mapper,
253253
**self.kwargs
254254
}
255-
self.common_kwargs['solver_objs'] = self.object_builder.solver_objs
255+
self.common_kwargs['solver_objs'] = self.type_builder.solver_objs
256256
self.common_kwargs['time_dependence'] = self.time_dependence
257-
self.common_kwargs['cbbuilder'] = self.cbbuilder
257+
self.common_kwargs['callback_builder'] = self.callback_builder
258258
self.common_kwargs['logger'] = self.logger
259259

260260
@cached_property
261-
def object_builder(self):
261+
def type_builder(self):
262262
return (
263-
CoupledObjectBuilder(**self.common_kwargs)
263+
CoupledTypeBuilder(**self.common_kwargs)
264264
if self.coupled else
265-
BaseObjectBuilder(**self.common_kwargs)
265+
BaseTypeBuilder(**self.common_kwargs)
266266
)
267267

268268
@cached_property
@@ -272,19 +272,19 @@ def time_dependence(self):
272272
return time_class(**self.common_kwargs)
273273

274274
@cached_property
275-
def cbbuilder(self):
275+
def callback_builder(self):
276276
return CoupledCallback(**self.common_kwargs) \
277277
if self.coupled else BaseCallback(**self.common_kwargs)
278278

279279
@cached_property
280-
def solver_setup(self):
281-
return CoupledSetup(**self.common_kwargs) \
282-
if self.coupled else BaseSetup(**self.common_kwargs)
280+
def builder(self):
281+
return CoupledBuilder(**self.common_kwargs) \
282+
if self.coupled else Builder(**self.common_kwargs)
283283

284284
@cached_property
285285
def solve(self):
286-
return CoupledSolver(**self.common_kwargs) \
287-
if self.coupled else Solver(**self.common_kwargs)
286+
return CoupledSolve(**self.common_kwargs) \
287+
if self.coupled else Solve(**self.common_kwargs)
288288

289289
@cached_property
290290
def logger(self):
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
from devito.petsc.types.modes import InsertMode, ScatterMode
1010

1111

12-
class Solver:
12+
class Solve:
1313
def __init__(self, **kwargs):
1414
self.inject_solve = kwargs.get('inject_solve')
1515
self.objs = kwargs.get('objs')
1616
self.solver_objs = kwargs.get('solver_objs')
1717
self.iters = kwargs.get('iters')
18-
self.cbbuilder = kwargs.get('cbbuilder')
18+
self.callback_builder = kwargs.get('callback_builder')
1919
self.time_dependence = kwargs.get('time_dependence')
2020
self.calls = self._execute_solve()
2121

@@ -29,16 +29,16 @@ def _execute_solve(self):
2929

3030
struct_assignment = self.time_dependence.assign_time_iters(sobjs['userctx'])
3131

32-
b_efunc = self.cbbuilder._b_efunc
32+
b_efunc = self.callback_builder._b_efunc
3333

3434
dmda = sobjs['dmda']
3535

3636
rhs_call = petsc_call(b_efunc.name, [sobjs['dmda'], sobjs['bglobal']])
3737

3838
vec_place_array = self.time_dependence.place_array(target)
3939

40-
if self.cbbuilder.initial_guesses:
41-
initguess = self.cbbuilder.initial_guesses[0]
40+
if self.callback_builder.initial_guesses:
41+
initguess = self.callback_builder.initial_guesses[0]
4242
initguess_call = petsc_call(initguess.name, [dmda, sobjs['xlocal']])
4343
else:
4444
initguess_call = None
@@ -84,7 +84,7 @@ def spatial_body(self):
8484
return spatial_body
8585

8686

87-
class CoupledSolver(Solver):
87+
class CoupledSolve(Solve):
8888
def _execute_solve(self):
8989
"""
9090
Assigns the required time iterators to the struct and executes
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313

1414

15-
class BaseObjectBuilder:
15+
class BaseTypeBuilder:
1616
"""
1717
A base class for constructing objects needed for a PETSc solver.
1818
Designed to be extended by subclasses, which can override the `_extend_build`
@@ -89,9 +89,9 @@ def _extend_build(self, base_dict):
8989
base dictionary of solver objects.
9090
"""
9191
return base_dict
92+
9293

93-
94-
class CoupledObjectBuilder(BaseObjectBuilder):
94+
class CoupledTypeBuilder(BaseTypeBuilder):
9595
def _extend_build(self, base_dict):
9696
sreg = self.sregistry
9797
objs = self.objs

0 commit comments

Comments
 (0)