Skip to content

Commit a8b8b1f

Browse files
committed
tests: Fix tests due to new destroys and callback for coupled solves to fix memory leaks
1 parent f592204 commit a8b8b1f

5 files changed

Lines changed: 126 additions & 105 deletions

File tree

devito/petsc/iet/callbacks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ def _whole_formfunc_body(self, body):
10661066
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields}
10671067

10681068
return Uxreplace(subs).visit(formfunc_body)
1069-
1069+
10701070
def _create_destroy_submatrix(self):
10711071
# Need a special destroy because each submatrix has a manually
10721072
# PetscMalloc'ed context attached via MatShellSetContext
@@ -1177,12 +1177,14 @@ def _submat_callback_body(self):
11771177

11781178
set_ctx = petsc_call('MatShellSetContext', [objs['block'], objs['subctx']])
11791179

1180+
destroy_cb = self._destroy_submat_callback.name
1181+
11801182
set_destroy_mat_op = petsc_call(
11811183
'MatShellSetOperation',
11821184
[
11831185
objs['block'],
11841186
'MATOP_DESTROY',
1185-
MatShellSetOp(self._destroy_submat_callback.name, VOID._dtype, VOID._dtype),
1187+
MatShellSetOp(destroy_cb, VOID._dtype, VOID._dtype),
11861188
],
11871189
)
11881190

devito/petsc/iet/logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, level, **kwargs):
2222
self.section_mapper = kwargs.get('section_mapper', {})
2323
self.inject_solve = kwargs.get('inject_solve', None)
2424

