Skip to content

Commit aa15de0

Browse files
committed
compiler/dsl: Extend get_info functions and add tests
1 parent 7ec8bf0 commit aa15de0

7 files changed

Lines changed: 126 additions & 66 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _gen_struct_decl(self, obj, masked=()):
292292
for i, (n, ct) in zip(fields, ctype._fields_):
293293
try:
294294
entries.append(self._gen_value(i, 0, masked=('const',)))
295-
295+
296296
except AttributeError:
297297
cstr = self.ccode(ct)
298298
if ct is c_restrict_void_p:

devito/petsc/iet/logging.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ class PetscLogger:
1212
"""
1313
Class for PETSc loggers that collect solver related statistics.
1414
"""
15-
def __init__(self, level, get_info=None, **kwargs):
16-
self.function_list = get_info or []
15+
def __init__(self, level, get_info=[], **kwargs):
16+
self.function_list = get_info
1717

1818
self.sobjs = kwargs.get('solver_objs')
1919
self.sreg = kwargs.get('sregistry')
@@ -28,6 +28,8 @@ def __init__(self, level, get_info=None, **kwargs):
2828
'kspgettolerances',
2929
'kspgetconvergedreason',
3030
'kspgettype',
31+
'kspgettype',
32+
'kspgetnormtype',
3133
# SNES specific
3234
'snesgetiterationnumber',
3335
]

devito/petsc/logging.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from collections import namedtuple, OrderedDict
1+
import os
2+
from collections import namedtuple
23
from dataclasses import dataclass
3-
from functools import cached_property
4-
from cgen import Struct, Value
54

65
from devito.types import CompositeObject
76

8-
from devito.petsc.types import PetscInt, PetscScalar, KSPType
7+
from devito.petsc.types import (
8+
PetscInt, PetscScalar, KSPType, KSPConvergedReason, KSPNormType
9+
)
910
from devito.petsc.utils import petsc_type_mappings, fixed_petsc_type_mappings
1011

1112

