Skip to content

Commit ec819f3

Browse files
committed
compiler: Support constrain_bcs for mixed petscsolves
1 parent 7ce9bd4 commit ec819f3

6 files changed

Lines changed: 248 additions & 85 deletions

File tree

devito/petsc/iet/builder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,12 +361,15 @@ def _create_dmda_calls(self, dmda):
361361
dm_setup = petsc_call('DMSetUp', [dmda])
362362
dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL'])
363363

364+
targets = self.field_data.targets
364365
count_bcs = petsc_call(
365-
self.callback_builder._count_bc_efunc.name, [dmda, Byref(sobjs['numBC'])]
366+
self.callback_builder._count_bc_efunc.name,
367+
[dmda] + [Byref(sobjs[f'numBC_{t.name}']) for t in targets]
366368
)
367369

368370
set_point_bcs = petsc_call(
369-
self.callback_builder._point_bc_efunc.name, [dmda, sobjs['numBC']]
371+
self.callback_builder._set_point_bc_efunc.name,
372+
[dmda] + [sobjs[f'numBC_{t.name}'] for t in targets]
370373
)
371374

372375
get_local_section = petsc_call(

devito/petsc/iet/callbacks.py

Lines changed: 175 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from devito.passes.iet.linearization import Stride
1616

1717
from devito.petsc.iet.nodes import PETScCallable, MatShellSetOp, petsc_call
18-
from devito.petsc.types import DMCast, MainUserStruct, CallbackUserStruct, PetscObjectCast
18+
from devito.petsc.types import (
19+
DMCast, MainUserStruct, CallbackUserStruct, PetscObjectCast, PetscInt
20+
)
1921
from devito.petsc.iet.type_builder import objs
2022
from devito.petsc.types.macros import petsc_func_begin_user
2123
from devito.petsc.types.modes import InsertMode
22-
from devito.petsc.types.object import Counter
2324

2425

2526
class BaseCallbackBuilder:
@@ -47,7 +48,7 @@ def __init__(self, **kwargs):
4748
self._F_efunc = None
4849
self._b_efunc = None
4950
self._count_bc_efunc = None
50-
self._point_bc_efunc = None
51+
self._set_point_bc_efunc = None
5152
self._J_efuncs = []
5253
self._initial_guess_efuncs = []
5354

@@ -646,43 +647,65 @@ def _create_initial_guess_body(self, body):
646647

647648
def _make_constrain_bc(self):
648649
"""
649-
To constrain essential boundary nodes, two additional callbacks are required.
650-
This method constructs the corresponding efuncs: `CountBCs` and `SetPointBCs`.
650+
Constructs the `CountBCs` and `SetPointBCs` efuncs. Works for both
651+
single- and multi-field: all fields' expressions are compiled together
652+
(clustering may fuse loops) and a single callback is emitted for each.
651653
"""
652-
increment_exprs = self.field_data.constrain_bc.increment_exprs
653-
point_bc_exprs = self.field_data.constrain_bc.point_bc_exprs
654+
constrain_bc = self.field_data.constrain_bc
654655
sobjs = self.solver_objs
655656

656-
# Compile `increment_exprs` into an IET via recursive compilation
657+
# Normalize to dict {target: ConstrainBC}
658+
if isinstance(constrain_bc, dict):
659+
constrain_bc_dict = constrain_bc
660+
else:
661+
constrain_bc_dict = {self.field_data.target: constrain_bc}
662+
targets = list(constrain_bc_dict.keys())
663+
664+
all_increment_exprs = [
665+
e for t in targets for e in constrain_bc_dict[t].increment_exprs
666+
]
657667
irs0, _ = self.rcompile(
658-
increment_exprs, options={'mpi': False}, sregistry=self.sregistry,
659-
concretize_mapper=self.concretize_mapper
668+
all_increment_exprs, options={'mpi': False},
669+
sregistry=self.sregistry, concretize_mapper=self.concretize_mapper
660670
)
661-
# Compile `point_bc_exprs` into an IET via recursive compilation
671+
all_point_bc_exprs = [
672+
e for t in targets for e in constrain_bc_dict[t].point_bc_exprs
673+
]
662674
irs1, _ = self.rcompile(
663-
point_bc_exprs, options={'mpi': False}, sregistry=self.sregistry,
664-
concretize_mapper=self.concretize_mapper
665-
)
666-
count_bc_body = self._create_count_bc_body(
667-
List(body=irs0.uiet.body)
675+
all_point_bc_exprs, options={'mpi': False},
676+
sregistry=self.sregistry, concretize_mapper=self.concretize_mapper
668677
)
678+
679+
pairs = [
680+
(sobjs[f'numBCPtr_{t.name}'], constrain_bc_dict[t].counter)
681+
for t in targets
682+
]
683+
count_bc_body = self._create_count_bc_body(List(body=irs0.uiet.body), pairs)
669684
set_point_bc_body = self._create_set_point_bc_body(
670-
List(body=irs1.uiet.body)
685+
List(body=irs1.uiet.body), constrain_bc_dict
671686
)
687+
688+
numBCPtr_params = tuple(sobjs[f'numBCPtr_{t.name}'] for t in targets)
689+
numBC_params = tuple(sobjs[f'numBC_{t.name}'] for t in targets)
690+
672691
cb0 = self._make_petsc_callable(
673692
'CountBCs', count_bc_body,
674-
parameters=(sobjs['callbackdm'], sobjs['numBCPtr'])
693+
parameters=(sobjs['callbackdm'],) + numBCPtr_params
675694
)
676695
cb1 = self._make_petsc_callable(
677696
'SetPointBCs', set_point_bc_body,
678-
parameters=(sobjs['callbackdm'], sobjs['numBC'])
697+
parameters=(sobjs['callbackdm'],) + numBC_params
679698
)
680699
self._count_bc_efunc = cb0
700+
self._set_point_bc_efunc = cb1
681701
self._efuncs[cb0.name] = cb0
682-
self._point_bc_efunc = cb1
683702
self._efuncs[cb1.name] = cb1
684703

685-
def _create_count_bc_body(self, body):
704+
def _create_count_bc_body(self, body, pairs):
705+
"""
706+
Generic CountBCs body. `pairs` is a list of (numBCPtr, counter) tuples,
707+
one per field. All fields are handled in a single callback body.
708+
"""
686709
objs = self.objs
687710
sobjs = self.solver_objs
688711

@@ -698,26 +721,29 @@ def _create_count_bc_body(self, body):
698721
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
699722
)
700723

