Skip to content

Commit 7c69c4d

Browse files
authored
dsl/compiler: Extend solver_parameters handling in PETScSolve and update logging (#2718)
* compiler: Utilise petsctools and start extending solver parameters functionality examples: Add example using argparse in conjunction with PetscInitialize * compiler: Edit argument to PetscInitialize * compiler: Use petsctools to process solver params and start callback to set petsc options * compiler: Progress on petscoptions callbacks * dsl/compiler: Re-factor solver params, add solver_parameters.py file * misc: Add tests * compiler: Start extending the PetscSummary * dsl: Fix hashing for solveexpr * misc: Add tests and clean up * compiler: Add utility function inside petsc routines * misc: Clean up * tests: Add command line tests with random prefixes * compiler: Create dummy op for petscgetargs * misc: Update requirements * misc: Create getargs Op inside function * compiler/dsl: Add get_info functionality to petscsolve * compiler/misc: Drop petscgetargs callback, add functions to logging/get_info, clean up and more tests * workflows: Fix serial wf run
1 parent 4139a3f commit 7c69c4d

17 files changed

Lines changed: 1244 additions & 472 deletions

File tree

.github/workflows/pytest-petsc.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,13 @@ jobs:
7272
run: |
7373
${{ env.RUN_CMD }} python3 -c "from devito import configuration; print(''.join(['%s: %s \n' % (k, v) for (k, v) in configuration.items()]))"
7474
75-
- name: Test with pytest
75+
- name: Test with pytest - serial
7676
run: |
77-
${{ env.RUN_CMD }} mpiexec -n 1 pytest --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }}
77+
${{ env.RUN_CMD }} mpiexec -n 1 pytest -m "not parallel" --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }}
78+
79+
- name: Test with pytest - parallel
80+
run: |
81+
${{ env.RUN_CMD }} python3 -m pytest --cov --cov-config=.coveragerc --cov-report=xml -m parallel ${{ env.TESTS }}
7882
7983
- name: Test examples
8084
run: |

devito/operator/profiling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ def summary(self, args, dtype, params, reduce_over=None):
213213
else:
214214
summary.add(name, None, time)
215215

216+
# Add the language-specific summary if necessary
217+
mapper_func = language_summary_mapper.get(self.language)
218+
if mapper_func:
219+
summary.add_language_summary(self.language, mapper_func(params))
220+
216221
return summary
217222

218223

@@ -342,7 +347,7 @@ def summary(self, args, dtype, params, reduce_over=None):
342347
# data transfers)
343348
summary.add_glb_fdlike('fdlike-nosetup', points, reduce_over_nosetup)
344349

345-
# Add the language specific summary if necessary
350+
# Add the language-specific summary if necessary
346351
mapper_func = language_summary_mapper.get(self.language)
347352
if mapper_func:
348353
summary.add_language_summary(self.language, mapper_func(params))

devito/petsc/clusters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from devito.tools import timed_pass
2-
from devito.petsc.types import LinearSolveExpr
2+
from devito.petsc.types import SolverMetaData
33

44

