1212from devito .types .misc import PostIncrementIndex
1313from devito .types import Dimension , Temp , TempArray
1414from devito .tools import filter_ordered
15- from devito .passes .iet .linearization import linearize_accesses
15+ from devito .passes .iet .linearization import linearize_accesses , Stride
1616
1717from devito .petsc .iet .nodes import PETScCallable , MatShellSetOp , petsc_call
1818from devito .petsc .types import DMCast , MainUserStruct , CallbackUserStruct
@@ -771,6 +771,7 @@ def _create_set_point_bc_body(self, body):
771771 findexeds = FindSymbols ('indexeds' ).visit (body )
772772 mapper_findexeds = {i : i .linear_index for i in findexeds }
773773
774+
774775 # from IPython import embed; embed()
775776
776777 # findexeds =
@@ -791,7 +792,13 @@ def _create_set_point_bc_body(self, body):
791792 malloc = petsc_call (
792793 'PetscMalloc1' , [1 , sobjs ['bcPoints' ]]
793794 )
794- body = body ._rebuild (body = body .body + (is_create_general ,malloc ))
795+
796+ dummy_expr = DummyExpr (sobjs ['bcPoints' ].indexed [0 ], sobjs ['bcPointsIS' ])
797+
798+ set_point_bc = petsc_call (
799+ 'DMDASetPointBC' , [dmda , 1 , sobjs ['bcPoints' ], Null ]
800+ )
801+ body = body ._rebuild (body = body .body + (is_create_general , malloc , dummy_expr , set_point_bc ))
795802
796803 stacks = (
797804 dm_get_local_info ,
@@ -811,7 +818,6 @@ def _create_set_point_bc_body(self, body):
811818 # Replace non-function data with pointer to data in struct
812819 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
813820 i in fields if not isinstance (i .function , AbstractFunction )}
814-
815821
816822 subs [Counter ._C_symbol ] = sobjs ['bcPointsArr' ].indexed [sobjs ['k_iter' ]]
817823
@@ -1303,7 +1309,7 @@ def zero_vector(vec):
13031309def get_user_struct_fields (iet ):
13041310 fields = [f .function for f in FindSymbols ('basics' ).visit (iet )]
13051311 from devito .types .basic import LocalType
1306- avoid = (Temp , TempArray , LocalType , PostIncrementIndex )
1312+ avoid = (Temp , TempArray , LocalType , PostIncrementIndex , Stride )
13071313 fields = [f for f in fields if not isinstance (f .function , avoid )]
13081314 fields = [
13091315 f for f in fields if not (f .is_Dimension and not (f .is_Time or f .is_Modulo ))
0 commit comments