Skip to content

Commit 3089dc5

Browse files
committed
compiler: Some progress with setpointbc petsc callback
1 parent 5cce561 commit 3089dc5

6 files changed

Lines changed: 17 additions & 15 deletions

File tree

devito/ir/cgen/printer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from devito.arch.compiler import AOMPCompiler
1818
from devito.symbolics.inspection import has_integer_args, sympy_dtype
1919
from devito.symbolics.queries import q_leaf
20-
from devito.types.basic import AbstractFunction, PostIncrementIndex
20+
from devito.types.basic import AbstractFunction
21+
from devito.types.misc import PostIncrementIndex
2122
from devito.tools import ctypes_to_cstr, dtype_to_ctype, ctypes_vector_mapper
2223

2324
__all__ = ['BasePrinter', 'ccode']

devito/petsc/iet/callbacks.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from devito.symbolics.unevaluation import Mul
1111
from devito.types.basic import AbstractFunction
12+
from devito.types.misc import PostIncrementIndex
1213
from devito.types import Dimension, Temp, TempArray
1314
from devito.tools import filter_ordered
1415
from devito.passes.iet.linearization import linearize_accesses
@@ -694,7 +695,6 @@ def _create_count_bc_body(self, body):
694695
dmda = sobjs['callbackdm']
695696
ctx = objs['dummyctx']
696697

697-
698698
body = self.time_dependence.uxreplace_time(body)
699699

700700
fields = get_user_struct_fields(body)
@@ -785,10 +785,13 @@ def _create_set_point_bc_body(self, body):
785785

786786
comm = sobjs['comm']
787787
is_create_general = petsc_call(
788-
'ISCreateGeneral', [comm, sobjs['numBC'], sobjs['bcPointsArr'], 'PETSC_OWN_POINTER']
788+
'ISCreateGeneral', [comm, sobjs['numBC'], sobjs['bcPointsArr'], 'PETSC_OWN_POINTER', Byref(sobjs['bcPointsIS'])]
789789
)
790790

791-
body = body._rebuild(body=body.body + (is_create_general,))
791+
malloc = petsc_call(
792+
'PetscMalloc1', [1, sobjs['bcPoints']]
793+
)
794+
body = body._rebuild(body=body.body + (is_create_general,malloc))
792795

793796
stacks = (
794797
dm_get_local_info,
@@ -1300,7 +1303,7 @@ def zero_vector(vec):
13001303
def get_user_struct_fields(iet):
13011304
fields = [f.function for f in FindSymbols('basics').visit(iet)]
13021305
from devito.types.basic import LocalType
1303-
avoid = (Temp, TempArray, LocalType)
1306+
avoid = (Temp, TempArray, LocalType, PostIncrementIndex)
13041307
fields = [f for f in fields if not isinstance(f.function, avoid)]
13051308
fields = [
13061309
f for f in fields if not (f.is_Dimension and not (f.is_Time or f.is_Modulo))

devito/petsc/iet/type_builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
from devito.symbolics import String
44
from devito.types import Symbol
5+
from devito.types.misc import PostIncrementIndex
56
from devito.tools import frozendict
67

78
from devito.petsc.types import (
89
PetscBundle, DM, Mat, CallbackVec, Vec, KSP, PC, SNES, PetscInt, StartPtr,
910
PointerIS, PointerDM, VecScatter, JacobianStruct, SubMatrixStruct, CallbackDM,
1011
PetscMPIInt, PetscErrorCode, PointerMat, MatReuse, CallbackPointerDM,
1112
CallbackPointerIS, CallbackMat, DummyArg, NofSubMats, PetscSectionGlobal, PetscSectionLocal, PetscSF,
12-
PetscIntPtr, CallbackPetscInt, CallbackPointerPetscInt, PostIncrementIndex
13+
PetscIntPtr, CallbackPetscInt, CallbackPointerPetscInt, SingleIS
1314
)
1415

1516

@@ -225,6 +226,9 @@ def _extend_build(self, base_dict):
225226
base_dict['k_iter'] = PostIncrementIndex(
226227
name='k_iter', initvalue=0
227228
)
229+
# change names etc..
230+
base_dict['bcPointsIS'] = SingleIS(name='bcPointsIS')
231+
base_dict['bcPoints'] = PointerIS(name='bcPoints')
228232
return base_dict
229233

230234

devito/petsc/types/object.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
CustomDimension, Scalar
88
)
99
from devito.symbolics import Byref, cast
10-
from devito.types.basic import DataSymbol, LocalType, PostIncrementIndex
10+
from devito.types.basic import DataSymbol, LocalType
1111

1212
from devito.petsc.iet.nodes import petsc_call
1313

@@ -320,10 +320,6 @@ def dtype(self):
320320
return CustomDtype('IS', modifier=' *')
321321

322322

323-
class PetscPostIncrementIndex(PostIncrementIndex):
324-
pass
325-
326-
327323
class CallbackPointerPetscInt(PETScArrayObject):
328324
"""
329325
"""

devito/types/basic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,7 +1927,3 @@ def _mem_internal_lazy(self):
19271927
to impose pass-by-reference semantics.
19281928
"""
19291929
_C_modifier = None
1930-
1931-
1932-
class PostIncrementIndex(DataSymbol):
1933-
pass

devito/types/misc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def linear_index(self):
172172

173173

174174
# the special postindex type sould live in this file i think
175+
class PostIncrementIndex(Symbol):
176+
pass
175177

176178

177179
class Global(Symbol):

0 commit comments

Comments
 (0)