55
@timed_pass()
@@ -19,7 +19,7 @@ def petsc_lift(clusters):
1919
"""
2020
processed = []
2121
for c in clusters:
22-
if isinstance(c.exprs[0].rhs, LinearSolveExpr):
22+
if isinstance(c.exprs[0].rhs, SolverMetaData):
2323
ispace = c.ispace.lift(c.exprs[0].rhs.field_data.space_dimensions)
2424
processed.append(c.rebuild(ispace=ispace))
2525
else:

devito/petsc/iet/logging.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from devito.symbolics import Byref, FieldFromPointer
44
from devito.ir.iet import DummyExpr
55
from devito.logger import PERF
6+
from devito.tools import frozendict
67

78
from devito.petsc.iet.utils import petsc_call
89
from devito.petsc.logging import petsc_return_variable_dict, PetscInfo
@@ -12,43 +13,66 @@ class PetscLogger:
1213
"""
1314
Class for PETSc loggers that collect solver related statistics.
1415
"""
16+
# TODO: Update docstring with kwargs
1517
def __init__(self, level, **kwargs):
18+
19+
self.query_functions = kwargs.get('get_info', [])
1620
self.sobjs = kwargs.get('solver_objs')
1721
self.sreg = kwargs.get('sregistry')
1822
self.section_mapper = kwargs.get('section_mapper', {})
1923
self.inject_solve = kwargs.get('inject_solve', None)
2024

21-
self.function_list = []
22-
2325
if level <= PERF:
24-
self.function_list.extend([
26+
funcs = [
27+
# KSP specific
2528
'kspgetiterationnumber',
26-
'snesgetiterationnumber'
27-
])
29+
'kspgettolerances',
30+
'kspgetconvergedreason',
31+
'kspgettype',
32+
'kspgetnormtype',
33+
# SNES specific
34+
'snesgetiterationnumber',
35+
]
36+
self.query_functions = set(self.query_functions)
37+
self.query_functions.update(funcs)
38+
self.query_functions = sorted(list(self.query_functions))
2839

2940
# TODO: To be extended with if level <= DEBUG: ...
3041

3142
name = self.sreg.make_name(prefix='petscinfo')
3243
pname = self.sreg.make_name(prefix='petscprofiler')
3344

3445
self.statstruct = PetscInfo(
35-
name, pname, self.logobjs, self.sobjs,
46+
name, pname, self.petsc_option_mapper, self.sobjs,
3647
self.section_mapper, self.inject_solve,
37-
self.function_list
48+
self.query_functions
3849
)
3950

4051
@cached_property
41-
def logobjs(self):
52+
def petsc_option_mapper(self):
4253
"""
43-
Create PETSc objects specifically needed for logging solver statistics.
44-
"""
45-
return {
46-
info.name: info.variable_type(
47-
self.sreg.make_name(prefix=info.output_param)
48-
)
49-
for func_name in self.function_list
50-
for info in [petsc_return_variable_dict[func_name]]
54+
For each function in `self.query_functions`, look up its metadata in
55+
`petsc_return_variable_dict` and instantiate the corresponding PETSc logging
56+
variables with names from the symbol registry.
57+
58+
Example:
59+
--------
60+
>>> self.query_functions
61+
['kspgetiterationnumber', 'snesgetiterationnumber', 'kspgettolerances']
62+
63+
>>> self.petsc_option_mapper
64+
{
65+
'KSPGetIterationNumber': {'kspits': kspits0},
66+
'KSPGetTolerances': {'rtol': rtol0, 'atol': atol0, ...}
5167
}
68+
"""
69+
opts = {}
70+
for func_name in self.query_functions:
71+
info = petsc_return_variable_dict[func_name]
72+
opts[info.name] = {}
73+
for vtype, out in zip(info.variable_type, info.output_param, strict=True):
74+
opts[info.name][out] = vtype(self.sreg.make_name(prefix=out))
75+
return frozendict(opts)
5276

