Skip to content

Commit 774af10

Browse files
committed
misc: Clean up and docstrings
1 parent 9a182da commit 774af10

5 files changed

Lines changed: 54 additions & 55 deletions

File tree

devito/petsc/iet/routines.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,12 @@ def __init__(self, **kwargs):
4242
self._struct_params = []
4343

4444
self._main_matvec_callback = None
45-
self._main_formfunc_callback = None
4645
self._user_struct_callback = None
47-
# TODO: Test pickling. The mutability of these lists
48-
# could cause issues when pickling?
46+
self._F_efunc = None
47+
self._b_efunc = None
48+
4949
self._J_efuncs = []
50-
self._F_efuncs = []
51-
self._b_efuncs = []
52-
self._initialguesses = []
50+
self._initial_guesses = []
5351

5452
self._make_core()
5553
self._efuncs = self._uxreplace_efuncs()
@@ -69,31 +67,26 @@ def filtered_struct_params(self):
6967
@property
7068
def main_matvec_callback(self):
7169
"""
72-
This is the matvec callback associated with the whole Jacobian i.e
73-
is set in the main kernel via
74-
`PetscCall(MatShellSetOperation(J,MATOP_MULT,(void (*)(void))...));`
70+
The matrix-vector callback for the full Jacobian.
71+
This is the function set in the main Kernel via:
72+
PetscCall(MatShellSetOperation(J, MATOP_MULT, (void (*)(void))...));
73+
The callback has the signature `(Mat, Vec, Vec)`.
7574
"""
7675
return self._J_efuncs[0]
7776

78-
@property
79-
def main_formfunc_callback(self):
80-
return self._F_efuncs[0]
81-
8277
@property
8378
def J_efuncs(self):
79+
"""
80+
List of matrix-vector callbacks.
81+
Each callback has the signature `(Mat, Vec, Vec)`. Typically, this list
82+
contains a single element, but in mixed systems it can include multiple
83+
callbacks, one for each subblock.
84+
"""
8485
return self._J_efuncs
8586

8687
@property
87-
def F_efuncs(self):
88-
return self._F_efuncs
89-
90-
@property
91-
def b_efuncs(self):
92-
return self._b_efuncs
93-
94-
@property
95-
def initialguesses(self):
96-
return self._initialguesses
88+
def initial_guesses(self):
89+
return self._initial_guesses
9790

9891
@property
9992
def user_struct_callback(self):
@@ -284,7 +277,7 @@ def _make_formfunc(self):
284277
retval=objs['err'],
285278
parameters=(objs['snes'], objs['X'], objs['F'], objs['dummyptr'])
286279
)
287-
self._F_efuncs.append(cb)
280+
self._F_efunc = cb
288281
self._efuncs[cb.name] = cb
289282

290283
def _create_formfunc_body(self, body):
@@ -422,7 +415,7 @@ def _make_formrhs(self):
422415
retval=objs['err'],
423416
parameters=(sobjs['callbackdm'], objs['B'])
424417
)
425-
self._b_efuncs.append(cb)
418+
self._b_efunc = cb
426419
self._efuncs[cb.name] = cb
427420

428421
def _create_form_rhs_body(self, body):
@@ -534,7 +527,7 @@ def _make_initialguess(self):
534527
retval=objs['err'],
535528
parameters=(sobjs['callbackdm'], objs['xloc'])
536529
)
537-
self._initialguesses.append(cb)
530+
self._initial_guesses.append(cb)
538531
self._efuncs[cb.name] = cb
539532

540533
def _create_initial_guess_body(self, body):
@@ -660,16 +653,12 @@ def jacobian(self):
660653
@property
661654
def main_matvec_callback(self):
662655
"""
663-
This is the matvec callback associated with the whole Jacobian i.e
656+
This is the matrix-vector callback associated with the whole Jacobian i.e
664657
is set in the main kernel via
665658
`PetscCall(MatShellSetOperation(J,MATOP_MULT,(void (*)(void))MyMatShellMult));`
666659
"""
667660
return self._main_matvec_callback
668661

