1515from devito .passes .iet .linearization import Stride
1616
1717from devito .petsc .iet .nodes import PETScCallable , MatShellSetOp , petsc_call
18- from devito .petsc .types import DMCast , MainUserStruct , CallbackUserStruct , PetscObjectCast
18+ from devito .petsc .types import (
19+ DMCast , MainUserStruct , CallbackUserStruct , PetscObjectCast , PetscInt
20+ )
1921from devito .petsc .iet .type_builder import objs
2022from devito .petsc .types .macros import petsc_func_begin_user
2123from devito .petsc .types .modes import InsertMode
22- from devito .petsc .types .object import Counter
2324
2425
2526class BaseCallbackBuilder :
@@ -47,7 +48,7 @@ def __init__(self, **kwargs):
4748 self ._F_efunc = None
4849 self ._b_efunc = None
4950 self ._count_bc_efunc = None
50- self ._point_bc_efunc = None
51+ self ._set_point_bc_efunc = None
5152 self ._J_efuncs = []
5253 self ._initial_guess_efuncs = []
5354
@@ -646,43 +647,65 @@ def _create_initial_guess_body(self, body):
646647
647648 def _make_constrain_bc (self ):
648649 """
649- To constrain essential boundary nodes, two additional callbacks are required.
650- This method constructs the corresponding efuncs: `CountBCs` and `SetPointBCs`.
650+ Constructs the `CountBCs` and `SetPointBCs` efuncs. Works for both
651+ single- and multi-field: all fields' expressions are compiled together
652+ (clustering may fuse loops) and a single callback is emitted for each.
651653 """
652- increment_exprs = self .field_data .constrain_bc .increment_exprs
653- point_bc_exprs = self .field_data .constrain_bc .point_bc_exprs
654+ constrain_bc = self .field_data .constrain_bc
654655 sobjs = self .solver_objs
655656
656- # Compile `increment_exprs` into an IET via recursive compilation
657+ # Normalize to dict {target: ConstrainBC}
658+ if isinstance (constrain_bc , dict ):
659+ constrain_bc_dict = constrain_bc
660+ else :
661+ constrain_bc_dict = {self .field_data .target : constrain_bc }
662+ targets = list (constrain_bc_dict .keys ())
663+
664+ all_increment_exprs = [
665+ e for t in targets for e in constrain_bc_dict [t ].increment_exprs
666+ ]
657667 irs0 , _ = self .rcompile (
658- increment_exprs , options = {'mpi' : False }, sregistry = self . sregistry ,
659- concretize_mapper = self .concretize_mapper
668+ all_increment_exprs , options = {'mpi' : False },
669+ sregistry = self . sregistry , concretize_mapper = self .concretize_mapper
660670 )
661- # Compile `point_bc_exprs` into an IET via recursive compilation
671+ all_point_bc_exprs = [
672+ e for t in targets for e in constrain_bc_dict [t ].point_bc_exprs
673+ ]
662674 irs1 , _ = self .rcompile (
663- point_bc_exprs , options = {'mpi' : False }, sregistry = self .sregistry ,
664- concretize_mapper = self .concretize_mapper
665- )
666- count_bc_body = self ._create_count_bc_body (
667- List (body = irs0 .uiet .body )
675+ all_point_bc_exprs , options = {'mpi' : False },
676+ sregistry = self .sregistry , concretize_mapper = self .concretize_mapper
668677 )
678+
679+ pairs = [
680+ (sobjs [f'numBCPtr_{ t .name } ' ], constrain_bc_dict [t ].counter )
681+ for t in targets
682+ ]
683+ count_bc_body = self ._create_count_bc_body (List (body = irs0 .uiet .body ), pairs )
669684 set_point_bc_body = self ._create_set_point_bc_body (
670- List (body = irs1 .uiet .body )
685+ List (body = irs1 .uiet .body ), constrain_bc_dict
671686 )
687+
688+ numBCPtr_params = tuple (sobjs [f'numBCPtr_{ t .name } ' ] for t in targets )
689+ numBC_params = tuple (sobjs [f'numBC_{ t .name } ' ] for t in targets )
690+
672691 cb0 = self ._make_petsc_callable (
673692 'CountBCs' , count_bc_body ,
674- parameters = (sobjs ['callbackdm' ], sobjs [ 'numBCPtr' ])
693+ parameters = (sobjs ['callbackdm' ],) + numBCPtr_params
675694 )
676695 cb1 = self ._make_petsc_callable (
677696 'SetPointBCs' , set_point_bc_body ,
678- parameters = (sobjs ['callbackdm' ], sobjs [ 'numBC' ])
697+ parameters = (sobjs ['callbackdm' ],) + numBC_params
679698 )
680699 self ._count_bc_efunc = cb0
700+ self ._set_point_bc_efunc = cb1
681701 self ._efuncs [cb0 .name ] = cb0
682- self ._point_bc_efunc = cb1
683702 self ._efuncs [cb1 .name ] = cb1
684703
685- def _create_count_bc_body (self , body ):
704+ def _create_count_bc_body (self , body , pairs ):
705+ """
706+ Generic CountBCs body. `pairs` is a list of (numBCPtr, counter) tuples,
707+ one per field. All fields are handled in a single callback body.
708+ """
686709 objs = self .objs
687710 sobjs = self .solver_objs
688711
@@ -698,26 +721,29 @@ def _create_count_bc_body(self, body):
698721 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
699722 )
700723
701- # TODO: change names
702- deref_ptr = DummyExpr (Counter , Deref (sobjs ['numBCPtr' ]))
703- move_ptr = DummyExpr (Deref (sobjs ['numBCPtr' ]), Counter )
724+ deref_ptrs = tuple (
725+ DummyExpr (counter , Deref (numBCPtr )) for numBCPtr , counter in pairs
726+ )
727+ move_ptrs = tuple (
728+ DummyExpr (Deref (numBCPtr ), counter ) for numBCPtr , counter in pairs
729+ )
704730
705- # Force the struct definition to appear at the very start, since
706- # stacks, allocs etc may rely on its information
707731 struct_definition = [Definition (ctx ), dm_get_app_context ]
708732
709- body = body ._rebuild (body .body + ( move_ptr ,) )
733+ body = body ._rebuild (body .body + move_ptrs )
710734
711735 body = self ._make_callable_body (
712- body , standalones = struct_definition , stacks = ( deref_ptr ,)
736+ body , standalones = struct_definition , stacks = deref_ptrs
713737 )
714- # Replace non-function data with pointer to data in struct
715738 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
716739 i in fields if not isinstance (i .function , AbstractFunction )}
717740
718741 return Uxreplace (subs ).visit (body )
719742
720- def _create_set_point_bc_body (self , body ):
743+ def _create_set_point_bc_body (self , body , constrain_bc_dict ):
744+ """Single-field SetPointBCs body. `constrain_bc_dict` has one entry."""
745+ (target , constrain_bc ), = constrain_bc_dict .items ()
746+ tname = target .name
721747 linsolve_expr = self .inject_solve .expr .rhs
722748 objs = self .objs
723749 sobjs = self .solver_objs
@@ -739,58 +765,43 @@ def _create_set_point_bc_body(self, body):
739765 )
740766 petsc_obj_comm = Call ('PetscObjectComm' , arguments = [PetscObjectCast (dmda )])
741767 is_create_general = petsc_call (
742- 'ISCreateGeneral' , [petsc_obj_comm , sobjs ['numBC' ], sobjs ['bcPointsArr' ],
743- 'PETSC_OWN_POINTER' , Byref (sobjs ['bcPointsIS' ])]
768+ 'ISCreateGeneral' ,
769+ [petsc_obj_comm , sobjs [f'numBC_{ tname } ' ], sobjs [f'bcPointsArr_{ tname } ' ],
770+ 'PETSC_OWN_POINTER' , Byref (sobjs ['bcPointsIS' ])]
744771 )
745772 malloc_bc_points_arr = petsc_call (
746- 'PetscMalloc1' , [sobjs ['numBC' ], Byref (sobjs ['bcPointsArr' ]._C_symbol )]
773+ 'PetscMalloc1' ,
774+ [sobjs [f'numBC_{ tname } ' ], Byref (sobjs [f'bcPointsArr_{ tname } ' ]._C_symbol )]
747775 )
748-
749776 malloc_bc_points = petsc_call (
750777 'PetscMalloc1' , [1 , Byref (sobjs ['bcPoints' ]._C_symbol )]
751778 )
752-
753779 dummy_expr = DummyExpr (sobjs ['bcPoints' ].indexed [0 ], sobjs ['bcPointsIS' ])
754-
755780 set_point_bc = petsc_call (
756781 'DMDASetPointBC' , [dmda , 1 , sobjs ['bcPoints' ], Null ]
757782 )
758783 body = body ._rebuild (
759784 body = (
760785 (malloc_bc_points_arr ,)
761786 + body .body
762- + (
763- is_create_general ,
764- malloc_bc_points ,
765- dummy_expr ,
766- set_point_bc ,
767- )
787+ + (is_create_general , malloc_bc_points , dummy_expr , set_point_bc ,)
768788 )
769789 )
770- stacks = (
771- dm_get_local_info ,
772- )
773790
774- # Dereference function data in struct
775791 derefs = dereference_funcs (ctx , fields )
776-
777- # Force the struct definition to appear at the very start, since
778- # stacks, allocs etc may rely on its information
779792 standalones = [
780793 Definition (ctx ),
781794 dm_get_app_context ,
782- Definition (sobjs ['k_iter ' ])
795+ Definition (sobjs [f'k_iter_ { tname } ' ])
783796 ]
784-
785797 body = self ._make_callable_body (
786- body , standalones = standalones , stacks = stacks + derefs
798+ body , standalones = standalones , stacks = ( dm_get_local_info ,) + derefs
787799 )
788800
789- # Replace non-function data with pointer to data in struct
790801 subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
791802 i in fields if not isinstance (i .function , AbstractFunction )}
792-
793- subs [ Counter . _C_symbol ] = sobjs ['bcPointsArr ' ].indexed [sobjs ['k_iter ' ]]
803+ subs [ constrain_bc . counter . _C_symbol ] = \
804+ sobjs [f'bcPointsArr_ { tname } ' ].indexed [sobjs [f'k_iter_ { tname } ' ]]
794805
795806 return Uxreplace (subs ).visit (body )
796807
@@ -846,6 +857,105 @@ def __init__(self, **kwargs):
846857 def submatrices_callback (self ):
847858 return self ._submatrices_callback
848859
860+ def _create_set_point_bc_body (self , body , _constrain_bc_dict ):
861+ return self ._create_set_point_bc_body_coupled (body )
862+
863+ def _create_set_point_bc_body_coupled (self , body ):
864+ """
865+ Combined SetPointBCs body for all target fields. The body is compiled
866+ from all fields' point_bc_exprs together (loops may be fused by
867+ clustering). Per-field counter symbols are substituted with the
868+ corresponding bcPointsArr[k_iter] after assembly.
869+ """
870+ linsolve_expr = self .inject_solve .expr .rhs
871+ objs = self .objs
872+ sobjs = self .solver_objs
873+ constrain_bc = self .field_data .constrain_bc
874+ targets = self .field_data .targets
875+ nfields = len (targets )
876+ dmda = sobjs ['callbackdm' ]
877+ ctx = objs ['dummyctx' ]
878+
879+ dm_get_local_info = petsc_call (
880+ 'DMDAGetLocalInfo' , [dmda , Byref (linsolve_expr .localinfo )]
881+ )
882+ dm_get_app_context = petsc_call (
883+ 'DMGetApplicationContext' , [dmda , Byref (ctx ._C_symbol )]
884+ )
885+ petsc_obj_comm = Call ('PetscObjectComm' , arguments = [PetscObjectCast (dmda )])
886+
887+ body = self .time_dependence .uxreplace_time (body )
888+ fields = get_user_struct_fields (body )
889+ self ._struct_params .extend (fields )
890+
891+ bcPointsIS = sobjs ['bcPointsIS' ]
892+ bcCompsIS = sobjs ['bcCompsIS' ]
893+
894+ # Zero-initialise IS arrays (PetscCalloc1 sets pointers to NULL so
895+ # the automatic ISDestroy cleanup is safe even on early exit)
896+ is_array_mallocs = (
897+ petsc_call ('PetscCalloc1' , [nfields , Byref (bcPointsIS ._C_symbol )]),
898+ petsc_call ('PetscCalloc1' , [nfields , Byref (bcCompsIS ._C_symbol )]),
899+ )
900+ bc_arr_mallocs = tuple (
901+ petsc_call ('PetscMalloc1' ,
902+ [sobjs [f'numBC_{ t .name } ' ],
903+ Byref (sobjs [f'bcPointsArr_{ t .name } ' ]._C_symbol )])
904+ for t in targets
905+ )
906+
907+ is_creates , comp_creates = [], []
908+ for i , t in enumerate (targets ):
909+ tname = t .name
910+ is_creates .append (petsc_call (
911+ 'ISCreateGeneral' ,
912+ [petsc_obj_comm , sobjs [f'numBC_{ tname } ' ],
913+ sobjs [f'bcPointsArr_{ tname } ' ],
914+ 'PETSC_OWN_POINTER' , Byref (bcPointsIS .indexed [i ])]
915+ ))
916+ comp_arr = PetscInt (name = f'comp{ i } ' , initvalue = i )
917+ comp_creates .append (petsc_call (
918+ 'ISCreateGeneral' ,
919+ [petsc_obj_comm , 1 , Byref (comp_arr ),
920+ 'PETSC_COPY_VALUES' , Byref (bcCompsIS .indexed [i ])]
921+ ))
922+
923+ set_point_bc = petsc_call (
924+ 'DMDASetPointBC' , [dmda , nfields , bcPointsIS , bcCompsIS ]
925+ )
926+
927+ body = body ._rebuild (body = (
928+ is_array_mallocs
929+ + bc_arr_mallocs
930+ + body .body
931+ + tuple (is_creates )
932+ + tuple (comp_creates )
933+ + (set_point_bc ,)
934+ ))
935+
936+ derefs = dereference_funcs (ctx , fields )
937+ k_defs = [Definition (sobjs [f'k_iter_{ t .name } ' ]) for t in targets ]
938+ comp_defs = [
939+ Definition (PetscInt (name = f'comp{ i } ' , initvalue = i ))
940+ for i in range (nfields )
941+ ]
942+ standalones = [Definition (ctx ), dm_get_app_context ] + k_defs + comp_defs
943+
944+ body = self ._make_callable_body (
945+ body , standalones = standalones ,
946+ stacks = (dm_get_local_info ,) + derefs
947+ )
948+
949+ # Struct substitutions + per-field counter → bcArr[k_iter]
950+ subs = {i ._C_symbol : FieldFromPointer (i ._C_symbol , ctx ) for
951+ i in fields if not isinstance (i .function , AbstractFunction )}
952+ for t in targets :
953+ tname = t .name
954+ subs [constrain_bc [t ].counter ._C_symbol ] = \
955+ sobjs [f'bcPointsArr_{ tname } ' ].indexed [sobjs [f'k_iter_{ tname } ' ]]
956+
957+ return Uxreplace (subs ).visit (body )
958+
849959 @property
850960 def jacobian (self ):
851961 return self .inject_solve .expr .rhs .field_data .jacobian
@@ -866,6 +976,8 @@ def _make_core(self):
866976 self ._make_options_callback ()
867977 self ._make_whole_matvec ()
868978 self ._make_whole_formfunc ()
979+ if self .field_data .constrain_bc :
980+ self ._make_constrain_bc ()
869981 self ._make_user_struct_efunc ()
870982 self ._create_destroy_submatrix ()
871983 self ._create_submatrices ()
@@ -1138,31 +1250,29 @@ def _submat_callback_body(self):
11381250 )
11391251
11401252 i = Dimension (name = 'i' )
1141- tmpvec = sobjs ['tmpvec' ]
1253+ tvec = sobjs ['tmpvec' ]
11421254
11431255 row_idx = DummyExpr (objs ['rowidx' ], IntDiv (i , objs ['dof' ]))
11441256 col_idx = DummyExpr (objs ['colidx' ], Mod (i , objs ['dof' ]))
11451257
1146- # Query constrained global size from each sub-DM via a temporary Vec.
1147- # For unconstrained sub-DMs this is equivalent to M*N; for constrained
1148- # (BC-excluded) sub-DMs it returns the reduced size automatically.
1258+ # Query global size from each sub-DM via a temporary Vec.
11491259 get_row_vec = petsc_call (
1150- 'DMGetGlobalVector' , [objs ['Subdms' ].indexed [objs ['rowidx' ]], Byref (tmpvec )]
1260+ 'DMGetGlobalVector' , [objs ['Subdms' ].indexed [objs ['rowidx' ]], Byref (tvec )]
11511261 )
11521262 get_row_size = petsc_call (
1153- 'VecGetSize' , [tmpvec , Byref (objs ['subblockrows' ])]
1263+ 'VecGetSize' , [tvec , Byref (objs ['subblockrows' ])]
11541264 )
11551265 restore_row_vec = petsc_call (
1156- 'DMRestoreGlobalVector' , [objs ['Subdms' ].indexed [objs ['rowidx' ]], Byref (tmpvec )]
1266+ 'DMRestoreGlobalVector' , [objs ['Subdms' ].indexed [objs ['rowidx' ]], Byref (tvec )]
11571267 )
11581268 get_col_vec = petsc_call (
1159- 'DMGetGlobalVector' , [objs ['Subdms' ].indexed [objs ['colidx' ]], Byref (tmpvec )]
1269+ 'DMGetGlobalVector' , [objs ['Subdms' ].indexed [objs ['colidx' ]], Byref (tvec )]
11601270 )
11611271 get_col_size = petsc_call (
1162- 'VecGetSize' , [tmpvec , Byref (objs ['subblockcols' ])]
1272+ 'VecGetSize' , [tvec , Byref (objs ['subblockcols' ])]
11631273 )
11641274 restore_col_vec = petsc_call (
1165- 'DMRestoreGlobalVector' , [objs ['Subdms' ].indexed [objs ['colidx' ]], Byref (tmpvec )]
1275+ 'DMRestoreGlobalVector' , [objs ['Subdms' ].indexed [objs ['colidx' ]], Byref (tvec )]
11661276 )
11671277
11681278 mat_create = petsc_call ('MatCreate' , [sobjs ['comm' ], Byref (objs ['block' ])])
0 commit comments