701-
# TODO: change names
702-
deref_ptr = DummyExpr(Counter, Deref(sobjs['numBCPtr']))
703-
move_ptr = DummyExpr(Deref(sobjs['numBCPtr']), Counter)
724+
deref_ptrs = tuple(
725+
DummyExpr(counter, Deref(numBCPtr)) for numBCPtr, counter in pairs
726+
)
727+
move_ptrs = tuple(
728+
DummyExpr(Deref(numBCPtr), counter) for numBCPtr, counter in pairs
729+
)
704730

705-
# Force the struct definition to appear at the very start, since
706-
# stacks, allocs etc may rely on its information
707731
struct_definition = [Definition(ctx), dm_get_app_context]
708732

709-
body = body._rebuild(body.body + (move_ptr,))
733+
body = body._rebuild(body.body + move_ptrs)
710734

711735
body = self._make_callable_body(
712-
body, standalones=struct_definition, stacks=(deref_ptr,)
736+
body, standalones=struct_definition, stacks=deref_ptrs
713737
)
714-
# Replace non-function data with pointer to data in struct
715738
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
716739
i in fields if not isinstance(i.function, AbstractFunction)}
717740

718741
return Uxreplace(subs).visit(body)
719742

720-
def _create_set_point_bc_body(self, body):
743+
def _create_set_point_bc_body(self, body, constrain_bc_dict):
744+
"""Single-field SetPointBCs body. `constrain_bc_dict` has one entry."""
745+
(target, constrain_bc), = constrain_bc_dict.items()
746+
tname = target.name
721747
linsolve_expr = self.inject_solve.expr.rhs
722748
objs = self.objs
723749
sobjs = self.solver_objs
@@ -739,58 +765,43 @@ def _create_set_point_bc_body(self, body):
739765
)
740766
petsc_obj_comm = Call('PetscObjectComm', arguments=[PetscObjectCast(dmda)])
741767
is_create_general = petsc_call(
742-
'ISCreateGeneral', [petsc_obj_comm, sobjs['numBC'], sobjs['bcPointsArr'],
743-
'PETSC_OWN_POINTER', Byref(sobjs['bcPointsIS'])]
768+
'ISCreateGeneral',
769+
[petsc_obj_comm, sobjs[f'numBC_{tname}'], sobjs[f'bcPointsArr_{tname}'],
770+
'PETSC_OWN_POINTER', Byref(sobjs['bcPointsIS'])]
744771
)
745772
malloc_bc_points_arr = petsc_call(
746-
'PetscMalloc1', [sobjs['numBC'], Byref(sobjs['bcPointsArr']._C_symbol)]
773+
'PetscMalloc1',
774+
[sobjs[f'numBC_{tname}'], Byref(sobjs[f'bcPointsArr_{tname}']._C_symbol)]
747775
)
748-
749776
malloc_bc_points = petsc_call(
750777
'PetscMalloc1', [1, Byref(sobjs['bcPoints']._C_symbol)]
751778
)
752-
753779
dummy_expr = DummyExpr(sobjs['bcPoints'].indexed[0], sobjs['bcPointsIS'])
754-
755780
set_point_bc = petsc_call(
756781
'DMDASetPointBC', [dmda, 1, sobjs['bcPoints'], Null]
757782
)
758783
body = body._rebuild(
759784
body=(
760785
(malloc_bc_points_arr,)
761786
+ body.body
762-
+ (
763-
is_create_general,
764-
malloc_bc_points,
765-
dummy_expr,
766-
set_point_bc,
767-
)
787+
+ (is_create_general, malloc_bc_points, dummy_expr, set_point_bc,)
768788
)
769789
)
770-
stacks = (
771-
dm_get_local_info,
772-
)
773790

