Skip to content

Commit e97d380

Browse files
committed
compiler: Progress on petscoptions callbacks
1 parent 2c57450 commit e97d380

9 files changed

Lines changed: 108 additions & 60 deletions

File tree

devito/petsc/iet/passes.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
import cgen as c
22
import numpy as np
33
from functools import cached_property
4+
from sympy import Not
45

56
from devito.passes.iet.engine import iet_pass
67
from devito.ir.iet import (Transformer, MapNodes, Iteration, BlankLine,
78
DummyExpr, CallableBody, List, Call, Callable,
8-
FindNodes, Section)
9-
from devito.symbolics import Byref, Macro, FieldFromPointer
9+
FindNodes, Section, Conditional)
10+
from devito.symbolics import Byref, FieldFromPointer, Macro
1011
from devito.types import Symbol, Scalar
1112
from devito.types.basic import DataSymbol
1213
from devito.tools import frozendict
14+
import devito.logger as dl
15+
1316
from devito.petsc.types import (PetscMPIInt, PetscErrorCode, MultipleFieldData,
1417
PointerIS, Mat, CallbackVec, Vec, CallbackMat, SNES,
1518
DummyArg, PetscInt, PointerDM, PointerMat, MatReuse,
1619
CallbackPointerIS, CallbackPointerDM, JacobianStruct,
17-
SubMatrixStruct, Initialize, Finalize, ArgvSymbol)
18-
from devito.petsc.types.macros import petsc_func_begin_user
19-
from devito.petsc.iet.nodes import PetscMetaData
20+
SubMatrixStruct, Initialize, Finalize, ArgvSymbol,
21+
ConstCharPtr, PetscBool)
22+
from devito.petsc.types.macros import petsc_func_begin_user, Null
23+
from devito.petsc.iet.nodes import PetscMetaData, PETScCallable
2024
from devito.petsc.utils import core_metadata, petsc_languages
2125
from devito.petsc.iet.routines import (CBBuilder, CCBBuilder, BaseObjectBuilder,
2226
CoupledObjectBuilder, BaseSetup, CoupledSetup,
@@ -25,8 +29,6 @@
2529
from devito.petsc.iet.logging import PetscLogger
2630
from devito.petsc.iet.utils import petsc_call, petsc_call_mpi
2731

28-
import devito.logger as dl
29-
3032

3133
@iet_pass
3234
def lower_petsc(iet, **kwargs):
@@ -69,9 +71,14 @@ def lower_petsc(iet, **kwargs):
6971
# Map PETScSolve to its Section (for logging)
7072
section_mapper = MapNodes(Section, PetscMetaData, 'groupby').visit(iet)
7173

74+
# Utility callback for setting PETSc options.
75+
# This generic function can be reused across all PETScSolves.
76+
set_solver_option(efuncs)
77+
# from IPython import embed; embed()
78+
7279
for iters, (inject_solve,) in inject_solve_mapper.items():
7380

74-
builder = Builder(inject_solve, objs, iters, comm, section_mapper, **kwargs)
81+
builder = Builder(inject_solve, iters, comm, section_mapper, **kwargs)
7582

7683
setup.extend(builder.solver_setup.calls)
7784

@@ -80,13 +87,14 @@ def lower_petsc(iet, **kwargs):
8087

8188
efuncs.update(builder.cbbuilder.efuncs)
8289

83-
populate_matrix_context(efuncs, objs)
90+
populate_matrix_context(efuncs)
8491

8592
iet = Transformer(subs).visit(iet)
8693

8794
body = core + tuple(setup) + iet.body.body
8895
body = iet.body._rebuild(body=body)
8996
iet = iet._rebuild(body=body)
97+
# from IPython import embed; embed()
9098
metadata = {**core_metadata(), 'efuncs': tuple(efuncs.values())}
9199
return iet, metadata
92100

@@ -131,7 +139,7 @@ class Builder:
131139
returning subclasses of the objects initialised in __init__,
132140
depending on the properties of `inject_solve`.
133141
"""
134-
def __init__(self, inject_solve, objs, iters, comm, section_mapper, **kwargs):
142+
def __init__(self, inject_solve, iters, comm, section_mapper, **kwargs):
135143
self.inject_solve = inject_solve
136144
self.objs = objs
137145
self.iters = iters
@@ -191,7 +199,7 @@ def calls(self):
191199
return List(body=self.solve.calls+self.logger.calls)
192200

193201

194-
def populate_matrix_context(efuncs, objs):
202+
def populate_matrix_context(efuncs):
195203
if not objs['dummyefunc'] in efuncs.values():
196204
return
197205

@@ -205,7 +213,7 @@ def populate_matrix_context(efuncs, objs):
205213
)
206214
body = CallableBody(
207215
List(body=[subdms_expr, fields_expr]),
208-
init=(objs['begin_user'],),
216+
init=(petsc_func_begin_user,),
209217
retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])])
210218
)
211219
name = 'PopulateMatContext'
@@ -215,6 +223,33 @@ def populate_matrix_context(efuncs, objs):
215223
)
216224

217225

226+
def set_solver_option(efuncs):
227+
228+
option = ConstCharPtr(name='option', is_const=True)
229+
value = ConstCharPtr(name='value', is_const=True)
230+
set = PetscBool(name='set')
231+
232+
body = List(body=[
233+
petsc_call('PetscOptionsHasName', [Null, Null, option, Byref(set)]),
234+
Conditional(Not(set), petsc_call('PetscOptionsSetValue', [Null, option, value]))
235+
])
236+
237+
body = CallableBody(
238+
body,
239+
init=(petsc_func_begin_user,),
240+
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
241+
)
242+
243+
cb = PETScCallable(
244+
'SetPetscOption',
245+
body,
246+
retval=objs['err'],
247+
parameters=(option, value)
248+
)
249+
# from IPython import embed; embed()
250+
efuncs[cb.name] = cb
251+
252+
218253
subdms = PointerDM(name='subdms')
219254
fields = PointerIS(name='fields')
220255
submats = PointerMat(name='submats')
@@ -262,13 +297,8 @@ def populate_matrix_context(efuncs, objs):
262297
fields=[subdms, fields, submats], modifier=' *'
263298
),
264299
'subctx': SubMatrixStruct(fields=[rows, cols]),
265-
'Null': Macro('NULL'),
266300
'dummyctx': Symbol('lctx'),
267301
'dummyptr': DummyArg('dummy'),
268302
'dummyefunc': Symbol('dummyefunc'),
269303
'dof': PetscInt('dof'),
270-
'begin_user': c.Line('PetscFunctionBeginUser;'),
271304
})
272-
273-
# Move to macros file?
274-
Null = Macro('NULL')

devito/petsc/iet/routines.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
KSP, PC, SNES, PetscInt, StartPtr, PointerIS, PointerDM,
2323
VecScatter, DMCast, JacobianStruct, SubMatrixStruct,
2424
CallbackDM, PetscBool)
25+
from devito.petsc.types.macros import petsc_func_begin_user, Null
2526

2627

2728
class CBBuilder:
@@ -118,20 +119,18 @@ def _make_core(self):
118119
def _petsc_options_callback(self):
119120
objs = self.objs
120121
params = self.solver_parameters
121-
Null = objs['Null']
122+
options_prefix = self.inject_solve.expr.rhs.options_prefix
122123

123-
has_names = ()
124+
body = []
124125

126+
# from IPython import embed; embed()
125127
# TODO: improve
126128
for k, v in params.items():
127-
is_set = PetscBool(self.sregistry.make_name(prefix='set'))
128-
has_name = petsc_call('PetscOptionsHasName', [
129-
Null, Null, String(k), Byref(is_set)])
130-
has_names += (has_name,)
129+
body.append(petsc_call('SetPetscOption', [String("-"+options_prefix+k), String(v)]))
131130

132131
body = CallableBody(
133-
List(body=has_names),
134-
init=(objs['begin_user'],),
132+
List(body=body),
133+
init=(petsc_func_begin_user,),
135134
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
136135
)
137136

@@ -281,7 +280,7 @@ def _create_matvec_body(self, body, jacobian):
281280

282281
body = CallableBody(
283282
List(body=body),
284-
init=(objs['begin_user'],),
283+
init=(petsc_func_begin_user,),
285284
stacks=stacks+derefs,
286285
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
287286
)
@@ -420,7 +419,7 @@ def _create_formfunc_body(self, body):
420419

421420
body = CallableBody(
422421
List(body=body),
423-
init=(objs['begin_user'],),
422+
init=(petsc_func_begin_user,),
424423
stacks=stacks+derefs,
425424
retstmt=(Call('PetscFunctionReturn', arguments=[0]),))
426425

@@ -530,7 +529,7 @@ def _create_form_rhs_body(self, body):
530529

531530
body = CallableBody(
532531
List(body=[body]),
533-
init=(objs['begin_user'],),
532+
init=(petsc_func_begin_user,),
534533
stacks=stacks+derefs,
535534
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
536535
)
@@ -608,7 +607,7 @@ def _create_initial_guess_body(self, body):
608607

609608
body = CallableBody(
610609
List(body=[body]),
611-
init=(objs['begin_user'],),
610+
init=(petsc_func_begin_user,),
612611
stacks=stacks+derefs,
613612
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
614613
)
@@ -635,7 +634,7 @@ def _make_user_struct_callback(self):
635634
for i in mainctx.callback_fields
636635
]
637636
struct_callback_body = CallableBody(
638-
List(body=body), init=(self.objs['begin_user'],),
637+
List(body=body), init=(petsc_func_begin_user,),
639638
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
640639
)
641640
cb = Callable(
@@ -759,7 +758,7 @@ def _whole_matvec_body(self):
759758
)
760759
return CallableBody(
761760
List(body=(ctx_main, zero_y_memory, BlankLine) + calls),
762-
init=(objs['begin_user'],),
761+
init=(petsc_func_begin_user,),
763762
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
764763
)
765764

@@ -894,7 +893,7 @@ def _whole_formfunc_body(self, body):
894893

895894
formfunc_body = CallableBody(
896895
List(body=body),
897-
init=(objs['begin_user'],),
896+
init=(petsc_func_begin_user,),
898897
stacks=stacks+derefs,
899898
casts=(f_soa, x_soa),
900899
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
@@ -943,7 +942,6 @@ def _submat_callback_body(self):
943942

944943
get_ctx = petsc_call('MatShellGetContext', [objs['J'], Byref(objs['ljacctx'])])
945944

946-
Null = objs['Null']
947945
dm_get_info = petsc_call(
948946
'DMDAGetInfo', [
949947
sobjs['callbackdm'], Null, Byref(sobjs['M']), Byref(sobjs['N']),
@@ -1046,7 +1044,7 @@ def _submat_callback_body(self):
10461044

10471045
return CallableBody(
10481046
List(body=tuple(body)),
1049-
init=(objs['begin_user'],),
1047+
init=(petsc_func_begin_user,),
10501048
stacks=(get_ctx, deref_subdm),
10511049
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
10521050
)
@@ -1105,7 +1103,7 @@ def _build(self):
11051103
'localsize': PetscInt(sreg.make_name(prefix='localsize')),
11061104
'dmda': DM(sreg.make_name(prefix='da'), dofs=len(targets)),
11071105
'callbackdm': CallbackDM(sreg.make_name(prefix='dm')),
1108-
'snesprefix': String((options_prefix or '') + '_'),
1106+
'snesprefix': String(options_prefix or ''),
11091107
'options_prefix': options_prefix,
11101108
}
11111109
base_dict['comm'] = self.comm
@@ -1262,7 +1260,7 @@ def _setup(self):
12621260

12631261
dmda = sobjs['dmda']
12641262

1265-
solver_params = self.inject_solve.expr.rhs.solver_parameters
1263+
# solver_params = self.inject_solve.expr.rhs.solver_parameters
12661264

12671265
snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])])
12681266

@@ -1283,7 +1281,7 @@ def _setup(self):
12831281

12841282
snes_set_jac = petsc_call(
12851283
'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'],
1286-
sobjs['Jac'], 'MatMFFDComputeJacobian', objs['Null']]
1284+
sobjs['Jac'], 'MatMFFDComputeJacobian', Null]
12871285
)
12881286

12891287
global_x = petsc_call('DMCreateGlobalVector',
@@ -1312,16 +1310,17 @@ def _setup(self):
13121310
snes_get_ksp = petsc_call('SNESGetKSP',
13131311
[sobjs['snes'], Byref(sobjs['ksp'])])
13141312

1315-
ksp_set_tols = petsc_call(
1316-
'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'],
1317-
solver_params['ksp_atol'], solver_params['ksp_divtol'],
1318-
solver_params['ksp_max_it']]
1319-
)
1313+
# ksp_set_tols = petsc_call(
1314+
# 'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'],
1315+
# solver_params['ksp_atol'], solver_params['ksp_divtol'],
1316+
# solver_params['ksp_max_it']]
1317+
# )
13201318

13211319
# ksp_set_type = petsc_call(
13221320
# 'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
13231321
# )
13241322

1323+
# TODO: can drop this
13251324
ksp_get_pc = petsc_call(
13261325
'KSPGetPC', [sobjs['ksp'], Byref(sobjs['pc'])]
13271326
)
@@ -1339,7 +1338,7 @@ def _setup(self):
13391338
formfunc = self.cbbuilder._F_efunc
13401339
formfunc_operation = petsc_call(
13411340
'SNESSetFunction',
1342-
[sobjs['snes'], objs['Null'], FormFunctionCallback(formfunc.name, void, void),
1341+
[sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void),
13431342
self.snes_ctx]
13441343
)
13451344

@@ -1373,7 +1372,7 @@ def _setup(self):
13731372
get_local_size,
13741373
global_b,
13751374
snes_get_ksp,
1376-
ksp_set_tols,
1375+
# ksp_set_tols,
13771376
ksp_get_pc,
13781377
pc_set_type,
13791378
ksp_set_from_ops,
@@ -1430,7 +1429,7 @@ def _create_dmda(self, dmda):
14301429
stencil_width = self.field_data.space_order
14311430

14321431
args.append(stencil_width)
1433-
args.extend([objs['Null']]*nspace_dims)
1432+
args.extend([Null]*nspace_dims)
14341433

14351434
# The distributed array object
14361435
args.append(Byref(dmda))
@@ -1465,7 +1464,7 @@ def _setup(self):
14651464

14661465
snes_set_jac = petsc_call(
14671466
'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'],
1468-
sobjs['Jac'], 'MatMFFDComputeJacobian', objs['Null']]
1467+
sobjs['Jac'], 'MatMFFDComputeJacobian', Null]
14691468
)
14701469

14711470
global_x = petsc_call('DMCreateGlobalVector',
@@ -1506,7 +1505,7 @@ def _setup(self):
15061505
formfunc = self.cbbuilder._F_efunc
15071506
formfunc_operation = petsc_call(
15081507
'SNESSetFunction',
1509-
[sobjs['snes'], objs['Null'], FormFunctionCallback(formfunc.name, void, void),
1508+
[sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void),
15101509
self.snes_ctx]
15111510
)
15121511

@@ -1529,7 +1528,7 @@ def _setup(self):
15291528

15301529
create_field_decomp = petsc_call(
15311530
'DMCreateFieldDecomposition',
1532-
[dmda, Byref(sobjs['nfields']), objs['Null'], Byref(sobjs['fields']),
1531+
[dmda, Byref(sobjs['nfields']), Null, Byref(sobjs['fields']),
15331532
Byref(sobjs['subdms'])]
15341533
)
15351534
submat_cb = self.cbbuilder.submatrices_callback
@@ -1710,7 +1709,7 @@ def _execute_solve(self):
17101709
),
17111710
petsc_call(
17121711
'VecScatterCreate',
1713-
[xglob, field, target_xglob, self.objs['Null'], Byref(s)]
1712+
[xglob, field, target_xglob, Null, Byref(s)]
17141713
),
17151714
petsc_call(
17161715
'VecScatterBegin',
@@ -1738,7 +1737,7 @@ def _execute_solve(self):
17381737
)
17391738
)
17401739

1741-
snes_solve = (petsc_call('SNESSolve', [sobjs['snes'], objs['Null'], xglob]),)
1740+
snes_solve = (petsc_call('SNESSolve', [sobjs['snes'], Null, xglob]),)
17421741

17431742
return (struct_assignment,) + pre_solve + snes_solve + post_solve + (BlankLine,)
17441743

devito/petsc/iet/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from devito.petsc.iet.nodes import PetscMetaData, PETScCall
21
from devito.ir.equations import OpPetsc
32
from devito.ir.iet import Dereference, FindSymbols, Uxreplace
43
from devito.types.basic import AbstractFunction
54

5+
from devito.petsc.iet.nodes import PetscMetaData, PETScCall
6+
67

78
def petsc_call(specific_call, call_args):
89
return PETScCall('PetscCall', [PETScCall(specific_call, arguments=call_args)])

0 commit comments

Comments
 (0)