1111from devito .types .basic import AbstractFunction
1212from devito .types import Dimension , Temp , TempArray
1313from devito .tools import filter_ordered
14+ from devito .passes .iet .linearization import linearize_accesses
1415
1516from devito .petsc .iet .nodes import PETScCallable , MatShellSetOp , petsc_call
1617from devito .petsc .types import DMCast , MainUserStruct , CallbackUserStruct
@@ -28,6 +29,7 @@ def __init__(self, **kwargs):
2829
2930 self .rcompile = kwargs .get ('rcompile' , None )
3031 self .sregistry = kwargs .get ('sregistry' , None )
32+ self .options = kwargs .get ('options' , {})
3133 self .concretize_mapper = kwargs .get ('concretize_mapper' , {})
3234 self .time_dependence = kwargs .get ('time_dependence' )
3335 self .objs = kwargs .get ('objs' )
@@ -753,6 +755,25 @@ def _create_set_point_bc_body(self, body):
753755 'DMDAGetLocalInfo' , [dmda , Byref (linsolve_expr .localinfo )]
754756 )
755757
758+ import numpy as np
759+ if self .options ['index-mode' ] == 'int32' :
760+ dtype = np .int32
761+ else :
762+ dtype = np .int64
763+ from devito .passes .iet .linearization import Tracker
764+
765+ tracker = Tracker ('basic' , dtype , self .sregistry )
766+
767+ key = lambda f : f .name == 'bc'
768+ body = linearize_accesses (body , key0 = key , tracker = tracker )
769+
770+ # will only be findexeds 'indexeds'
771+ findexeds = FindSymbols ('indexeds' ).visit (body )
772+ mapper_findexeds = {i : i .linear_index for i in findexeds }
773+
774+ # from IPython import embed; embed()
775+
776+ # findexeds =
756777 body = self .time_dependence .uxreplace_time (body )
757778
758779 fields = get_user_struct_fields (body )
@@ -762,7 +783,12 @@ def _create_set_point_bc_body(self, body):
762783 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
763784 )
764785
765- # body = body._rebuild(body=body.body)
786+ comm = sobjs ['comm' ]
787+ is_create_general = petsc_call (
788+ 'ISCreateGeneral' , [comm , sobjs ['numBC' ], sobjs ['bcPointsArr' ], 'PETSC_OWN_POINTER' ]
789+ )
790+
791+ body = body ._rebuild (body = body .body + (is_create_general ,))
766792
767793 stacks = (
768794 dm_get_local_info ,
@@ -786,7 +812,10 @@ def _create_set_point_bc_body(self, body):
786812
787813 subs [Counter ._C_symbol ] = sobjs ['bcPointsArr' ].indexed [sobjs ['k_iter' ]]
788814
789- return Uxreplace (subs ).visit (body )
815+ body = Uxreplace (mapper_findexeds ).visit (body )
816+ body = Uxreplace (subs ).visit (body )
817+
818+ return body
790819
791820 def _make_user_struct_callback (self ):
792821 """
0 commit comments