Skip to content

Commit 154a4dc

Browse files
committed
compiler: Some development with the bc callback automation
1 parent abdcc91 commit 154a4dc

7 files changed

Lines changed: 219 additions & 50 deletions

File tree

devito/petsc/equations.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sympy import Eq
22
from devito.symbolics import retrieve_indexed, retrieve_dimensions
3-
from devito.petsc.types.equation import ConstrainEssentialBC
3+
from devito.petsc.types.equation import ConstrainBC
44
from devito.types.dimension import CustomBoundSubDimension
55
from devito import Min, Max
66

@@ -22,7 +22,8 @@ def constrain_essential_bcs(expressions, **kwargs):
2222

2323
# build mapper
2424
for e in expressions:
25-
if not isinstance(e, ConstrainEssentialBC):
25+
# from IPython import embed; embed()
26+
if not isinstance(e, ConstrainBC):
2627
continue
2728

2829
indexeds = retrieve_indexed(e)
@@ -64,7 +65,7 @@ def constrain_essential_bcs(expressions, **kwargs):
6465

6566
# build new expressions
6667
for e in expressions:
67-
if isinstance(e, ConstrainEssentialBC):
68+
if isinstance(e, ConstrainBC):
6869
new_e = e.subs(mapper)
6970
new_exprs.append(new_e)
7071

devito/petsc/iet/builder.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from devito.ir.iet import DummyExpr, BlankLine
55
from devito.symbolics import (Byref, FieldFromPointer, VOID,
6-
FieldFromComposite, Null)
6+
FieldFromComposite, Null, String)
77