774-
# Dereference function data in struct
775791
derefs = dereference_funcs(ctx, fields)
776-
777-
# Force the struct definition to appear at the very start, since
778-
# stacks, allocs etc may rely on its information
779792
standalones = [
780793
Definition(ctx),
781794
dm_get_app_context,
782-
Definition(sobjs['k_iter'])
795+
Definition(sobjs[f'k_iter_{tname}'])
783796
]
784-
785797
body = self._make_callable_body(
786-
body, standalones=standalones, stacks=stacks+derefs
798+
body, standalones=standalones, stacks=(dm_get_local_info,) + derefs
787799
)
788800

789-
# Replace non-function data with pointer to data in struct
790801
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
791802
i in fields if not isinstance(i.function, AbstractFunction)}
792-
793-
subs[Counter._C_symbol] = sobjs['bcPointsArr'].indexed[sobjs['k_iter']]
803+
subs[constrain_bc.counter._C_symbol] = \
804+
sobjs[f'bcPointsArr_{tname}'].indexed[sobjs[f'k_iter_{tname}']]
794805

795806
return Uxreplace(subs).visit(body)
796807

@@ -846,6 +857,105 @@ def __init__(self, **kwargs):
846857
def submatrices_callback(self):
847858
return self._submatrices_callback
848859

860+
def _create_set_point_bc_body(self, body, _constrain_bc_dict):
861+
return self._create_set_point_bc_body_coupled(body)
862+
863+
def _create_set_point_bc_body_coupled(self, body):
864+
"""
865+
Combined SetPointBCs body for all target fields. The body is compiled
866+
from all fields' point_bc_exprs together (loops may be fused by
867+
clustering). Per-field counter symbols are substituted with the
868+
corresponding bcPointsArr[k_iter] after assembly.
869+
"""
870+
linsolve_expr = self.inject_solve.expr.rhs
871+
objs = self.objs
872+
sobjs = self.solver_objs
873+
constrain_bc = self.field_data.constrain_bc
874+
targets = self.field_data.targets
875+
nfields = len(targets)
876+
dmda = sobjs['callbackdm']
877+
ctx = objs['dummyctx']
878+
879+
dm_get_local_info = petsc_call(
880+
'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)]
881+
)
882+
dm_get_app_context = petsc_call(
883+
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
884+
)
885+
petsc_obj_comm = Call('PetscObjectComm', arguments=[PetscObjectCast(dmda)])
886+
887+
body = self.time_dependence.uxreplace_time(body)
888+
fields = get_user_struct_fields(body)
889+
self._struct_params.extend(fields)
890+
891+
bcPointsIS = sobjs['bcPointsIS']
892+
bcCompsIS = sobjs['bcCompsIS']
893+
894+
# Zero-initialise IS arrays (PetscCalloc1 sets pointers to NULL so
895+
# the automatic ISDestroy cleanup is safe even on early exit)
896+
is_array_mallocs = (
897+
petsc_call('PetscCalloc1', [nfields, Byref(bcPointsIS._C_symbol)]),
898+
petsc_call('PetscCalloc1', [nfields, Byref(bcCompsIS._C_symbol)]),
899+
)
900+
bc_arr_mallocs = tuple(
901+
petsc_call('PetscMalloc1',
902+
[sobjs[f'numBC_{t.name}'],
903+
Byref(sobjs[f'bcPointsArr_{t.name}']._C_symbol)])
904+
for t in targets
905+
)
906+
907+
is_creates, comp_creates = [], []
908+
for i, t in enumerate(targets):
909+
tname = t.name
910+
is_creates.append(petsc_call(
911+
'ISCreateGeneral',
912+
[petsc_obj_comm, sobjs[f'numBC_{tname}'],
913+
sobjs[f'bcPointsArr_{tname}'],
914+
'PETSC_OWN_POINTER', Byref(bcPointsIS.indexed[i])]
915+
))
916+
comp_arr = PetscInt(name=f'comp{i}', initvalue=i)
917+
comp_creates.append(petsc_call(
918+
'ISCreateGeneral',
919+
[petsc_obj_comm, 1, Byref(comp_arr),
920+
'PETSC_COPY_VALUES', Byref(bcCompsIS.indexed[i])]
921+
))
922+
923+
set_point_bc = petsc_call(
924+
'DMDASetPointBC', [dmda, nfields, bcPointsIS, bcCompsIS]
925+
)
926+
927+
body = body._rebuild(body=(
928+
is_array_mallocs
929+
+ bc_arr_mallocs
930+
+ body.body
931+
+ tuple(is_creates)
932+
+ tuple(comp_creates)
933+
+ (set_point_bc,)
934+
))
935+
936+
derefs = dereference_funcs(ctx, fields)
937+
k_defs = [Definition(sobjs[f'k_iter_{t.name}']) for t in targets]
938+
comp_defs = [
939+
Definition(PetscInt(name=f'comp{i}', initvalue=i))
940+
for i in range(nfields)
941+
]
942+
standalones = [Definition(ctx), dm_get_app_context] + k_defs + comp_defs
943+
944+
body = self._make_callable_body(
945+
body, standalones=standalones,
946+
stacks=(dm_get_local_info,) + derefs
947+
)
948+
949+
# Struct substitutions + per-field counter → bcArr[k_iter]
950+
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
951+
i in fields if not isinstance(i.function, AbstractFunction)}
952+
for t in targets:
953+
tname = t.name
954+
subs[constrain_bc[t].counter._C_symbol] = \
955+
sobjs[f'bcPointsArr_{tname}'].indexed[sobjs[f'k_iter_{tname}']]
956+
957+
return Uxreplace(subs).visit(body)
958+
849959
@property
850960
def jacobian(self):
851961
return self.inject_solve.expr.rhs.field_data.jacobian
@@ -866,6 +976,8 @@ def _make_core(self):
866976
self._make_options_callback()
867977
self._make_whole_matvec()
868978
self._make_whole_formfunc()
979+
if self.field_data.constrain_bc:
980+
self._make_constrain_bc()
869981
self._make_user_struct_efunc()
870982
self._create_destroy_submatrix()
871983
self._create_submatrices()
@@ -1138,31 +1250,29 @@ def _submat_callback_body(self):
11381250
)
11391251