@@ -79,18 +80,18 @@ def _add_properties(self):
7980
For each function name in `self._functions` (e.g., 'KSPGetIterationNumber'),
8081
dynamically add a property to the class with the same name.
8182
82-
Each property returns an OrderedDict that maps each PetscKey to the
83+
Each property returns a dict mapping each PetscKey to the
8384
result of looking up that function on the corresponding PetscEntry,
8485
if the function exists on that entry.
8586
"""
8687
def make_property(function):
8788
def getter(self):
88-
return OrderedDict(
89-
(k, getattr(v, function))
89+
return {
90+
k: getattr(v, function)
9091
for k, v in self.items()
9192
# Only include entries that have the function
9293
if hasattr(v, function)
93-
)
94+
}
9495
return property(getter)
9596

9697
for f in self._functions:
@@ -141,27 +142,25 @@ def __init__(self, name, pname, petsc_option_mapper, sobjs, section_mapper,
141142
self.formatted_prefix = inject_solve.expr.rhs.formatted_prefix
142143

143144
mapper = {v: k for k, v in petsc_type_mappings.items()}
144-
145145
pfields = []
146-
self._tmp_fields = []
147-
148-
# obj_mapper is e.g {'kspits': kspits0}
149-
# TODO: think this can just be for (x,y) in ....
150-
for obj_mapper in petsc_option_mapper.values():
151-
for petsc_option in obj_mapper.values():
152-
self._tmp_fields.append(petsc_option)
153-
# petsc_type is e.g. 'PetscInt', 'PetscScalar', 'KSPType'
154-
petsc_type = str(petsc_option.dtype)
155-
if petsc_type in mapper:
156-
ctype = mapper[petsc_type]
157-
else:
158-
ctype = fixed_petsc_type_mappings[petsc_type]
159-
pfields.append((petsc_option.name, ctype))
146+
147+
# All petsc options needed to form the PetscInfo struct
148+
# e.g (kspits0, rtol0, atol0, ...)
149+
self._fields = [i for j in petsc_option_mapper.values() for i in j.values()]
150+
151+
for petsc_option in self._fields:
152+
# petsc_type is e.g. 'PetscInt', 'PetscScalar', 'KSPType'
153+
petsc_type = str(petsc_option.dtype)
154+
if petsc_type in mapper:
155+
ctype = mapper[petsc_type]
156+
else:
157+
ctype = fixed_petsc_type_mappings[petsc_type]
158+
pfields.append((petsc_option.name, ctype))
160159
super().__init__(name, pname, pfields)
161160

162161
@property
163162
def fields(self):
164-
return self._tmp_fields
163+
return self._fields
165164

166165
@property
167166
def section(self):
@@ -184,8 +183,10 @@ def __getattr__(self, attr):
184183
# Maps the petsc_option to its generated variable name e.g {'its': its0}
185184
obj_mapper = self.petsc_option_mapper[attr]
186185

187-
# Helper to get the value from the profiling struct
188-
get_val = lambda v: getattr(self.value._obj, v.name)
186+
# Decode the value if it is a bytes object
187+
def decode_if_bytes(val):
188+
return str(os.fsdecode(val)) if isinstance(val, bytes) else val
189+
get_val = lambda v: decode_if_bytes(getattr(self.value._obj, v.name))
189190

190191
# - If the function returns a single value (e.g., KSPGetIterationNumber),
191192
# return that value directly.
@@ -228,7 +229,7 @@ class PetscReturnVariable:
228229
),
229230
'kspgetconvergedreason': PetscReturnVariable(
230231
name='KSPGetConvergedReason',
231-
variable_type=(PetscInt,),
232+
variable_type=(KSPConvergedReason,),
232233
input_params='ksp',
233234
output_param=('reason',),
234235
),
@@ -238,7 +239,13 @@ class PetscReturnVariable:
238239
input_params='ksp',
239240
output_param=('ksptype',),
240241
),
241-
# SNES specific
242+
'kspgetnormtype': PetscReturnVariable(
243+
name='KSPGetNormType',
244+
variable_type=(KSPNormType,),
245+
input_params='ksp',
246+
output_param=('kspnormtype',),
247+
),
248+
# SNES specific -> will be extended when non-linear solvers are supported
242249
'snesgetiterationnumber': PetscReturnVariable(
243250
name='SNESGetIterationNumber',
244251
variable_type=(PetscInt,),

devito/petsc/solve.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# TODO: Rename this to petsc_solve, petscsolve?
1616
def PETScSolve(target_exprs, target=None, solver_parameters=None,
17-
options_prefix=None, get_info=None):
17+
options_prefix=None, get_info=[]):
1818
"""
1919
Returns a symbolic expression representing a linear PETSc solver,
2020
enriched with all the necessary metadata for execution within an `Operator`.
@@ -68,13 +68,13 @@ def PETScSolve(target_exprs, target=None, solver_parameters=None,
6868
6969
get_info : list[str], optional
7070
A list of PETSc API functions to collect statistics from the solver.
71-
For example, `['kspgetiterationnumber', 'kspgettolerances']`.
71+
For example, `['kspgetiterationnumber', 'kspgettolerances']`.
7272
Capitalisation does not matter; e.g. `'KSPGetIterationNumber'` and
7373
`'kspgetiterationnumber'` are treated the same.
7474
7575
List of available functions:
7676
- ['kspgetiterationnumber', 'kspgettolerances', 'kspgetconvergedreason',
77-
'snesgetiterationnumber']
77+
'kspgettype', 'kspgetnormtype', 'snesgetiterationnumber']
7878
7979
Returns
8080
-------
@@ -92,7 +92,7 @@ def PETScSolve(target_exprs, target=None, solver_parameters=None,
9292

9393
class InjectSolve:
9494
def __init__(self, solver_parameters=None, target_exprs=None, options_prefix=None,
95-
get_info=None):
95+
get_info=[]):
9696
self.solver_parameters = linear_solver_parameters(solver_parameters)
9797
self.time_mapper = None
9898
self.target_exprs = target_exprs