25-
# TODO: fix the segfault with kspgettype
25+
# TODO: fix the segfault with kspgettype
2626
if level <= PERF:
2727
funcs = [
2828
# KSP specific

devito/petsc/iet/type_builder.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from devito.tools import frozendict
77

88
from devito.petsc.types import (
9-
PetscBundle, DM, Mat, CallbackVec, Vec, KSP, PC, SNES, PetscInt, StartPtr,
10-
PointerIS, PointerDM, VecScatter, JacobianStruct, SubMatrixStruct, CallbackDM,
11-
PetscMPIInt, PetscErrorCode, PointerMat, MatReuse, CallbackPointerDM,
12-
CallbackPointerIS, CallbackMat, DummyArg, NofSubMats, PetscSectionGlobal,
13-
PetscSectionLocal, PetscSF, CallbackPetscInt, CallbackPointerPetscInt, SingleIS
9+
PetscBundle, DM, Mat, Vec, KSP, PC, SNES, PetscInt, StartPtr,
10+
PointerIS, PointerDM, VecScatter, JacobianStruct, SubMatrixStruct,
11+
PetscMPIInt, PetscErrorCode, PointerMat, MatReuse,
12+
DummyArg, NofSubMats, PetscSectionGlobal,
13+
PetscSectionLocal, PetscSF, CallbackPetscInt, PointerPetscInt, SingleIS
1414
)
1515

1616

@@ -46,7 +46,7 @@ def _build(self):
4646
- 'localsize' (PetscInt): The local length of the solution vector.
4747
- 'dmda' (DM): The DMDA object associated with this solve, linked to
4848
the SNES object via `SNESSetDM`.
49-
- 'callbackdm' (CallbackDM): The DM object accessed within callback
49+
- 'callbackdm' (DM): The DM object accessed within callback
5050
functions via `SNESGetDM`.
5151
"""
5252
sreg = self.sregistry
@@ -60,13 +60,13 @@ def _build(self):
6060
'xglobal': Vec(sreg.make_name(prefix='xglobal')),
6161
'xlocal': Vec(sreg.make_name(prefix='xlocal')),
6262
'bglobal': Vec(sreg.make_name(prefix='bglobal')),
63-
'blocal': CallbackVec(sreg.make_name(prefix='blocal')),
63+
'blocal': Vec(sreg.make_name(prefix='blocal'), destroy=False),
6464
'ksp': KSP(sreg.make_name(prefix='ksp')),
6565
'pc': PC(sreg.make_name(prefix='pc')),
6666
'snes': SNES(snes_name),
6767
'localsize': PetscInt(sreg.make_name(prefix='localsize')),
6868
'dmda': DM(sreg.make_name(prefix='da'), dofs=len(targets)),
69-
'callbackdm': CallbackDM(sreg.make_name(prefix='dm')),
69+
'callbackdm': DM(sreg.make_name(prefix='dm'), destroy=False),
7070
'snes_prefix': String(formatted_prefix),
7171
}
7272

@@ -129,9 +129,9 @@ def _extend_build(self, base_dict):
129129
name=f'{name}ctx',
130130
fields=objs['subctx'].fields,
131131
)
132-
base_dict[f'{name}X'] = CallbackVec(f'{name}X')
133-
base_dict[f'{name}Y'] = CallbackVec(f'{name}Y')
134-
base_dict[f'{name}F'] = CallbackVec(f'{name}F')
132+
base_dict[f'{name}X'] = Vec(f'{name}X', destroy=False)
133+
base_dict[f'{name}Y'] = Vec(f'{name}Y', destroy=False)
134+
base_dict[f'{name}F'] = Vec(f'{name}F', destroy=False)
135135

136136
# Bundle objects/metadata required by the coupled residual callback
137137
f_components, x_components = [], []
@@ -178,17 +178,17 @@ def _target_dependent(self, base_dict):
178178
base_dict[f'xlocal{name}'] = Vec(
179179
sreg.make_name(prefix=f'xlocal{name}'), liveness='eager'
180180
)
181-
base_dict[f'Fglobal{name}'] = CallbackVec(
182-
sreg.make_name(prefix=f'Fglobal{name}'), liveness='eager'
181+
base_dict[f'Fglobal{name}'] = Vec(
182+
sreg.make_name(prefix=f'Fglobal{name}'), liveness='eager', destroy=False
183183
)
184-
base_dict[f'Xglobal{name}'] = CallbackVec(
185-
sreg.make_name(prefix=f'Xglobal{name}')
184+
base_dict[f'Xglobal{name}'] = Vec(
185+
sreg.make_name(prefix=f'Xglobal{name}'), destroy=False
186186
)
187187
base_dict[f'xglobal{name}'] = Vec(
188188
sreg.make_name(prefix=f'xglobal{name}')
189189
)
190-
base_dict[f'blocal{name}'] = CallbackVec(
191-
sreg.make_name(prefix=f'blocal{name}'), liveness='eager'
190+
base_dict[f'blocal{name}'] = Vec(
191+
sreg.make_name(prefix=f'blocal{name}'), liveness='eager', destroy=False
192192
)
193193
base_dict[f'bglobal{name}'] = Vec(
194194
sreg.make_name(prefix=f'bglobal{name}')
@@ -220,7 +220,7 @@ def _extend_build(self, base_dict):
220220
base_dict['numBCPtr'] = CallbackPetscInt(
221221
name=sreg.make_name(prefix='numBCPtr'), initvalue=0
222222
)
223-
base_dict['bcPointsArr'] = CallbackPointerPetscInt(
223+
base_dict['bcPointsArr'] = PointerPetscInt(
224224
name=sreg.make_name(prefix='bcPointsArr')
225225
)
226226
base_dict['k_iter'] = PostIncrementIndex(
@@ -246,19 +246,19 @@ def _extend_build(self, base_dict):
246246
objs = frozendict({
247247
'size': PetscMPIInt(name='size'),
248248
'err': PetscErrorCode(name='err'),
249-
'block': CallbackMat('block'),
249+
'block': Mat('block', destroy=False),
250250
'submat_arr': PointerMat(name='submat_arr'),
251251
'subblockrows': PetscInt('subblockrows'),
252252
'subblockcols': PetscInt('subblockcols'),
253253
'rowidx': PetscInt('rowidx'),
254254
'colidx': PetscInt('colidx'),
255255
'J': Mat('J'),
256256
'X': Vec('X'),
257-
'xloc': CallbackVec('xloc'),
257+
'xloc': Vec('xloc', destroy=False),
258258
'Y': Vec('Y'),
259-
'yloc': CallbackVec('yloc'),
259+
'yloc': Vec('yloc', destroy=False),
260260
'F': Vec('F'),
261-
'floc': CallbackVec('floc'),
261+
'floc': Vec('floc', destroy=False),
262262
'B': Vec('B'),
263263
'nfields': PetscInt('nfields'),
264264
'irow': PointerIS(name='irow'),
@@ -270,12 +270,12 @@ def _extend_build(self, base_dict):
270270
'rows': rows,
271271
'cols': cols,
272272
'Subdms': subdms,
273-
'LocalSubdms': CallbackPointerDM(name='subdms'),
273+
'LocalSubdms': PointerDM(name='subdms', destroy=False),
274274
'Fields': fields,
275-
'LocalFields': CallbackPointerIS(name='fields'),
275+
'LocalFields': PointerIS(name='fields', destroy=False),
276276
'Submats': submats,
277277
'ljacctx': JacobianStruct(
278-
fields=[subdms, fields, submats], modifier=' *'
278+
fields=[subdms, fields, submats], modifier=' *', destroy=False
279279
),
280280
'subctx': SubMatrixStruct(fields=[rows, cols]),
281281
'dummyctx': Symbol('lctx'),

0 commit comments

Comments
 (0)