Skip to content

Commit 0189eb1

Browse files
committed
compiler/dsl: Add get_info functionality to petscsolve
1 parent f2c32f0 commit 0189eb1

7 files changed

Lines changed: 169 additions & 26 deletions

File tree

devito/operator/profiling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,11 @@ def summary(self, args, dtype, params, reduce_over=None):
213213
else:
214214
summary.add(name, None, time)
215215

216+
# Add the language specific summary if necessary
217+
mapper_func = language_summary_mapper.get(self.language)
218+
if mapper_func:
219+
summary.add_language_summary(self.language, mapper_func(params))
220+
216221
return summary
217222

218223

devito/petsc/iet/logging.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,27 @@ class PetscLogger:
1212
"""
1313
Class for PETSc loggers that collect solver related statistics.
1414
"""
15-
def __init__(self, level, **kwargs):
15+
def __init__(self, level, get_info=None, **kwargs):
16+
self.function_list = get_info or []
17+
1618
self.sobjs = kwargs.get('solver_objs')
1719
self.sreg = kwargs.get('sregistry')
20+
1821
self.section_mapper = kwargs.get('section_mapper', {})
1922
self.inject_solve = kwargs.get('inject_solve', None)
2023

21-
self.function_list = []
22-
2324
if level <= PERF:
24-
self.function_list.extend([
25+
funcs = [
2526
# KSP specific
2627
'kspgetiterationnumber',
2728
'kspgettolerances',
2829
'kspgetconvergedreason',
2930
# SNES specific
3031
'snesgetiterationnumber',
31-
])
32+
]
33+
for f in funcs:
34+
if f not in self.function_list:
35+
self.function_list.append(f)
3236

3337
# TODO: To be extended with if level <= DEBUG: ...
3438

@@ -56,7 +60,7 @@ def petsc_option_mapper(self):
5660
>>> self.petsc_option_mapper
5761
{
5862
'KSPGetIterationNumber': {'kspits': kspits0},
59-
'KSPGetTolerances': {'rtol': rtol0, 'abstol': abstol0, ...}
63+
'KSPGetTolerances': {'rtol': rtol0, 'atol': atol0, ...}
6064
}
6165
"""
6266
opts = {}

devito/petsc/iet/passes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def __init__(self, inject_solve, iters, comm, section_mapper, **kwargs):
170170
self.iters = iters
171171
self.comm = comm
172172
self.section_mapper = section_mapper
173+
self.get_info = inject_solve.expr.rhs.get_info
173174
self.kwargs = kwargs
174175
self.coupled = isinstance(inject_solve.expr.rhs.field_data, MultipleFieldData)
175176
self.common_kwargs = {
@@ -217,7 +218,9 @@ def solve(self):
217218
@cached_property
218219
def logger(self):
219220
log_level = dl.logger.level
220-
return PetscLogger(log_level, **self.common_kwargs)
221+
return PetscLogger(
222+
log_level, self.get_info, **self.common_kwargs
223+
)
221224

222225
@cached_property
223226
def calls(self):

devito/petsc/logging.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def __init__(self, **kwargs):
1717
def __getitem__(self, key):
1818
return self._properties[key.lower()]
1919

20+
def __len__(self):
21+
return len(self._properties)
22+
2023
def __repr__(self):
2124
return f"PetscEntry({', '.join(f'{k}={v}' for k, v in self.kwargs.items())})"
2225

@@ -63,7 +66,9 @@ def petsc_entry(self, petscinfo):
6366
Create a named tuple entry for the given PetscInfo object,
6467
containing the values for each PETSc function call.
6568
"""
66-
funcs = self._functions
69+
funcs = [
70+
petsc_return_variable_dict[f].name for f in petscinfo.function_list
71+
]
6772
values = tuple(getattr(petscinfo, c) for c in funcs)
6873
return PetscEntry(**{k: v for k, v in zip(funcs, values)})
6974

