1717from devito .petsc .iet .type_builder import objs
1818from devito .petsc .types .macros import petsc_func_begin_user
1919from devito .petsc .types .modes import InsertMode
20+ from devito .petsc .types .object import TempSymb
2021
2122
2223class BaseCallbackBuilder :
@@ -44,7 +45,8 @@ def __init__(self, **kwargs):
4445 self ._user_struct_callback = None
4546 self ._F_efunc = None
4647 self ._b_efunc = None
47- self ._constrain_bc_efunc = None
48+ self ._count_bc_efunc = None
49+ self ._point_bc_efunc = None
4850
4951 self ._J_efuncs = []
5052 # TODO: isn't there only ever one of these per solver so why is it a list?
@@ -649,25 +651,39 @@ def _create_initial_guess_body(self, body):
649651 return Uxreplace (subs ).visit (body )
650652
651653 def _make_constrain_bc (self ):
652- exprs = self .field_data .constrain_bc .exprs
654+ increment_exprs = self .field_data .constrain_bc .increment_exprs
655+ point_bc_exprs = self .field_data .constrain_bc .point_bc_exprs
653656 sobjs = self .solver_objs
654657 objs = self .objs
655658
656659 # Compile constrain `eqns` into an IET via recursive compilation
657- irs , _ = self .rcompile (
658- exprs , options = {'mpi' : False }, sregistry = self .sregistry ,
660+ irs0 , _ = self .rcompile (
661+ increment_exprs , options = {'mpi' : False }, sregistry = self .sregistry ,
659662 concretize_mapper = self .concretize_mapper
660663 )
661- body = self ._create_constrain_bc_body (
662- List (body = irs .uiet .body )
664+ # Compile constrain `eqns` into an IET via recursive compilation
665+ irs1 , _ = self .rcompile (
666+ point_bc_exprs , options = {'mpi' : False }, sregistry = self .sregistry ,
667+ concretize_mapper = self .concretize_mapper
663668 )
664- cb = self ._make_petsc_callable (
665- 'ConstrainBCs' , body , parameters = ( sobjs [ 'callbackdm' ], )
669+ count_bc_body = self ._create_count_bc_body (
670+ List ( body = irs0 . uiet . body )
666671 )
667- self ._constrain_bc_efunc = cb
668- self ._efuncs [cb .name ] = cb
672+ set_point_bc_body = self ._create_set_point_bc_body (
673+ List (body = irs1 .uiet .body )
674+ )
675+ cb0 = self ._make_petsc_callable (
676+ 'CountBCs' , count_bc_body , parameters = (sobjs ['callbackdm' ], sobjs ['numBCPtr' ])
677+ )
678+ cb1 = self ._make_petsc_callable (
679+ 'SetPointBCs' , set_point_bc_body , parameters = (sobjs ['callbackdm' ], sobjs ['numBC' ])
680+ )
681+ self ._count_bc_efunc = cb0
682+ self ._efuncs [cb0 .name ] = cb0
683+ self ._point_bc_efunc = cb1
684+ self ._efuncs [cb1 .name ] = cb1
669685
670- def _create_constrain_bc_body (self , body ):
686+ def _create_count_bc_body (self , body ):
671687 linsolve_expr = self .inject_solve .expr .rhs
672688 objs = self .objs
673689 sobjs = self .solver_objs
@@ -676,11 +692,6 @@ def _create_constrain_bc_body(self, body):
676692 dmda = sobjs ['callbackdm' ]
677693 ctx = objs ['dummyctx' ]
678694
679- x_arr = self .field_data .arrays [target ]['x' ]
680-
681- vec_get_array = petsc_call (
682- 'VecGetArray' , [objs ['xloc' ], Byref (x_arr ._C_symbol )]
683- )
684695
685696 dm_get_local_info = petsc_call (
686697 'DMDAGetLocalInfo' , [dmda , Byref (linsolve_expr .localinfo )]
@@ -695,15 +706,14 @@ def _create_constrain_bc_body(self, body):
695706 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
696707 )
697708
698- vec_restore_array = petsc_call (
699- 'VecRestoreArray' , [objs ['xloc' ], Byref (x_arr ._C_symbol )]
700- )
709+ # dummyexpr = Dereference(self.target, sobjs['numBCPtr'])
701710
702- body = body ._rebuild (body = body .body + (vec_restore_array ,))
711+ # body = body._rebuild(body=body.body)
712+
713+ body = body ._rebuild (body .body )
703714
704715 stacks = (
705- vec_get_array ,
706- dm_get_local_info
716+ dm_get_local_info ,
707717 )
708718
709719 # Dereference function data in struct
@@ -720,6 +730,62 @@ def _create_constrain_bc_body(self, body):
720730 # Replace non-function data with pointer to data in struct
721731 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
722732 i in fields if not isinstance (i .function , AbstractFunction )}
733+
734+ # subs[]
735+ # subs[self.target] = sobjs['numBC']
736+
737+ subs [TempSymb ._C_symbol ] = sobjs ['numBCPtr' ]._C_symbol
738+
739+ # from IPython import embed; embed()
740+
741+ return Uxreplace (subs ).visit (body )
742+
743+ def _create_set_point_bc_body (self , body ):
744+ linsolve_expr = self .inject_solve .expr .rhs
745+ objs = self .objs
746+ sobjs = self .solver_objs
747+ target = self .target
748+
749+ dmda = sobjs ['callbackdm' ]
750+ ctx = objs ['dummyctx' ]
751+
752+
753+ dm_get_local_info = petsc_call (
754+ 'DMDAGetLocalInfo' , [dmda , Byref (linsolve_expr .localinfo )]
755+ )
756+
757+ body = self .time_dependence .uxreplace_time (body )
758+
759+ fields = get_user_struct_fields (body )
760+ self ._struct_params .extend (fields )
761+
762+ dm_get_app_context = petsc_call (
763+ 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
764+ )
765+
766+ # body = body._rebuild(body=body.body)
767+
768+ stacks = (
769+ dm_get_local_info ,
770+ )
771+
772+ # Dereference function data in struct
773+ derefs = dereference_funcs (ctx , fields )
774+
775+ # Force the struct definition to appear at the very start, since
776+ # stacks, allocs etc may rely on its information
777+ struct_definition = [Definition (ctx ), dm_get_app_context , Definition (sobjs ['k_iter' ])]
778+
779+ body = self ._make_callable_body (
780+ body , standalones = struct_definition , stacks = stacks + derefs
781+ )
782+
783+ # Replace non-function data with pointer to data in struct
784+ subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
785+ i in fields if not isinstance (i .function , AbstractFunction )}
786+
787+
788+ subs [TempSymb ._C_symbol ] = sobjs ['bcPointsArr' ].indexed [sobjs ['k_iter' ]]
723789
724790 return Uxreplace (subs ).visit (body )
725791
0 commit comments