Skip to content

Commit c5cb7ee

Browse files
committed
dsl/compiler: Re-factor solver params, add solver_parameters.py file
1 parent e97d380 commit c5cb7ee

12 files changed

Lines changed: 205 additions & 232 deletions

File tree

devito/finite_differences/finite_difference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
# Number of digits for FD coefficients to avoid roundup errors and non-deterministic
1111
# code generation
12-
_PRECISION = 18
12+
_PRECISION = 9
1313

1414

1515
@check_input

devito/petsc/clusters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from devito.tools import timed_pass
2-
from devito.petsc.types import LinearSolveExpr
2+
from devito.petsc.types import SolveExpr
33

44

55
@timed_pass()
@@ -19,7 +19,7 @@ def petsc_lift(clusters):
1919
"""
2020
processed = []
2121
for c in clusters:
22-
if isinstance(c.exprs[0].rhs, LinearSolveExpr):
22+
if isinstance(c.exprs[0].rhs, SolveExpr):
2323
ispace = c.ispace.lift(c.exprs[0].rhs.field_data.space_dimensions)
2424
processed.append(c.rebuild(ispace=ispace))
2525
else:

devito/petsc/iet/passes.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,8 @@ def lower_petsc(iet, **kwargs):
7171
# Map PETScSolve to its Section (for logging)
7272
section_mapper = MapNodes(Section, PetscMetaData, 'groupby').visit(iet)
7373

74-
# Utility callback for setting PETSc options.
75-
# This generic function can be reused across all PETScSolves.
74+
# Callback used to set a PetscOption - used by all PETScSolves
7675
set_solver_option(efuncs)
77-
# from IPython import embed; embed()
7876

7977
for iters, (inject_solve,) in inject_solve_mapper.items():
8078

@@ -94,7 +92,6 @@ def lower_petsc(iet, **kwargs):
9492
body = core + tuple(setup) + iet.body.body
9593
body = iet.body._rebuild(body=body)
9694
iet = iet._rebuild(body=body)
97-
# from IPython import embed; embed()
9895
metadata = {**core_metadata(), 'efuncs': tuple(efuncs.values())}
9996
return iet, metadata
10097

@@ -228,7 +225,7 @@ def set_solver_option(efuncs):
228225
option = ConstCharPtr(name='option', is_const=True)
229226
value = ConstCharPtr(name='value', is_const=True)
230227
set = PetscBool(name='set')
231-
228+
232229
body = List(body=[
233230
petsc_call('PetscOptionsHasName', [Null, Null, option, Byref(set)]),
234231
Conditional(Not(set), petsc_call('PetscOptionsSetValue', [Null, option, value]))
@@ -246,7 +243,6 @@ def set_solver_option(efuncs):
246243
retval=objs['err'],
247244
parameters=(option, value)
248245
)
249-
# from IPython import embed; embed()
250246
efuncs[cb.name] = cb
251247

252248

devito/petsc/iet/routines.py

Lines changed: 25 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
MatShellSetOp, PetscMetaData)
1818
from devito.petsc.iet.utils import (petsc_call, petsc_struct, zero_vector,
1919
dereference_funcs, residual_bundle)
20-
from devito.petsc.utils import solver_mapper
2120
from devito.petsc.types import (PETScArray, PetscBundle, DM, Mat, CallbackVec, Vec,
2221
KSP, PC, SNES, PetscInt, StartPtr, PointerIS, PointerDM,
2322
VecScatter, DMCast, JacobianStruct, SubMatrixStruct,
24-
CallbackDM, PetscBool)
23+
CallbackDM)
2524
from devito.petsc.types.macros import petsc_func_begin_user, Null
2625

2726

@@ -119,14 +118,14 @@ def _make_core(self):
119118
def _petsc_options_callback(self):
120119
objs = self.objs
121120
params = self.solver_parameters
122-
options_prefix = self.inject_solve.expr.rhs.options_prefix
121+
prefix = self.inject_solve.expr.rhs.formatted_prefix
123122

124-
body = []
125-
126-
# from IPython import embed; embed()
127-
# TODO: improve
128-
for k, v in params.items():
129-
body.append(petsc_call('SetPetscOption', [String("-"+options_prefix+k), String(v)]))
123+
body = [
124+
petsc_call(
125+
'SetPetscOption', [String(f"-{prefix}{k}"), String(str(v))]
126+
)
127+
for k, v in params.items()
128+
]
130129

131130
body = CallableBody(
132131
List(body=body),
@@ -695,6 +694,7 @@ def _make_core(self):
695694
for sm in self.field_data.jacobian.nonzero_submatrices:
696695
self._make_matvec(sm, prefix=f'{sm.name}_MatMult')
697696

697+
self._petsc_options_callback()
698698
self._make_whole_matvec()
699699
self._make_whole_formfunc()
700700
self._make_user_struct_callback()
@@ -1089,7 +1089,7 @@ def _build(self):
10891089
targets = self.field_data.targets
10901090

10911091
snes_name = sreg.make_name(prefix='snes')
1092-
options_prefix = self.inject_solve.expr.rhs.options_prefix
1092+
formatted_prefix = self.inject_solve.expr.rhs.formatted_prefix
10931093

10941094
base_dict = {
10951095
'Jac': Mat(sreg.make_name(prefix='J')),
@@ -1103,9 +1103,9 @@ def _build(self):
11031103
'localsize': PetscInt(sreg.make_name(prefix='localsize')),
11041104
'dmda': DM(sreg.make_name(prefix='da'), dofs=len(targets)),
11051105
'callbackdm': CallbackDM(sreg.make_name(prefix='dm')),
1106-
'snesprefix': String(options_prefix or ''),
1107-
'options_prefix': options_prefix,
1106+
'snes_prefix': String(formatted_prefix),
11081107
}
1108+
11091109
base_dict['comm'] = self.comm
11101110
self._target_dependent(base_dict)
11111111
return self._extend_build(base_dict)
@@ -1244,6 +1244,7 @@ def __init__(self, **kwargs):
12441244
self.solver_objs = kwargs.get('solver_objs')
12451245
self.cbbuilder = kwargs.get('cbbuilder')
12461246
self.field_data = self.inject_solve.expr.rhs.field_data
1247+
self.formatted_prefix = self.inject_solve.expr.rhs.formatted_prefix
12471248
self.calls = self._setup()
12481249

12491250
@property
@@ -1255,18 +1256,14 @@ def snes_ctx(self):
12551256
return VOID(self.solver_objs['dmda'], stars='*')
12561257

12571258
def _setup(self):
1258-
objs = self.objs
12591259
sobjs = self.solver_objs
1260-
12611260
dmda = sobjs['dmda']
12621261

1263-
# solver_params = self.inject_solve.expr.rhs.solver_parameters
1264-
12651262
snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])])
12661263

12671264
snes_options_prefix = petsc_call(
1268-
'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snesprefix']]
1269-
) if sobjs['options_prefix'] else None
1265+
'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snes_prefix']]
1266+
) if self.formatted_prefix else None
12701267

12711268
set_options = petsc_call(
12721269
self.cbbuilder._options_efunc.name, []
@@ -1276,9 +1273,6 @@ def _setup(self):
12761273

12771274
create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])])
12781275

1279-
# NOTE: Assuming all solves are linear for now
1280-
snes_set_type = petsc_call('SNESSetType', [sobjs['snes'], 'SNESKSPONLY'])
1281-
12821276
snes_set_jac = petsc_call(
12831277
'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'],
12841278
sobjs['Jac'], 'MatMFFDComputeJacobian', Null]
@@ -1295,6 +1289,7 @@ def _setup(self):
12951289
local_size = math.prod(
12961290
v for v, dim in zip(target.shape_allocated, target.dimensions) if dim.is_Space
12971291
)
1292+
# TODO: Check, maybe this should be VecCreateSeqWithArray
12981293
local_x = petsc_call('VecCreateMPIWithArray',
12991294
[sobjs['comm'], 1, local_size, 'PETSC_DECIDE',
13001295
field_from_ptr, Byref(sobjs['xlocal'])])
@@ -1310,26 +1305,6 @@ def _setup(self):
13101305
snes_get_ksp = petsc_call('SNESGetKSP',
13111306
[sobjs['snes'], Byref(sobjs['ksp'])])
13121307

1313-
# ksp_set_tols = petsc_call(
1314-
# 'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'],
1315-
# solver_params['ksp_atol'], solver_params['ksp_divtol'],
1316-
# solver_params['ksp_max_it']]
1317-
# )
1318-
1319-
# ksp_set_type = petsc_call(
1320-
# 'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
1321-
# )
1322-
1323-
# TODO: can drop this
1324-
ksp_get_pc = petsc_call(
1325-
'KSPGetPC', [sobjs['ksp'], Byref(sobjs['pc'])]
1326-
)
1327-
1328-
# Even though the default will be jacobi, set to PCNONE for now
1329-
pc_set_type = petsc_call('PCSetType', [sobjs['pc'], 'PCNONE'])
1330-
1331-
ksp_set_from_ops = petsc_call('KSPSetFromOptions', [sobjs['ksp']])
1332-
13331308
matvec = self.cbbuilder.main_matvec_callback
13341309
matvec_operation = petsc_call(
13351310
'MatShellSetOperation',
@@ -1366,16 +1341,11 @@ def _setup(self):
13661341
snes_set_dm,
13671342
create_matrix,
13681343
snes_set_jac,
1369-
snes_set_type,
13701344
global_x,
13711345
local_x,
13721346
get_local_size,
13731347
global_b,
13741348
snes_get_ksp,
1375-
# ksp_set_tols,
1376-
ksp_get_pc,
1377-
pc_set_type,
1378-
ksp_set_from_ops,
13791349
matvec_operation,
13801350
formfunc_operation,
13811351
snes_set_options,
@@ -1400,7 +1370,6 @@ def _create_dmda_calls(self, dmda):
14001370
return dmda_create, dm_setup, dm_mat_type
14011371

14021372
def _create_dmda(self, dmda):
1403-
objs = self.objs
14041373
sobjs = self.solver_objs
14051374
grid = self.field_data.grid
14061375
nspace_dims = len(grid.dimensions)
@@ -1445,23 +1414,22 @@ def _setup(self):
14451414
# TODO: minimise code duplication with superclass
14461415
objs = self.objs
14471416
sobjs = self.solver_objs
1448-
14491417
dmda = sobjs['dmda']
1450-
solver_params = self.inject_solve.expr.rhs.solver_parameters
14511418

14521419
snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])])
14531420

14541421
snes_options_prefix = petsc_call(
1455-
'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snesprefix']]
1456-
) if sobjs['options_prefix'] else None
1422+
'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snes_prefix']]
1423+
) if self.formatted_prefix else None
1424+
1425+
set_options = petsc_call(
1426+
self.cbbuilder._options_efunc.name, []
1427+
)
14571428

14581429
snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda])
14591430

14601431
create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])])
14611432

1462-
# NOTE: Assuming all solves are linear for now
1463-
snes_set_type = petsc_call('SNESSetType', [sobjs['snes'], 'SNESKSPONLY'])
1464-
14651433
snes_set_jac = petsc_call(
14661434
'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'],
14671435
sobjs['Jac'], 'MatMFFDComputeJacobian', Null]
@@ -1478,25 +1446,6 @@ def _setup(self):
14781446
snes_get_ksp = petsc_call('SNESGetKSP',
14791447
[sobjs['snes'], Byref(sobjs['ksp'])])
14801448

1481-
ksp_set_tols = petsc_call(
1482-
'KSPSetTolerances', [sobjs['ksp'], solver_params['ksp_rtol'],
1483-
solver_params['ksp_atol'], solver_params['ksp_divtol'],
1484-
solver_params['ksp_max_it']]
1485-
)
1486-
1487-
ksp_set_type = petsc_call(
1488-
'KSPSetType', [sobjs['ksp'], solver_mapper[solver_params['ksp_type']]]
1489-
)
1490-
1491-
ksp_get_pc = petsc_call(
1492-
'KSPGetPC', [sobjs['ksp'], Byref(sobjs['pc'])]
1493-
)
1494-
1495-
# Even though the default will be jacobi, set to PCNONE for now
1496-
pc_set_type = petsc_call('PCSetType', [sobjs['pc'], 'PCNONE'])
1497-
1498-
ksp_set_from_ops = petsc_call('KSPSetFromOptions', [sobjs['ksp']])
1499-
15001449
matvec = self.cbbuilder.main_matvec_callback
15011450
matvec_operation = petsc_call(
15021451
'MatShellSetOperation',
@@ -1569,19 +1518,14 @@ def _setup(self):
15691518
coupled_setup = dmda_calls + (
15701519
snes_create,
15711520
snes_options_prefix,
1521+
set_options,
15721522
snes_set_dm,
15731523
create_matrix,
15741524
snes_set_jac,
1575-
snes_set_type,
15761525
global_x,
15771526
local_x,
15781527
get_local_size,
15791528
snes_get_ksp,
1580-
ksp_set_tols,
1581-
ksp_set_type,
1582-
ksp_get_pc,
1583-
pc_set_type,
1584-
ksp_set_from_ops,
15851529
matvec_operation,
15861530
formfunc_operation,
15871531
snes_set_options,
@@ -1678,7 +1622,6 @@ def _execute_solve(self):
16781622
Assigns the required time iterators to the struct and executes
16791623
the necessary calls to execute the SNES solver.
16801624
"""
1681-
objs = self.objs
16821625
sobjs = self.solver_objs
16831626
xglob = sobjs['xglobal']
16841627

@@ -1794,10 +1737,10 @@ class TimeDependent(NonTimeDependent):
17941737
for each `SNESSolve` at every time step, don't require the time loop, but
17951738
may still need access to data from other time steps.
17961739
- All `Function` objects are passed through the initial lowering via the
1797-
`LinearSolveExpr` object, ensuring the correct time loop is generated
1740+
`SolveExpr` object, ensuring the correct time loop is generated
17981741
in the main kernel.
17991742
- Another mapper is created based on the modulo dimensions
1800-
generated by the `LinearSolveExpr` object in the main kernel
1743+
generated by the `SolveExpr` object in the main kernel
18011744
(e.g., {time: time, t: t0, t + 1: t1}).
18021745
- These two mappers are used to generate a final mapper `symb_to_moddim`
18031746
(e.g. {tau0: t0, tau1: t1}) which is used at the IET level to

devito/petsc/logging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def section(self):
142142

143143
@property
144144
def summary_key(self):
145-
return (self.section, self.sobjs['options_prefix'])
145+
user_prefix = self.inject_solve.expr.rhs.user_prefix
146+
return (self.section, user_prefix)
146147

147148
def __getattr__(self, attr):
148149
if attr in self.logobjs.keys():

0 commit comments

Comments
 (0)