33from devito .ir .iet import (Call , FindSymbols , List , Uxreplace , CallableBody ,
44 Dereference , DummyExpr , BlankLine , Callable , Iteration ,
55 PointerCast , Definition )
6- from devito .symbolics import (Byref , FieldFromPointer , IntDiv , Deref , Mod , String , Null )
6+ from devito .symbolics import (Byref , FieldFromPointer , IntDiv , Deref , Mod , String , Null , VOID )
77from devito .symbolics .unevaluation import Mul
88from devito .types .basic import AbstractFunction
9- from devito .types import Dimension
9+ from devito .types import Dimension , Temp , TempArray
1010from devito .tools import filter_ordered
1111
12- from devito .petsc .iet .nodes import PETScCallable , MatShellSetOp
13- from devito .petsc .iet .utils import (petsc_call , void , get_user_struct_fields )
12+ from devito .petsc .iet .nodes import PETScCallable , MatShellSetOp , petsc_call
1413from devito .petsc .types import DMCast , MainUserStruct , CallbackUserStruct
1514from devito .petsc .iet .object_builder import objs
1615from devito .petsc .types .macros import petsc_func_begin_user
17- from devito .petsc .types .strings import InsertMode
16+ from devito .petsc .types .modes import InsertMode
1817
1918
2019class BaseCallback :
@@ -226,12 +225,11 @@ def _create_matvec_body(self, body, jacobian):
226225 )
227226
228227 global_to_local_begin = petsc_call (
229- 'DMGlobalToLocalBegin' , [dmda , objs ['X' ],
230- InsertMode .insert_values , xlocal ]
228+ 'DMGlobalToLocalBegin' , [dmda , objs ['X' ], insert_values , xlocal ]
231229 )
232230
233231 global_to_local_end = petsc_call ('DMGlobalToLocalEnd' , [
234- dmda , objs ['X' ], InsertMode . insert_values , xlocal
232+ dmda , objs ['X' ], insert_values , xlocal
235233 ])
236234
237235 dm_get_local_yvec = petsc_call (
@@ -261,11 +259,11 @@ def _create_matvec_body(self, body, jacobian):
261259 )
262260
263261 dm_local_to_global_begin = petsc_call ('DMLocalToGlobalBegin' , [
264- dmda , ylocal , InsertMode . add_values , objs ['Y' ]
262+ dmda , ylocal , add_values , objs ['Y' ]
265263 ])
266264
267265 dm_local_to_global_end = petsc_call ('DMLocalToGlobalEnd' , [
268- dmda , ylocal , InsertMode . add_values , objs ['Y' ]
266+ dmda , ylocal , add_values , objs ['Y' ]
269267 ])
270268
271269 dm_restore_local_xvec = petsc_call (
@@ -373,13 +371,12 @@ def _create_formfunc_body(self, body):
373371 )
374372
375373 global_to_local_begin = petsc_call (
376- 'DMGlobalToLocalBegin' , [dmda , objs ['X' ],
377- InsertMode .insert_values , objs ['xloc' ]]
374+ 'DMGlobalToLocalBegin' , [dmda , objs ['X' ], insert_values , objs ['xloc' ]]
378375 )
379376
380- global_to_local_end = petsc_call ('DMGlobalToLocalEnd' , [
381- dmda , objs ['X' ], InsertMode . insert_values , objs ['xloc' ]
382- ] )
377+ global_to_local_end = petsc_call (
378+ 'DMGlobalToLocalEnd' , [ dmda , objs ['X' ], insert_values , objs ['xloc' ] ]
379+ )
383380
384381 dm_get_local_yvec = petsc_call (
385382 'DMGetLocalVector' , [dmda , Byref (objs ['floc' ])]
@@ -406,11 +403,11 @@ def _create_formfunc_body(self, body):
406403 )
407404
408405 dm_local_to_global_begin = petsc_call ('DMLocalToGlobalBegin' , [
409- dmda , objs ['floc' ], InsertMode . add_values , objs ['F' ]
406+ dmda , objs ['floc' ], add_values , objs ['F' ]
410407 ])
411408
412409 dm_local_to_global_end = petsc_call ('DMLocalToGlobalEnd' , [
413- dmda , objs ['floc' ], InsertMode . add_values , objs ['F' ]
410+ dmda , objs ['floc' ], add_values , objs ['F' ]
414411 ])
415412
416413 dm_restore_local_xvec = petsc_call (
@@ -490,14 +487,12 @@ def _create_form_rhs_body(self, body):
490487 )
491488
492489 dm_global_to_local_begin = petsc_call (
493- 'DMGlobalToLocalBegin' , [dmda , objs ['B' ],
494- InsertMode .insert_values , sobjs ['blocal' ]]
490+ 'DMGlobalToLocalBegin' , [dmda , objs ['B' ], insert_values , sobjs ['blocal' ]]
495491 )
496492
497- dm_global_to_local_end = petsc_call ('DMGlobalToLocalEnd' , [
498- dmda , objs ['B' ], InsertMode .insert_values ,
499- sobjs ['blocal' ]
500- ])
493+ dm_global_to_local_end = petsc_call (
494+ 'DMGlobalToLocalEnd' , [dmda , objs ['B' ], insert_values , sobjs ['blocal' ]]
495+ )
501496
502497 b_arr = self .field_data .arrays [target ]['b' ]
503498
@@ -519,13 +514,11 @@ def _create_form_rhs_body(self, body):
519514 )
520515
521516 dm_local_to_global_begin = petsc_call ('DMLocalToGlobalBegin' , [
522- dmda , sobjs ['blocal' ], InsertMode .insert_values ,
523- objs ['B' ]
517+ dmda , sobjs ['blocal' ], insert_values , objs ['B' ]
524518 ])
525519
526520 dm_local_to_global_end = petsc_call ('DMLocalToGlobalEnd' , [
527- dmda , sobjs ['blocal' ], InsertMode .insert_values ,
528- objs ['B' ]
521+ dmda , sobjs ['blocal' ], insert_values , objs ['B' ]
529522 ])
530523
531524 vec_restore_array = petsc_call (
@@ -822,13 +815,12 @@ def _whole_formfunc_body(self, body):
822815 'DMGetLocalVector' , [dmda , Byref (objs ['xloc' ])]
823816 )
824817
825- global_to_local_begin = petsc_call (
826- 'DMGlobalToLocalBegin' , [dmda , objs ['X' ],
827- InsertMode .insert_values , objs ['xloc' ]]
828- )
818+ global_to_local_begin = petsc_call ('DMGlobalToLocalBegin' , [
819+ dmda , objs ['X' ], insert_values , objs ['xloc' ]
820+ ])
829821
830822 global_to_local_end = petsc_call ('DMGlobalToLocalEnd' , [
831- dmda , objs ['X' ], InsertMode . insert_values , objs ['xloc' ]
823+ dmda , objs ['X' ], insert_values , objs ['xloc' ]
832824 ])
833825
834826 dm_get_local_yvec = petsc_call (
@@ -856,11 +848,11 @@ def _whole_formfunc_body(self, body):
856848 )
857849
858850 dm_local_to_global_begin = petsc_call ('DMLocalToGlobalBegin' , [
859- dmda , objs ['floc' ], InsertMode . add_values , objs ['F' ]
851+ dmda , objs ['floc' ], add_values , objs ['F' ]
860852 ])
861853
862854 dm_local_to_global_end = petsc_call ('DMLocalToGlobalEnd' , [
863- dmda , objs ['floc' ], InsertMode . add_values , objs ['F' ]
855+ dmda , objs ['floc' ], add_values , objs ['F' ]
864856 ])
865857
866858 dm_restore_local_xvec = petsc_call (
@@ -1033,7 +1025,7 @@ def _submat_callback_body(self):
10331025 [
10341026 objs ['submat_arr' ].indexed [sb .linear_idx ],
10351027 'MATOP_MULT' ,
1036- MatShellSetOp (matvec_lookup [sb .name ].name , void , void ),
1028+ MatShellSetOp (matvec_lookup [sb .name ].name , VOID . _dtype , VOID . _dtype ),
10371029 ],
10381030 )
10391031 for sb in nonzero_submats if sb .name in matvec_lookup
@@ -1120,3 +1112,18 @@ def zero_vector(vec):
11201112 Set all entries of a PETSc vector to zero.
11211113 """
11221114 return petsc_call ('VecSet' , [vec , 0.0 ])
1115+
1116+
1117+ def get_user_struct_fields (iet ):
1118+ fields = [f .function for f in FindSymbols ('basics' ).visit (iet )]
1119+ from devito .types .basic import LocalType
1120+ avoid = (Temp , TempArray , LocalType )
1121+ fields = [f for f in fields if not isinstance (f .function , avoid )]
1122+ fields = [
1123+ f for f in fields if not (f .is_Dimension and not (f .is_Time or f .is_Modulo ))
1124+ ]
1125+ return fields
1126+
1127+
1128+ insert_values = InsertMode .insert_values
1129+ add_values = InsertMode .add_values
0 commit comments