Skip to content

Commit b731340

Browse files
committed
compiler: Add destroy matop ctx for submats but need to fix when things are destroyed or not..
1 parent e3e4eb6 commit b731340

4 files changed

Lines changed: 75 additions & 9 deletions

File tree

devito/petsc/iet/callbacks.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,7 @@ def _uxreplace_efuncs(self):
839839
class CoupledCallbackBuilder(BaseCallbackBuilder):
840840
def __init__(self, **kwargs):
841841
self._submatrices_callback = None
842+
self._destroy_submat_callback = None
842843
super().__init__(**kwargs)
843844

844845
@property
@@ -866,6 +867,7 @@ def _make_core(self):
866867
self._make_whole_matvec()
867868
self._make_whole_formfunc()
868869
self._make_user_struct_efunc()
870+
self._create_destroy_submatrix()
869871
self._create_submatrices()
870872
self._efuncs['PopulateMatContext'] = self.objs['dummyefunc']
871873

@@ -1064,6 +1066,28 @@ def _whole_formfunc_body(self, body):
10641066
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields}
10651067

10661068
return Uxreplace(subs).visit(formfunc_body)
1069+
1070+
def _create_destroy_submatrix(self):
1071+
# Need a special destroy because each submatrix has a manually
1072+
# PetscMalloc'ed context attached via MatShellSetContext
1073+
1074+
objs = self.objs
1075+
1076+
get_ctx = petsc_call(
1077+
'MatShellGetContext', [objs['J'], Byref(objs['subctx'])]
1078+
)
1079+
1080+
free_ctx = petsc_call(
1081+
'PetscFree', [objs['subctx']]
1082+
)
1083+
1084+
body = self._make_callable_body((get_ctx, free_ctx))
1085+
1086+
cb = self._make_petsc_callable(
1087+
'DestroySubMatrixCtx', body, parameters=(objs['J']))
1088+
1089+
self._destroy_submat_callback = cb
1090+
self._efuncs[cb.name] = cb
10671091

10681092
def _create_submatrices(self):
10691093
body = self._submat_callback_body()
@@ -1153,6 +1177,15 @@ def _submat_callback_body(self):
11531177

11541178
set_ctx = petsc_call('MatShellSetContext', [objs['block'], objs['subctx']])
11551179

1180+
set_destroy_mat_op = petsc_call(
1181+
'MatShellSetOperation',
1182+
[
1183+
objs['block'],
1184+
'MATOP_DESTROY',
1185+
MatShellSetOp(self._destroy_submat_callback.name, VOID._dtype, VOID._dtype),
1186+
],
1187+
)
1188+
11561189
mat_setup = petsc_call('MatSetUp', [objs['block']])
11571190

11581191
assign_block = DummyExpr(objs['submat_arr'].indexed[i], objs['block'])
@@ -1169,6 +1202,7 @@ def _submat_callback_body(self):
11691202
dm_set_ctx,
11701203
matset_dm,
11711204
set_ctx,
1205+
set_destroy_mat_op,
11721206
mat_setup,
11731207
assign_block
11741208
)

devito/petsc/iet/logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, level, **kwargs):
2222
self.section_mapper = kwargs.get('section_mapper', {})
2323
self.inject_solve = kwargs.get('inject_solve', None)
2424

25+
# TODO: fix the segfault with kspgettype
2526
if level <= PERF:
2627
funcs = [
2728
# KSP specific

devito/petsc/iet/type_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _extend_build(self, base_dict):
119119

120120
base_dict['jacctx'] = JacobianStruct(
121121
name=sreg.make_name(prefix=objs['ljacctx'].name),
122-
fields=objs['ljacctx'].fields,
122+
fields=objs['ljacctx'].fields, no_of_submats=2
123123
)
124124

125125
for sm in submatrices:
@@ -175,7 +175,7 @@ def _target_dependent(self, base_dict):
175175
base_dict[f'{name}_ptr'] = StartPtr(
176176
sreg.make_name(prefix=f'{name}_ptr'), t.dtype
177177
)
178-
base_dict[f'xlocal{name}'] = CallbackVec(
178+
base_dict[f'xlocal{name}'] = Vec(
179179
sreg.make_name(prefix=f'xlocal{name}'), liveness='eager'
180180
)
181181
base_dict[f'Fglobal{name}'] = CallbackVec(

devito/petsc/types/object.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ class MatReuse(PetscObject):
201201
class VecScatter(PetscObject):
202202
dtype = CustomDtype('VecScatter')
203203

204+
@property
205+
def _C_free(self):
206+
return petsc_call('VecScatterDestroy', [Byref(self.function)])
207+
204208

205209
class StartPtr(PetscObject):
206210
def __init__(self, name, dtype):
@@ -228,7 +232,7 @@ class PetscSF(PetscObject):
228232
dtype = CustomDtype('PetscSF')
229233

230234

231-
class PETScStruct(LocalCompositeObject):
235+
class PETScStruct(LocalCompositeObject, PetscMixin):
232236

233237
@property
234238
def time_dim_fields(self):
@@ -268,10 +272,35 @@ def parent(self):
268272

269273
class JacobianStruct(PETScStruct):
270274
def __init__(self, name='jctx', pname='JacobianCtx', fields=None,
271-
modifier='', liveness='lazy'):
275+
modifier='', liveness='lazy', no_of_submats=0):
272276
super().__init__(name, pname, fields, modifier, liveness)
277+
self._no_of_submats = no_of_submats
273278
_C_modifier = None
274279

280+
@property
281+
def _C_free(self):
282+
# from IPython import embed; embed()
283+
submats = [i for i in self.fields if isinstance(i, PointerMat)]
284+
submats = submats[0]
285+
# from IPython import embed; embed()
286+
from devito.symbolics import FieldFromComposite
287+
destroy_call = [petsc_call('MatDestroy', [Byref(FieldFromComposite(submats.indexed[i], self.function))]) for i in range(self._no_of_submats)]
288+
destroy_call.append(petsc_call('PetscFree', [Byref(FieldFromComposite(submats.base, self.function))]))
289+
return destroy_call
290+
# return petsc_call('PetscFree', [Byref(self.function)])
291+
292+
293+
# @property
294+
# def _C_free(self):
295+
296+
# submats = [i for i in self.fields if isinstance(i, PointerMat)]
297+
# destroy_calls = [
298+
# petsc_call('MatDestroy', [Byref(self.indexify().subs({self.dim: i}))])
299+
# for i in range(self._no_of_submats)
300+
# ]
301+
# destroy_calls.append(petsc_call('PetscFree', [self.function]))
302+
# return destroy_calls
303+
275304

276305
class SubMatrixStruct(PETScStruct):
277306
def __init__(self, name='subctx', pname='SubMatrixCtx', fields=None,
@@ -385,9 +414,11 @@ class NofSubMats(Scalar, LocalType):
385414

386415
FREE_PRIORITY = {
387416
PETScArrayObject: 0,
388-
Vec: 1,
389-
Mat: 2,
390-
SNES: 3,
391-
PetscSectionGlobal: 4,
392-
DM: 5,
417+
JacobianStruct: 1,
418+
VecScatter: 2,
419+
Vec: 3,
420+
Mat: 4,
421+
SNES: 5,
422+
PetscSectionGlobal: 6,
423+
DM: 7,
393424
}

0 commit comments

Comments
 (0)