Skip to content

Commit 2c57450

Browse files
committed
compiler: Use petsctools to process solver params and start callback to set petsc options
1 parent 17f0dc5 commit 2c57450

5 files changed

Lines changed: 135 additions & 26 deletions

File tree

devito/petsc/iet/routines.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from devito.petsc.types import (PETScArray, PetscBundle, DM, Mat, CallbackVec, Vec,
2222
KSP, PC, SNES, PetscInt, StartPtr, PointerIS, PointerDM,
2323
VecScatter, DMCast, JacobianStruct, SubMatrixStruct,
24-
CallbackDM)
24+
CallbackDM, PetscBool)
2525

2626

2727
class CBBuilder:
@@ -37,10 +37,12 @@ def __init__(self, **kwargs):
3737
self.objs = kwargs.get('objs')
3838
self.solver_objs = kwargs.get('solver_objs')
3939
self.inject_solve = kwargs.get('inject_solve')
40+
self.solver_parameters = self.inject_solve.expr.rhs.solver_parameters
4041

4142
self._efuncs = OrderedDict()
4243
self._struct_params = []
4344

45+
self._options_efunc = None
4446
self._main_matvec_callback = None
4547
self._user_struct_callback = None
4648
self._F_efunc = None
@@ -105,13 +107,44 @@ def target(self):
105107
return self.field_data.target
106108

107109
def _make_core(self):
110+
self._petsc_options_callback()
108111
self._make_matvec(self.field_data.jacobian)
109112
self._make_formfunc()
110113
self._make_formrhs()
111114
if self.field_data.initial_guess.exprs:
112115
self._make_initial_guess()
113116
self._make_user_struct_callback()
114117

118+
def _petsc_options_callback(self):
119+
objs = self.objs
120+
params = self.solver_parameters
121+
Null = objs['Null']
122+
123+
has_names = ()
124+
125+
# TODO: improve
126+
for k, v in params.items():
127+
is_set = PetscBool(self.sregistry.make_name(prefix='set'))
128+
has_name = petsc_call('PetscOptionsHasName', [
129+
Null, Null, String(k), Byref(is_set)])
130+
has_names += (has_name,)
131+
132+
body = CallableBody(
133+
List(body=has_names),
134+
init=(objs['begin_user'],),
135+
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
136+
)
137+
138+
objs = self.objs
139+
cb = PETScCallable(
140+
self.sregistry.make_name(prefix='SetPetscOptions'),
141+
body,
142+
retval=objs['err'],
143+
parameters=()
144+
)
145+
self._options_efunc = cb
146+
self._efuncs[cb.name] = cb
147+
115148
def _make_matvec(self, jacobian, prefix='MatMult'):
116149
# Compile `matvecs` into an IET via recursive compilation
117150
matvecs = jacobian.matvecs
@@ -1237,6 +1270,10 @@ def _setup(self):
12371270
'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snesprefix']]
12381271
) if sobjs['options_prefix'] else None
12391272

1273+
set_options = petsc_call(
1274+
self.cbbuilder._options_efunc.name, []
1275+
)
1276+
12401277
snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda])
12411278

12421279
create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])])
@@ -1281,9 +1318,9 @@ def _setup(self):
12811318
solver_params['ksp_max_it']]
12821319
)
12831320

1284-
ksp_set_type = petsc_call(
1285-
'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
1286-
)
1321+
# ksp_set_type = petsc_call(
1322+
# 'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
1323+
# )
12871324