devito/petsc/types/object.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class KSPType(PetscObject):
117117
dtype = CustomDtype('KSPType')
118118

119119

120+
class KSPNormType(PetscObject):
121+
dtype = CustomDtype('KSPNormType')
122+
123+
120124
class CallbackSNES(PetscObject):
121125
"""
122126
PETSc SNES : Non-Linear Systems Solvers.

devito/petsc/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,13 @@ def get_petsc_type_mappings():
9595
petsc_type_mappings = get_petsc_type_mappings()
9696

9797

98+
# NOTE: These mappings are only used when constructing ctypes.Structures
99+
# that wrap PETSc objects. In the generated C code, the fields will still
100+
# appear as the actual PETSc types.
98101
fixed_petsc_type_mappings = {
99102
'KSPType': ctypes.c_char_p,
103+
'KSPConvergedReason': ctypes.c_int,
104+
'KSPNormType': ctypes.c_int,
100105
}
101106

102107

tests/test_petsc.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import os
55
import re
6-
from collections import OrderedDict
76

87
from conftest import skipif
98
from devito import (Grid, Function, TimeFunction, Eq, Operator,
@@ -1437,7 +1436,6 @@ def test_logging(self, log_level):
14371436

14381437
assert snesits0 == snesits1 == snesits2 == snesits3
14391438

1440-
assert isinstance(snesits0, OrderedDict)
14411439
assert len(snesits0) == 1
14421440
key, value = next(iter(snesits0.items()))
14431441
assert str(key) == "PetscKey(name='section0', options_prefix='poisson')"
@@ -1813,20 +1811,25 @@ class TestGetInfo:
18131811
iterations to converge.
18141812
"""
18151813
@skipif('petsc')
1816-
def test_get_info(self):
1817-
1814+
def setup_class(self):
1815+
"""
1816+
Setup grid, functions and equations shared across
1817+
tests in this class
1818+
"""
18181819
grid = Grid(shape=(11, 11), dtype=np.float64)
1819-
functions = [Function(name=n, grid=grid, space_order=2)
1820-
for n in ['e', 'f']]
1821-
e, f = functions
1822-
eq = Eq(e.laplace, f)
1820+
self.e, self.f, self.g, self.h = [
1821+
Function(name=n, grid=grid, space_order=2)
1822+
for n in ['e', 'f', 'g', 'h']
1823+
]
1824+
self.eq1 = Eq(self.e.laplace, self.f)
1825+
self.eq2 = Eq(self.g.laplace, self.h)
18231826