@@ -171,7 +176,7 @@ def __getattr__(self, attr):
171176
# return that value directly.
172177
# - If the function returns multiple values (e.g., KSPGetTolerances),
173178
# return a dictionary mapping each output name to its value,
174-
# e.g., {'rtol': val0, 'abstol': val1, ...}.
179+
# e.g., {'rtol': val0, 'atol': val1, ...}.
175180
if len(obj_mapper) == 1:
176181
return get_val(next(iter(obj_mapper.values())))
177182
return {k: get_val(v) for k, v in obj_mapper.items()}
@@ -204,8 +209,7 @@ class PetscReturnVariable:
204209
name='KSPGetTolerances',
205210
variable_type=(PetscScalar, PetscScalar, PetscScalar, PetscInt),
206211
input_params='ksp',
207-
# TODO: check if maxits is max_its in command line
208-
output_param=('rtol', 'atol', 'dtol', 'maxits'),
212+
output_param=('rtol', 'atol', 'divtol', 'max_it'),
209213
),
210214
'kspgetconvergedreason': PetscReturnVariable(
211215
name='KSPGetConvergedReason',

devito/petsc/solve.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414

1515
# TODO: Rename this to petsc_solve, petscsolve?
16-
def PETScSolve(target_exprs, target=None, solver_parameters=None, options_prefix=None):
16+
def PETScSolve(target_exprs, target=None, solver_parameters=None,
17+
options_prefix=None, get_info=None):
1718
"""
1819
Returns a symbolic expression representing a linear PETSc solver,
1920
enriched with all the necessary metadata for execution within an `Operator`.
@@ -45,29 +46,37 @@ def PETScSolve(target_exprs, target=None, solver_parameters=None, options_prefix
4546
solver_parameters : dict, optional
4647
PETSc solver options.
4748
49+
options_prefix : str, optional
50+
Prefix for the solver, used to configure options via the command line. If not
51+
provided, a default prefix is generated by Devito.
52+
53+
get_info : list[str], optional
54+
A list of PETSc API functions to collect statistics from the solver.
55+
For example, `['kspgetiterationnumber', 'kspgettolerances']`.
56+
4857
Returns
4958
-------
5059
Eq:
5160
A symbolic expression that wraps the linear solver.
5261
This can be passed directly to a Devito Operator.
5362
"""
5463
if target is not None:
55-
return InjectSolve(
56-
solver_parameters, {target: target_exprs}, options_prefix
57-
).build_expr()
64+
return InjectSolve(solver_parameters, {target: target_exprs},
65+
options_prefix, get_info).build_expr()
5866
else:
59-
return InjectMixedSolve(
60-
solver_parameters, target_exprs, options_prefix
61-
).build_expr()
67+
return InjectMixedSolve(solver_parameters, target_exprs,
68+
options_prefix, get_info).build_expr()
6269

6370

6471
class InjectSolve:
65-
def __init__(self, solver_parameters=None, target_exprs=None, options_prefix=None):
72+
def __init__(self, solver_parameters=None, target_exprs=None, options_prefix=None,
73+
get_info=None):
6674
self.solver_parameters = linear_solver_parameters(solver_parameters)
6775
self.time_mapper = None
6876
self.target_exprs = target_exprs
6977
self.user_prefix = options_prefix
7078
self.formatted_prefix = format_options_prefix(options_prefix)
79+
self.get_info = get_info
7180

7281
def build_expr(self):
7382
target, funcs, field_data = self.linear_solve_args()
@@ -80,7 +89,8 @@ def build_expr(self):
8089
time_mapper=self.time_mapper,
8190
localinfo=localinfo,
8291
user_prefix=self.user_prefix,
83-
formatted_prefix=self.formatted_prefix
92+
formatted_prefix=self.formatted_prefix,
93+
get_info=self.get_info
8494
)
8595
return PetscEq(target, linear_solve)
8696