12881325
ksp_get_pc = petsc_call(
12891326
'KSPGetPC', [sobjs['ksp'], Byref(sobjs['pc'])]
@@ -1326,6 +1363,7 @@ def _setup(self):
13261363
base_setup = dmda_calls + (
13271364
snes_create,
13281365
snes_options_prefix,
1366+
set_options,
13291367
snes_set_dm,
13301368
create_matrix,
13311369
snes_set_jac,
@@ -1336,7 +1374,6 @@ def _setup(self):
13361374
global_b,
13371375
snes_get_ksp,
13381376
ksp_set_tols,
1339-
ksp_set_type,
13401377
ksp_get_pc,
13411378
pc_set_type,
13421379
ksp_set_from_ops,

devito/petsc/types/object.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ class PetscInt(PetscObject):
9797
dtype = CustomIntType('PetscInt')
9898

9999

100+
class PetscBool(PetscObject):
101+
"""
102+
"""
103+
dtype = CustomDtype('PetscBool')
104+
105+
100106
class KSP(PetscObject):
101107
"""
102108
PETSc KSP : Linear Systems Solvers.

devito/petsc/types/types.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22

33
from itertools import chain
44
from functools import cached_property
5+
from petsctools import flatten_parameters
56

67
from devito.tools import Reconstructable, sympy_mutex, as_tuple, frozendict
78
from devito.tools.dtypes_lowering import dtype_mapper
8-
from devito.petsc.utils import petsc_variables
99
from devito.symbolics.extraction import separate_eqn, generate_targets, centre_stencil
10-
from devito.petsc.types.equation import EssentialBC, ZeroRow, ZeroColumn
1110
from devito.types.equation import Eq
1211
from devito.operations.solve import eval_time_derivatives
1312

13+
from devito.petsc.utils import petsc_variables
14+
from devito.petsc.types.equation import EssentialBC, ZeroRow, ZeroColumn
15+
1416

1517
class MetaData(sympy.Function, Reconstructable):
1618
def __new__(cls, expr, **kwargs):
@@ -63,6 +65,7 @@ class LinearSolveExpr(MetaData):
6365
__rkwargs__ = ('solver_parameters', 'field_data', 'time_mapper',
6466
'localinfo', 'options_prefix')
6567

68+
# TODO: Will be extended
6669
defaults = {
6770
'ksp_type': 'gmres',
6871
'pc_type': 'jacobi',
@@ -76,17 +79,19 @@ def __new__(cls, expr, solver_parameters=None,
7679
field_data=None, time_mapper=None, localinfo=None,
7780
options_prefix=None, **kwargs):
7881

79-
if solver_parameters is None:
80-
solver_parameters = cls.defaults
81-
else:
82-
for key, val in cls.defaults.items():
83-
solver_parameters[key] = solver_parameters.get(key, val)
82+
# TODO: move into a function
83+
flattened_params = flatten_parameters(solver_parameters or {})
84+
processed = cls.defaults.copy()
85+
processed.update(flattened_params)
86+
processed = {k: str(v) for k, v in processed.items()}
87+
# TODO: attach options_prefix to the parameters
88+
# TODO: need to add the "-" to the beginning of each key -> use petsctools etc
8489

8590
with sympy_mutex:
8691
obj = sympy.Function.__new__(cls, expr)
8792

8893
obj._expr = expr
89-
obj._solver_parameters = solver_parameters
94+
obj._solver_parameters = processed
9095
obj._field_data = field_data if field_data else FieldData()
9196
obj._time_mapper = time_mapper
9297
obj._localinfo = localinfo

examples/petsc/solver_options.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ int main(int argc, char **argv)
2121
PetscCallMPI(MPI_Comm_size(PETSC_COMM_WORLD, &size));
2222

2323
PetscCall(PetscOptionsSetValue(NULL, "-poisson_ksp_type", "cg"));
24-
PetscCall(PetscOptionsInsert(NULL, &argc, &argv, NULL));
24+
// PetscCall(PetscOptionsInsert(NULL, &argc, &argv, NULL));
2525

2626
PetscCall(SNESCreate(PETSC_COMM_WORLD, &snes));
2727
PetscCall(SNESSetOptionsPrefix(snes, "poisson_"));
28+
// PetscCall(SNESSetType(snes, SNESKSPONLY));
29+
// PetscCall(PetscOptionsSetValue(NULL, "-poisson_snes_type", "snesksponly"));
2830
PetscCall(SNESSetFromOptions(snes));
2931

3032
PetscCall(VecCreate(PETSC_COMM_WORLD, &x));
@@ -44,7 +46,7 @@ int main(int argc, char **argv)
4446
PetscCall(KSPGetPC(ksp, &pc));
4547
PetscCall(PCSetType(pc, PCNONE));
4648
PetscCall(KSPSetTolerances(ksp, 1.e-4, PETSC_CURRENT, PETSC_CURRENT, 20));
47-
49+
PetscCall(KSPSetFromOptions(ksp));
4850

4951
// PetscCall(VecSet(x, pfive));
5052
PetscCall(SNESSolve(snes, NULL, x));

examples/petsc/solver_options.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from devito.petsc import PETScSolve
77
from devito.petsc.initialize import PetscInitialize
88
import petsctools
9-
from petsctools import get_commandline_options
9+
from petsctools import get_commandline_options, OptionsManager, flatten_parameters
1010
configuration['compiler'] = 'custom'
1111
os.environ['CC'] = 'mpicc'
1212

@@ -34,20 +34,79 @@
3434

3535
eq = Eq(v, u.laplace, subdomain=grid.interior)
3636

37-
petsc = PETScSolve([eq], u)
37+
solver = PETScSolve([eq], u, solver_parameters={'ksp_rtol': 1e-8}, options_prefix='poisson')
3838

3939
with switchconfig(language='petsc'):
40-
op = Operator(petsc)
40+
op = Operator(solver)
4141
op.apply()
42+
print(op.ccode)
43+
44+
45+
# import sys
46+
# petsctools.options._commandline_options = sys.argv[1:]
47+
# tmp = get_commandline_options()
48+
# print("Command line options:", tmp)
49+
50+
51+
52+
# class DevitoOptionsManager(OptionsManager):
53+
# """
54+
# """
55+
# def __init__(self, parameters, options_prefix):
56+
# if parameters is None:
57+
# parameters = {}
58+
# else:
59+
# # Convert nested dicts
60+
# parameters = flatten_parameters(parameters)
61+
# if options_prefix is None:
62+
# self.options_prefix = "firedrake_%d_" % next(self.count)
63+
# self.parameters = parameters
64+
# self.to_delete = set(parameters)
65+
# else:
66+
# if len(options_prefix) and not options_prefix.endswith("_"):
67+
# options_prefix += "_"
68+
# self.options_prefix = options_prefix
69+
# # Remove those options from the dict that were passed on
70+
# # the commandline.
71+
# self.parameters = {
72+
# k: v
73+
# for k, v in parameters.items()
74+
# if options_prefix + k not in get_commandline_options()
75+
# }
76+
# self.to_delete = set(self.parameters)
77+
# # Now update parameters from options, so that they're
78+
# # available to solver setup (for, e.g., matrix-free).
79+
# # Can't ask for the prefixed guy in the options object,
80+
# # since that does not DTRT for flag options.
81+
# # for k, v in self.options_object.getAll().items():
82+
# # if k.startswith(self.options_prefix):
83+
# # self.parameters[k[len(self.options_prefix):]] = v
84+
85+
# for k, v in get_commandline_options():
86+
# if k.startswith(self.options_prefix):
87+
# self.parameters[k[len(self.options_prefix):]] = v
88+
# self._setfromoptions = False
89+
90+
91+
92+
# options_manager = DevitoOptionsManager(solver.rhs.solver_parameters, solver.rhs.options_prefix)
93+
94+
95+
# nested = {"ksp_type": "cg",
96+
# "pc_type": "fieldsplit",
97+
# "fieldsplit_0": {"ksp_type": "gmres",
98+
# "pc_type": "hypre",
99+
# "ksp_rtol": 1e-5},
100+
# "fieldsplit_1": {"ksp_type": "richardson",
101+
# "pc_type": "ilu"}}
102+
103+
# tmp = flatten_parameters(nested)
104+
# # from IPython import embed; embed()
105+
106+
# # convert all values into strings
107+
108+
# tmp_str = {k: str(v) for k, v in tmp.items()}
42109

43-
# print(op.ccode)
44110

45-
# print(grid.shape)
46-
47-
48-
import sys
49-
petsctools.options._commandline_options = sys.argv[1:]
50-
tmp = get_commandline_options()
51-
print("Command line options:", tmp)
52111

53112

0 commit comments

Comments
 (0)