11401252
i = Dimension(name='i')
1141-
tmpvec = sobjs['tmpvec']
1253+
tvec = sobjs['tmpvec']
11421254

11431255
row_idx = DummyExpr(objs['rowidx'], IntDiv(i, objs['dof']))
11441256
col_idx = DummyExpr(objs['colidx'], Mod(i, objs['dof']))
11451257

1146-
# Query constrained global size from each sub-DM via a temporary Vec.
1147-
# For unconstrained sub-DMs this is equivalent to M*N; for constrained
1148-
# (BC-excluded) sub-DMs it returns the reduced size automatically.
1258+
# Query global size from each sub-DM via a temporary Vec.
11491259
get_row_vec = petsc_call(
1150-
'DMGetGlobalVector', [objs['Subdms'].indexed[objs['rowidx']], Byref(tmpvec)]
1260+
'DMGetGlobalVector', [objs['Subdms'].indexed[objs['rowidx']], Byref(tvec)]
11511261
)
11521262
get_row_size = petsc_call(
1153-
'VecGetSize', [tmpvec, Byref(objs['subblockrows'])]
1263+
'VecGetSize', [tvec, Byref(objs['subblockrows'])]
11541264
)
11551265
restore_row_vec = petsc_call(
1156-
'DMRestoreGlobalVector', [objs['Subdms'].indexed[objs['rowidx']], Byref(tmpvec)]
1266+
'DMRestoreGlobalVector', [objs['Subdms'].indexed[objs['rowidx']], Byref(tvec)]
11571267
)
11581268
get_col_vec = petsc_call(
1159-
'DMGetGlobalVector', [objs['Subdms'].indexed[objs['colidx']], Byref(tmpvec)]
1269+
'DMGetGlobalVector', [objs['Subdms'].indexed[objs['colidx']], Byref(tvec)]
11601270
)
11611271
get_col_size = petsc_call(
1162-
'VecGetSize', [tmpvec, Byref(objs['subblockcols'])]
1272+
'VecGetSize', [tvec, Byref(objs['subblockcols'])]
11631273
)
11641274
restore_col_vec = petsc_call(
1165-
'DMRestoreGlobalVector', [objs['Subdms'].indexed[objs['colidx']], Byref(tmpvec)]
1275+
'DMRestoreGlobalVector', [objs['Subdms'].indexed[objs['colidx']], Byref(tvec)]
11661276
)
11671277

11681278
mat_create = petsc_call('MatCreate', [sobjs['comm'], Byref(objs['block'])])

0 commit comments

Comments
 (0)