devito/petsc/types/types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ class SolveExpr(MetaData):
3939
"""
4040
__rargs__ = ('expr',)
4141
__rkwargs__ = ('solver_parameters', 'field_data', 'time_mapper',
42-
'localinfo', 'user_prefix', 'formatted_prefix')
42+
'localinfo', 'user_prefix', 'formatted_prefix',
43+
'get_info')
4344

4445
def __new__(cls, expr, solver_parameters=None,
4546
field_data=None, time_mapper=None, localinfo=None,
46-
user_prefix=None, formatted_prefix=None, **kwargs):
47+
user_prefix=None, formatted_prefix=None,
48+
get_info=None, **kwargs):
4749

4850
with sympy_mutex:
4951
if isinstance(expr, tuple):
@@ -57,6 +59,7 @@ def __new__(cls, expr, solver_parameters=None,
5759
obj.localinfo = localinfo
5860
obj.user_prefix = user_prefix
5961
obj.formatted_prefix = formatted_prefix
62+
obj.get_info = get_info
6063
return obj
6164

6265
def __repr__(self):

tests/test_petsc.py

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,8 +1448,8 @@ def test_logging(self, log_level):
14481448
tols = entry0.KSPGetTolerances
14491449
assert tols['rtol'] == linear_solve_defaults['ksp_rtol']
14501450
assert tols['atol'] == linear_solve_defaults['ksp_atol']
1451-
assert tols['dtol'] == linear_solve_defaults['ksp_divtol']
1452-
assert tols['maxits'] == linear_solve_defaults['ksp_max_it']
1451+
assert tols['divtol'] == linear_solve_defaults['ksp_divtol']
1452+
assert tols['max_it'] == linear_solve_defaults['ksp_max_it']
14531453

14541454
@skipif('petsc')
14551455
@pytest.mark.parametrize('log_level', ['PERF', 'DEBUG'])
@@ -1655,8 +1655,8 @@ def test_tolerances(self, log_level):
16551655
# appear as expected in the `PetscSummary`.
16561656
assert tolerances['rtol'] == params['ksp_rtol']
16571657
assert tolerances['atol'] == params['ksp_atol']
1658-
assert tolerances['dtol'] == params['ksp_divtol']
1659-
assert tolerances['maxits'] == params['ksp_max_it']
1658+
assert tolerances['divtol'] == params['ksp_divtol']
1659+
assert tolerances['max_it'] == params['ksp_max_it']
16601660

16611661
@skipif('petsc')
16621662
def test_clearing_options(self):
@@ -1805,3 +1805,117 @@ def test_solveexpr(self):
18051805
options_prefix='poisson3'
18061806
)
18071807
assert hash(petsc3.rhs) != hash(petsc4.rhs)
1808+
1809+
1810+
class TestGetInfo:
1811+
"""
1812+
Test the `get_info` optional argument to `PETScSolve`.
1813+
1814+
This argument can be used independently of the `log_level` to retrieve
1815+
specific information about the solve, such as the number of KSP
1816+
iterations to converge.
1817+
"""
1818+
@skipif('petsc')
1819+
def test_get_info(self):
1820+
1821+
grid = Grid(shape=(11, 11), dtype=np.float64)
1822+
functions = [Function(name=n, grid=grid, space_order=2)
1823+
for n in ['e', 'f']]
1824+
e, f = functions
1825+
eq = Eq(e.laplace, f)
1826+
1827+
get_info = ['kspgetiterationnumber', 'snesgetiterationnumber']
1828+
1829+
petsc = PETScSolve(
1830+
eq, target=e, options_prefix='pde1', get_info=get_info
1831+
)
1832+
1833+
with switchconfig(language='petsc'):
1834+
op = Operator(petsc)
1835+
summary = op.apply()
1836+
1837+
petsc_summary = summary.petsc
1838+
entry = petsc_summary.get_entry('section0', 'pde1')
1839+
1840+
# Verify that the entry contains only the requested info
1841+
# (since logging is not set)
1842+
assert len(entry) == 2
1843+
assert hasattr(entry, "KSPGetIterationNumber")
1844+
assert hasattr(entry, "SNESGetIterationNumber")
1845+
1846+
@skipif('petsc')
1847+
@pytest.mark.parametrize('log_level', ['PERF', 'DEBUG'])
1848+
def test_get_info_with_logging(self, log_level):
1849+
"""
1850+
Test that `get_info` works correctly when logging is enabled.
1851+
"""
1852+
grid = Grid(shape=(11, 11), dtype=np.float64)
1853+
functions = [Function(name=n, grid=grid, space_order=2)
1854+
for n in ['e', 'f']]
1855+
e, f = functions
1856+
eq = Eq(e.laplace, f)
1857+
1858+
get_info = ['kspgetiterationnumber']
1859+
1860+
petsc = PETScSolve(
1861+
eq, target=e, options_prefix='pde1', get_info=get_info
1862+
)
1863+
1864+
with switchconfig(language='petsc', log_level=log_level):
1865+
op = Operator(petsc)
1866+
summary = op.apply()
1867+
1868+
petsc_summary = summary.petsc
1869+
entry = petsc_summary.get_entry('section0', 'pde1')
1870+
1871+
# With logging enabled, the entry should include both the
1872+
# requested KSP iteration number and additional PETSc info
1873+
# (e.g., SNES iteration count logged at PERF/DEBUG).
1874+
assert len(entry) > 1
1875+
assert hasattr(entry, "KSPGetIterationNumber")
1876+
assert hasattr(entry, "SNESGetIterationNumber")
1877+
1878+
@skipif('petsc')
1879+
def test_different_solvers(self):
1880+
"""
1881+
Test that `get_info` works correctly when multiple solvers are used
1882+
within the same Operator.
1883+
"""
1884+
grid = Grid(shape=(11, 11), dtype=np.float64)
1885+
functions = [Function(name=n, grid=grid, space_order=2)
1886+
for n in ['e', 'f', 'g', 'h']]
1887+
e, f, g, h = functions
1888+
1889+
eq1 = Eq(e.laplace, f)
1890+
eq2 = Eq(g.laplace, h)
1891+
1892+
# Create two PETScSolve instances with different get_info arguments
1893+
1894+
get_info_1 = ['kspgetiterationnumber']
1895+
get_info_2 = ['snesgetiterationnumber']
1896+
1897+
solver1 = PETScSolve(
1898+
eq1, target=e, options_prefix='pde1', get_info=get_info_1
1899+
)
1900+
solver2 = PETScSolve(
1901+
eq2, target=g, options_prefix='pde2', get_info=get_info_2
1902+
)
1903+
1904+
with switchconfig(language='petsc'):
1905+
op = Operator([solver1, solver2])
1906+
summary = op.apply()
1907+
1908+
petsc_summary = summary.petsc
1909+
1910+
assert len(petsc_summary) == 2
1911+
assert len(petsc_summary.KSPGetIterationNumber) == 1
1912+
assert len(petsc_summary.SNESGetIterationNumber) == 1
1913+
1914+
entry1 = petsc_summary.get_entry('section0', 'pde1')
1915+
entry2 = petsc_summary.get_entry('section1', 'pde2')
1916+
1917+
assert hasattr(entry1, "KSPGetIterationNumber")
1918+
assert not hasattr(entry1, "SNESGetIterationNumber")
1919+
1920+
assert not hasattr(entry2, "KSPGetIterationNumber")
1921+
assert hasattr(entry2, "SNESGetIterationNumber")

0 commit comments

Comments
 (0)