1827+
@skipif('petsc')
1828+
def test_get_info(self):
18241829
get_info = ['kspgetiterationnumber', 'snesgetiterationnumber']
1825-
18261830
petsc = PETScSolve(
1827-
eq, target=e, options_prefix='pde1', get_info=get_info
1831+
self.eq1, target=self.e, options_prefix='pde1', get_info=get_info
18281832
)
1829-
18301833
with switchconfig(language='petsc'):
18311834
op = Operator(petsc)
18321835
summary = op.apply()
@@ -1846,18 +1849,10 @@ def test_get_info_with_logging(self, log_level):
18461849
"""
18471850
Test that `get_info` works correctly when logging is enabled.
18481851
"""
1849-
grid = Grid(shape=(11, 11), dtype=np.float64)
1850-
functions = [Function(name=n, grid=grid, space_order=2)
1851-
for n in ['e', 'f']]
1852-
e, f = functions
1853-
eq = Eq(e.laplace, f)
1854-
18551852
get_info = ['kspgetiterationnumber']
1856-
18571853
petsc = PETScSolve(
1858-
eq, target=e, options_prefix='pde1', get_info=get_info
1854+
self.eq1, target=self.e, options_prefix='pde1', get_info=get_info
18591855
)
1860-
18611856
with switchconfig(language='petsc', log_level=log_level):
18621857
op = Operator(petsc)
18631858
summary = op.apply()
@@ -1878,26 +1873,17 @@ def test_different_solvers(self):
18781873
Test that `get_info` works correctly when multiple solvers are used
18791874
within the same Operator.
18801875
"""
1881-
grid = Grid(shape=(11, 11), dtype=np.float64)
1882-
functions = [Function(name=n, grid=grid, space_order=2)
1883-
for n in ['e', 'f', 'g', 'h']]
1884-
e, f, g, h = functions
1885-
1886-
eq1 = Eq(e.laplace, f)
1887-
eq2 = Eq(g.laplace, h)
1888-
18891876
# Create two PETScSolve instances with different get_info arguments
18901877

18911878
get_info_1 = ['kspgetiterationnumber']
18921879
get_info_2 = ['snesgetiterationnumber']
18931880

18941881
solver1 = PETScSolve(
1895-
eq1, target=e, options_prefix='pde1', get_info=get_info_1
1882+
self.eq1, target=self.e, options_prefix='pde1', get_info=get_info_1
18961883
)
18971884
solver2 = PETScSolve(
1898-
eq2, target=g, options_prefix='pde2', get_info=get_info_2
1885+
self.eq2, target=self.g, options_prefix='pde2', get_info=get_info_2
18991886
)
1900-
19011887
with switchconfig(language='petsc'):
19021888
op = Operator([solver1, solver2])
19031889
summary = op.apply()
@@ -1916,3 +1902,59 @@ def test_different_solvers(self):
19161902

19171903
assert not hasattr(entry2, "KSPGetIterationNumber")
19181904
assert hasattr(entry2, "SNESGetIterationNumber")
1905+
1906+
@skipif('petsc')
1907+
def test_case_insensitive(self):
1908+
"""
1909+
Test that `get_info` is case insensitive
1910+
"""
1911+
# Create a list with mixed cases
1912+
get_info = ['KSPGetIterationNumber', 'snesgetiterationnumber']
1913+
petsc = PETScSolve(
1914+
self.eq1, target=self.e, options_prefix='pde1', get_info=get_info
1915+
)
1916+
with switchconfig(language='petsc'):
1917+
op = Operator(petsc)
1918+
summary = op.apply()
1919+
1920+
petsc_summary = summary.petsc
1921+
entry = petsc_summary.get_entry('section0', 'pde1')
1922+
1923+
assert hasattr(entry, "KSPGetIterationNumber")
1924+
assert hasattr(entry, "SNESGetIterationNumber")
1925+
1926+
@skipif('petsc')
1927+
def test_get_ksp_type(self):
1928+
"""
1929+
Test that `get_info` can retrieve the KSP type as
1930+
a string.
1931+
"""
1932+
get_info = ['kspgettype']
1933+
solver1 = PETScSolve(
1934+
self.eq1, target=self.e, options_prefix='poisson1', get_info=get_info
1935+
)
1936+
solver2 = PETScSolve(
1937+
self.eq1, target=self.e, options_prefix='poisson2',
1938+
solver_parameters={'ksp_type': 'cg'}, get_info=get_info
1939+
)
1940+
with switchconfig(language='petsc'):
1941+
op = Operator([solver1, solver2])
1942+
summary = op.apply()
1943+
1944+
petsc_summary = summary.petsc
1945+
entry1 = petsc_summary.get_entry('section0', 'poisson1')
1946+
entry2 = petsc_summary.get_entry('section1', 'poisson2')
1947+
1948+
assert hasattr(entry1, "KSPGetType")
1949+
# Check the type matches the default in linear_solve_defaults
1950+
# since it has not been overridden
1951+
assert entry1.KSPGetType == linear_solve_defaults['ksp_type']
1952+
assert entry1['KSPGetType'] == linear_solve_defaults['ksp_type']
1953+
assert entry1['kspgettype'] == linear_solve_defaults['ksp_type']
1954+
1955+
# Test that the KSP type default is correctly overridden by the
1956+
# solver_parameters dictionary passed to solver2
1957+
assert hasattr(entry2, "KSPGetType")
1958+
assert entry2.KSPGetType == 'cg'
1959+
assert entry2['KSPGetType'] == 'cg'
1960+
assert entry2['kspgettype'] == 'cg'

0 commit comments

Comments
 (0)