5377
@cached_property
5478
def calls(self):
@@ -58,20 +82,21 @@ def calls(self):
5882
"""
5983
struct = self.statstruct
6084
calls = []
61-
for param in self.function_list:
62-
param = petsc_return_variable_dict[param]
63-
64-
inputs = []
65-
for i in param.input_params:
66-
inputs.append(self.sobjs[i])
85+
for func_name in self.query_functions:
86+
return_variable = petsc_return_variable_dict[func_name]
6787

68-
logobj = self.logobjs[param.name]
88+
input = self.sobjs[return_variable.input_params]
89+
output_params = self.petsc_option_mapper[return_variable.name].values()
90+
by_ref_output = [Byref(i) for i in output_params]
6991

7092
calls.append(
71-
petsc_call(param.name, inputs + [Byref(logobj)])
93+
petsc_call(return_variable.name, [input] + by_ref_output)
7294
)
7395
# TODO: Perform a PetscCIntCast here?
74-
expr = DummyExpr(FieldFromPointer(logobj._C_symbol, struct), logobj._C_symbol)
75-
calls.append(expr)
96+
exprs = [
97+
DummyExpr(FieldFromPointer(i._C_symbol, struct), i._C_symbol)
98+
for i in output_params
99+
]
100+
calls.extend(exprs)
76101

77102
return tuple(calls)

devito/petsc/iet/passes.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from devito.ir.iet import (Transformer, MapNodes, Iteration, BlankLine,
77
DummyExpr, CallableBody, List, Call, Callable,
88
FindNodes, Section)
9-
from devito.symbolics import Byref, Macro, FieldFromPointer
9+
from devito.symbolics import Byref, FieldFromPointer, Macro, Null
1010
from devito.types import Symbol, Scalar
1111
from devito.types.basic import DataSymbol
1212
from devito.tools import frozendict
13+
import devito.logger
14+
1315
from devito.petsc.types import (PetscMPIInt, PetscErrorCode, MultipleFieldData,
1416
PointerIS, Mat, CallbackVec, Vec, CallbackMat, SNES,
1517
DummyArg, PetscInt, PointerDM, PointerMat, MatReuse,
@@ -25,8 +27,6 @@
2527
from devito.petsc.iet.logging import PetscLogger
2628
from devito.petsc.iet.utils import petsc_call, petsc_call_mpi
2729

28-
import devito.logger as dl
29-
3030

3131
@iet_pass
3232
def lower_petsc(iet, **kwargs):
@@ -69,9 +69,24 @@ def lower_petsc(iet, **kwargs):
6969
# Map PETScSolve to its Section (for logging)
7070
section_mapper = MapNodes(Section, PetscMetaData, 'groupby').visit(iet)
7171

72+
# Prefixes within the same `Operator` should not be duplicated
73+
prefixes = [d.expr.rhs.user_prefix for d in data if d.expr.rhs.user_prefix]
74+
duplicates = {p for p in prefixes if prefixes.count(p) > 1}
75+
76+
if duplicates:
77+
dup_list = ", ".join(repr(p) for p in sorted(duplicates))
78+
raise ValueError(
79+
f"The following `options_prefix` values are duplicated "
80+
f"among your PETScSolves. Ensure each one is unique: {dup_list}"
81+
)
82+
83+
# List of `Call`s to clear options from the global PETSc options database,
84+
# executed at the end of the Operator.
85+
clear_options = []
86+
7287
for iters, (inject_solve,) in inject_solve_mapper.items():
7388

74-
builder = Builder(inject_solve, objs, iters, comm, section_mapper, **kwargs)
89+
builder = Builder(inject_solve, iters, comm, section_mapper, **kwargs)
7590

7691
setup.extend(builder.solver_setup.calls)
7792

@@ -80,11 +95,13 @@ def lower_petsc(iet, **kwargs):
8095

8196
efuncs.update(builder.cbbuilder.efuncs)
8297

83-
populate_matrix_context(efuncs, objs)
98+
clear_options.extend((petsc_call(
99+
builder.cbbuilder._clear_options_efunc.name, []
100+
),))
84101

102+
populate_matrix_context(efuncs)
85103
iet = Transformer(subs).visit(iet)
86-
87-
body = core + tuple(setup) + iet.body.body
104+
body = core + tuple(setup) + iet.body.body + tuple(clear_options)
88105
body = iet.body._rebuild(body=body)
89106
iet = iet._rebuild(body=body)
90107
metadata = {**core_metadata(), 'efuncs': tuple(efuncs.values())}
@@ -131,12 +148,13 @@ class Builder:
131148
returning subclasses of the objects initialised in __init__,
132149
depending on the properties of `inject_solve`.
133150
"""
134-
def __init__(self, inject_solve, objs, iters, comm, section_mapper, **kwargs):
151+
def __init__(self, inject_solve, iters, comm, section_mapper, **kwargs):
135152
self.inject_solve = inject_solve
136153
self.objs = objs
137154
self.iters = iters
138155
self.comm = comm
139156
self.section_mapper = section_mapper
157+
self.get_info = inject_solve.expr.rhs.get_info
140158
self.kwargs = kwargs
141159
self.coupled = isinstance(inject_solve.expr.rhs.field_data, MultipleFieldData)
142160
self.common_kwargs = {
@@ -183,15 +201,17 @@ def solve(self):
183201

184202
@cached_property
185203
def logger(self):
186-
log_level = dl.logger.level
187-
return PetscLogger(log_level, **self.common_kwargs)
204+
log_level = devito.logger.logger.level
205+
return PetscLogger(
206+
log_level, get_info=self.get_info, **self.common_kwargs
207+
)
188208

189209
@cached_property
190210
def calls(self):
191211
return List(body=self.solve.calls+self.logger.calls)
192212

193213

194-
def populate_matrix_context(efuncs, objs):
214+
def populate_matrix_context(efuncs):
195215
if not objs['dummyefunc'] in efuncs.values():
196216
return
197217

@@ -205,7 +225,7 @@ def populate_matrix_context(efuncs, objs):
205225
)
206226
body = CallableBody(
207227
List(body=[subdms_expr, fields_expr]),
208-
init=(objs['begin_user'],),
228+
init=(petsc_func_begin_user,),
209229
retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])])
210230
)
211231
name = 'PopulateMatContext'
@@ -262,13 +282,8 @@ def populate_matrix_context(efuncs, objs):
262282
fields=[subdms, fields, submats], modifier=' *'
263283
),
264284
'subctx': SubMatrixStruct(fields=[rows, cols]),
265-
'Null': Macro('NULL'),
266285
'dummyctx': Symbol('lctx'),
267286
'dummyptr': DummyArg('dummy'),
268287
'dummyefunc': Symbol('dummyefunc'),
269288
'dof': PetscInt('dof'),
270-
'begin_user': c.Line('PetscFunctionBeginUser;'),
271289
})
272-
273-
# Move to macros file?
274-
Null = Macro('NULL')

0 commit comments

Comments
 (0)