669-
@property
670-
def main_formfunc_callback(self):
671-
return self._main_formfunc_callback
672-
673662
def _make_core(self):
674663
for sm in self.fielddata.jacobian.nonzero_submatrices:
675664
self._make_matvec(sm, prefix=f'{sm.name}_MatMult')
@@ -757,7 +746,7 @@ def _make_whole_formfunc(self):
757746
retval=objs['err'],
758747
parameters=(objs['snes'], objs['X'], objs['F'], objs['dummyptr'])
759748
)
760-
self._main_formfunc_callback = cb
749+
self._F_efunc = cb
761750
self._efuncs[cb.name] = cb
762751

763752
def _whole_formfunc_body(self, body):
@@ -1310,7 +1299,7 @@ def _setup(self):
13101299
'MatShellSetOperation',
13111300
[sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)]
13121301
)
1313-
formfunc = self.cbbuilder.main_formfunc_callback
1302+
formfunc = self.cbbuilder._F_efunc
13141303
formfunc_operation = petsc_call(
13151304
'SNESSetFunction',
13161305
[sobjs['snes'], objs['Null'], FormFunctionCallback(formfunc.name, void, void),
@@ -1477,7 +1466,7 @@ def _setup(self):
14771466
'MatShellSetOperation',
14781467
[sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)]
14791468
)
1480-
formfunc = self.cbbuilder.main_formfunc_callback
1469+
formfunc = self.cbbuilder._F_efunc
14811470
formfunc_operation = petsc_call(
14821471
'SNESSetFunction',
14831472
[sobjs['snes'], objs['Null'], FormFunctionCallback(formfunc.name, void, void),
@@ -1593,16 +1582,16 @@ def _execute_solve(self):
15931582

15941583
struct_assignment = self.timedep.assign_time_iters(sobjs['userctx'])
15951584

1596-
b_efunc = self.cbbuilder.b_efuncs[0]
1585+
b_efunc = self.cbbuilder._b_efunc
15971586

15981587
dmda = sobjs['dmda']
15991588

16001589
rhs_call = petsc_call(b_efunc.name, [sobjs['dmda'], sobjs['bglobal']])
16011590

16021591
vec_place_array = self.timedep.place_array(target)
16031592

1604-
if self.cbbuilder.initialguesses:
1605-
initguess = self.cbbuilder.initialguesses[0]
1593+
if self.cbbuilder.initial_guesses:
1594+
initguess = self.cbbuilder.initial_guesses[0]
16061595
initguess_call = petsc_call(initguess.name, [dmda, sobjs['xlocal']])
16071596
else:
16081597
initguess_call = None

devito/petsc/types/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _C_ctype(self):
169169
fields = [(i.target.name, dtype_to_ctype(i.dtype)) for i in self.components]
170170
return POINTER(type(self.pname, (Structure,), {'_fields_': fields}))
171171

172-
@cached_property
172+
@property
173173
def symbolic_shape(self):
174174
return self.c0.symbolic_shape
175175

devito/petsc/types/types.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sympy
22

33
from itertools import chain
4+
from functools import cached_property
45

56
from devito.tools import Reconstructable, sympy_mutex, as_tuple, frozendict
67
from devito.tools.dtypes_lowering import dtype_mapper
@@ -241,7 +242,7 @@ def __init__(self, targets, arrays, jacobian=None, residual=None):
241242
self._jacobian = jacobian
242243
self._residual = residual
243244

244-
@property
245+
@cached_property
245246
def space_dimensions(self):
246247
space_dims = {t.space_dimensions for t in self.targets}
247248
if len(space_dims) > 1:
@@ -251,7 +252,7 @@ def space_dimensions(self):
251252
)
252253
return space_dims.pop()
253254

254-
@property
255+
@cached_property
255256
def grid(self):
256257
"""The unique `Grid` associated with all targets."""
257258
grids = [t.grid for t in self.targets]
@@ -261,7 +262,7 @@ def grid(self):
261262
)
262263
return grids.pop()
263264