88
from devito.petsc.iet.nodes import (
99
FormFunctionCallback, MatShellSetOp, PETScCall, petsc_call
@@ -343,20 +343,25 @@ def _setup(self):
343343

344344
class ConstrainedBCMixin:
345345
"""
346+
not really a mixin?
346347
"""
347348
def _create_dmda_calls(self, dmda):
348349
sobjs = self.solver_objs
349350
# TODO: CLEAN UP
350351
dmda_create = self._create_dmda(dmda)
351352
# TODO: probs need to set the dm options prefix the same as snes?
352353
# don't hardcode this probs? - the dm needs to be specific to the solver as well
353-
da_create_section = petsc_call('PetscOptionsSetValue', [Null, '-da_use_section', Null])
354+
da_create_section = petsc_call('PetscOptionsSetValue', [Null, String("-da_use_section"), Null])
354355
dm_set_from_opts = petsc_call('DMSetFromOptions', [dmda])
355356
dm_setup = petsc_call('DMSetUp', [dmda])
356357
dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL'])
357358

358-
set_constraints = petsc_call(
359-
self.callback_builder._constrain_bc_efunc.name, [dmda]
359+
count_bcs = petsc_call(
360+
self.callback_builder._count_bc_efunc.name, [dmda, Byref(sobjs['numBC'])]
361+
)
362+
363+
set_point_bcs = petsc_call(
364+
self.callback_builder._point_bc_efunc.name, [dmda, Byref(sobjs['numBC'])]
360365
)
361366

362367
get_local_section = petsc_call('DMGetLocalSection', [dmda, Byref(sobjs['lsection'])])
@@ -377,7 +382,8 @@ def _create_dmda_calls(self, dmda):
377382
dm_set_from_opts,
378383
dm_setup,
379384
dm_mat_type,
380-
set_constraints,
385+
count_bcs,
386+
set_point_bcs,
381387
get_local_section,
382388
get_point_sf,
383389
create_global_section,

devito/petsc/iet/callbacks.py

Lines changed: 88 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from devito.petsc.iet.type_builder import objs
1818
from devito.petsc.types.macros import petsc_func_begin_user
1919
from devito.petsc.types.modes import InsertMode
20+
from devito.petsc.types.object import TempSymb
2021

2122

2223
class BaseCallbackBuilder:
@@ -44,7 +45,8 @@ def __init__(self, **kwargs):
4445
self._user_struct_callback = None
4546
self._F_efunc = None
4647
self._b_efunc = None
47-
self._constrain_bc_efunc = None
48+
self._count_bc_efunc = None
49+
self._point_bc_efunc = None
4850

4951
self._J_efuncs = []
5052
# TODO: isn't there only ever one of these per solver so why is it a list?
@@ -649,25 +651,39 @@ def _create_initial_guess_body(self, body):
649651
return Uxreplace(subs).visit(body)
650652

651653
def _make_constrain_bc(self):
652-
exprs = self.field_data.constrain_bc.exprs
654+
increment_exprs = self.field_data.constrain_bc.increment_exprs
655+
point_bc_exprs = self.field_data.constrain_bc.point_bc_exprs
653656
sobjs = self.solver_objs
654657
objs = self.objs
655658

656659
# Compile constrain `eqns` into an IET via recursive compilation
657-
irs, _ = self.rcompile(
658-
exprs, options={'mpi': False}, sregistry=self.sregistry,
660+
irs0, _ = self.rcompile(
661+
increment_exprs, options={'mpi': False}, sregistry=self.sregistry,
659662
concretize_mapper=self.concretize_mapper
660663
)
661-
body = self._create_constrain_bc_body(
662-
List(body=irs.uiet.body)
664+
# Compile constrain `eqns` into an IET via recursive compilation
665+
irs1, _ = self.rcompile(
666+
point_bc_exprs, options={'mpi': False}, sregistry=self.sregistry,
667+
concretize_mapper=self.concretize_mapper
663668
)
664-
cb = self._make_petsc_callable(
665-
'ConstrainBCs', body, parameters=(sobjs['callbackdm'],)
669+
count_bc_body = self._create_count_bc_body(
670+
List(body=irs0.uiet.body)
666671
)
667-
self._constrain_bc_efunc = cb
668-
self._efuncs[cb.name] = cb
672+
set_point_bc_body = self._create_set_point_bc_body(
673+
List(body=irs1.uiet.body)
674+
)
675+
cb0 = self._make_petsc_callable(
676+
'CountBCs', count_bc_body, parameters=(sobjs['callbackdm'], sobjs['numBCPtr'])
677+
)
678+
cb1 = self._make_petsc_callable(
679+
'SetPointBCs', set_point_bc_body, parameters=(sobjs['callbackdm'], sobjs['numBC'])
680+
)
681+
self._count_bc_efunc = cb0
682+
self._efuncs[cb0.name] = cb0
683+
self._point_bc_efunc = cb1
684+
self._efuncs[cb1.name] = cb1
669685

670-
def _create_constrain_bc_body(self, body):
686+
def _create_count_bc_body(self, body):
671687
linsolve_expr = self.inject_solve.expr.rhs
672688
objs = self.objs
673689
sobjs = self.solver_objs
@@ -676,11 +692,6 @@ def _create_constrain_bc_body(self, body):
676692
dmda = sobjs['callbackdm']
677693
ctx = objs['dummyctx']
678694

679-
x_arr = self.field_data.arrays[target]['x']
680-
681-
vec_get_array = petsc_call(
682-
'VecGetArray', [objs['xloc'], Byref(x_arr._C_symbol)]
683-
)
684695

685696
dm_get_local_info = petsc_call(
686697
'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)]
@@ -695,15 +706,14 @@ def _create_constrain_bc_body(self, body):
695706
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
696707
)
697708

698-
vec_restore_array = petsc_call(
699-
'VecRestoreArray', [objs['xloc'], Byref(x_arr._C_symbol)]
700-
)
709+
# dummyexpr = Dereference(self.target, sobjs['numBCPtr'])
701710

702-
body = body._rebuild(body=body.body + (vec_restore_array,))
711+
# body = body._rebuild(body=body.body)
712+
713+
body = body._rebuild(body.body)
703714

704715
stacks = (
705-
vec_get_array,
706-
dm_get_local_info
716+
dm_get_local_info,
707717
)
708718

709719
# Dereference function data in struct
@@ -720,6 +730,62 @@ def _create_constrain_bc_body(self, body):
720730
# Replace non-function data with pointer to data in struct
721731
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
722732
i in fields if not isinstance(i.function, AbstractFunction)}
733+
734+
# subs[]
735+
# subs[self.target] = sobjs['numBC']
736+
737+
subs[TempSymb._C_symbol] = sobjs['numBCPtr']._C_symbol
738+
739+
# from IPython import embed; embed()
740+
741+
return Uxreplace(subs).visit(body)
742+
743+
def _create_set_point_bc_body(self, body):
744+
linsolve_expr = self.inject_solve.expr.rhs
745+
objs = self.objs
746+
sobjs = self.solver_objs
747+
target = self.target
748+
749+
dmda = sobjs['callbackdm']
750+
ctx = objs['dummyctx']
751+
752+
753+
dm_get_local_info = petsc_call(
754+
'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)]
755+
)
756+
757+
body = self.time_dependence.uxreplace_time(body)
758+
759+
fields = get_user_struct_fields(body)
760+
self._struct_params.extend(fields)
761+
762+
dm_get_app_context = petsc_call(
763+
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
764+
)
765+
766+
# body = body._rebuild(body=body.body)
767+
768+
stacks = (
769+
dm_get_local_info,
770+
)
771+
772+
# Dereference function data in struct
773+
derefs = dereference_funcs(ctx, fields)
774+
775+
# Force the struct definition to appear at the very start, since
776+
# stacks, allocs etc may rely on its information
777+
struct_definition = [Definition(ctx), dm_get_app_context, Definition(sobjs['k_iter'])]
778+
779+
body = self._make_callable_body(
780+
body, standalones=struct_definition, stacks=stacks+derefs
781+
)
782+
783+
# Replace non-function data with pointer to data in struct
784+
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
785+
i in fields if not isinstance(i.function, AbstractFunction)}
786+
787+
788+
subs[TempSymb._C_symbol] = sobjs['bcPointsArr'].indexed[sobjs['k_iter']]
723789

724790
return Uxreplace(subs).visit(body)
725791

devito/petsc/iet/type_builder.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
PetscBundle, DM, Mat, CallbackVec, Vec, KSP, PC, SNES, PetscInt, StartPtr,
99
PointerIS, PointerDM, VecScatter, JacobianStruct, SubMatrixStruct, CallbackDM,
1010
PetscMPIInt, PetscErrorCode, PointerMat, MatReuse, CallbackPointerDM,
11-
CallbackPointerIS, CallbackMat, DummyArg, NofSubMats, PetscSectionGlobal, PetscSectionLocal, PetscSF
11+
CallbackPointerIS, CallbackMat, DummyArg, NofSubMats, PetscSectionGlobal, PetscSectionLocal, PetscSF,
12+
PetscIntPtr, CallbackPetscInt, CallbackPointerPetscInt
1213
)
1314

1415

@@ -210,6 +211,19 @@ def _extend_build(self, base_dict):
210211
)
211212
base_dict['sf'] = PetscSF(
212213
name=sreg.make_name(prefix='sf')
214+
)
215+
name = sreg.make_name(prefix='numBC')
216+
base_dict['numBC'] = PetscInt(
217+
name=name, initvalue=0
218+
)
219+
base_dict['numBCPtr'] = CallbackPetscInt(
220+
name=sreg.make_name(prefix='numBCPtr'), initvalue=0
221+
)
222+
base_dict['bcPointsArr'] = CallbackPointerPetscInt(
223+
name=sreg.make_name(prefix='bcPointsArr')
224+
)
225+
base_dict['k_iter'] = PetscInt(
226+
name='k_iter', initvalue=0
213227
)
214228
return base_dict
215229

devito/petsc/types/equation.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from devito.types.equation import Eq
1+
from devito.types.equation import Eq, Inc
22

33

44
__all__ = ['EssentialBC']
@@ -45,10 +45,28 @@ class ZeroColumn(EssentialBC):
4545
pass
4646

4747

48-
class ConstrainEssentialBC(EssentialBC):
49-
"""
50-
Equation used to constrain nodes marked by EssentialBCs
51-
inside a PetscSection. This type of equation is generated inside
52-
petscsolve if the user sets `constrain_bcs=True`.
53-
"""
48+
class ConstrainBC(EssentialBC):
5449
pass
50+
51+
52+
# class NoOfEssentialBC(ConstrainBC, Inc):
53+
# """
54+
# Equation used count essential boundary condition nodes.
55+
# This type of equation is generated inside
56+
# petscsolve if the user sets `constrain_bcs=True`.
57+
# """
58+
# def __new__(cls, *args, **kwargs):
59+
# return Inc.__new__(Inc, *args, **kwargs)
60+
61+
62+
class NoOfEssentialBC(ConstrainBC, Inc):
63+
"""Equation used count essential boundary condition nodes.
64+
This type of equation is generated inside
65+
petscsolve if the user sets `constrain_bcs=True`."""
66+
def __new__(cls, *args, **kwargs):
67+
return Inc.__new__(Inc, *args, **kwargs)
68+
69+
70+
class PointEssentialBC(ConstrainBC):
71+
pass
72+

0 commit comments

Comments
 (0)