264-
@property
265+
@cached_property
265266
def space_order(self):
266267
# NOTE: since we use DMDA to create vecs for the coupled solves,
267268
# all fields must have the same space order
@@ -398,7 +399,8 @@ def _build_matvecs(self):
398399
matvecs.extend(
399400
e for e in self._build_matvec_expr(eq) if e is not None
400401
)
401-
matvecs = tuple(sorted(matvecs, key=lambda e: not isinstance(e, EssentialBC)))
402+
key = lambda e: not isinstance(e, EssentialBC)
403+
matvecs = tuple(sorted(matvecs, key=key))
402404

403405
matvecs = self._scale_non_bcs(matvecs)
404406
scdiag = self._compute_scdiag(matvecs)
@@ -425,7 +427,7 @@ class MixedJacobian(BaseJacobian):
425427
"""
426428
def __init__(self, target_exprs, arrays, time_mapper):
427429
super().__init__(arrays=arrays, target=None)
428-
self.targets = tuple(target_exprs.keys())
430+
self.targets = tuple(target_exprs)
429431
self.time_mapper = time_mapper
430432
self._submatrices = []
431433
self._build_blocks(target_exprs)
@@ -443,12 +445,12 @@ def n_submatrices(self):
443445
"""Return the number of submatrix blocks."""
444446
return len(self._submatrices)
445447

446-
@property
448+
@cached_property
447449
def nonzero_submatrices(self):
448450
"""Return SubMatrixBlock objects that have non-empty matvecs."""
449451
return [m for m in self.submatrices if m.matvecs]
450452

451-
@property
453+
@cached_property
452454
def target_scaler_mapper(self):
453455
"""
454456
Map each row target to the scdiag of its corresponding
@@ -467,13 +469,8 @@ def _build_blocks(self, target_exprs):
467469
for i, row_target in enumerate(self.targets):
468470
exprs = target_exprs[row_target]
469471
for j, col_target in enumerate(self.targets):
470-
matvecs = []
471-
for expr in exprs:
472-
matvecs.extend(
473-
e for e in self._build_matvec_expr(
474-
expr, col_target=col_target, row_target=row_target
475-
)
476-
)
472+
473+
matvecs = self._build_submatrix_matvecs(exprs, row_target, col_target)
477474
matvecs = [m for m in matvecs if m is not None]
478475

479476
# Sort to put EssentialBC first if any
@@ -497,6 +494,16 @@ def _build_blocks(self, target_exprs):
497494
)
498495
self._submatrices.append(block)
499496

497+
def _build_submatrix_matvecs(self, exprs, row_target, col_target):
498+
matvecs = []
499+
for expr in exprs:
500+
matvecs.extend(
501+
e for e in self._build_matvec_expr(
502+
expr, col_target=col_target, row_target=row_target
503+
)
504+
)
505+
return matvecs
506+
500507
def get_submatrix(self, row_idx, col_idx):
501508
"""
502509
Return the SubMatrixBlock at (row_idx, col_idx), or None if not found.

devito/symbolics/extraction.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,13 @@ def centre_stencil(expr, target, as_coeff=False):
7979
Extract the centre stencil from an expression. Its coefficient is what
8080
would appear on the diagonal of the matrix system if the matrix were
8181
formed explicitly.
82+
8283
Parameters
8384
----------
84-
expr : the expression to extract the centre stencil from
85-
target : the target function whose centre stencil we want
85+
expr : expr-like
86+
The expression to extract the centre stencil from
87+
target : Function
88+
The target function whose centre stencil we want
8689
as_coeff : bool, optional
8790
If True, return the coefficient of the centre stencil
8891
"""

devito/types/object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,10 @@ class LocalCompositeObject(CompositeObject, LocalType):
252252
__rkwargs__ = ('modifier', 'liveness')
253253

254254
def __init__(self, name, pname, fields, modifier=None, liveness='lazy'):
255-
self.modifier = modifier
256255
dtype = CustomDtype(f"struct {pname}", modifier=modifier)
257256
Object.__init__(self, name, dtype, None)
258257
self._pname = pname
258+
self.modifier = modifier
259259
assert liveness in ['eager', 'lazy']
260260
self._liveness = liveness
261261
self._fields = fields

0 commit comments

Comments
 (0)