Skip to content

Commit ce92e4d

Browse files
committed
misc: Clean up
1 parent 2bc855d commit ce92e4d

7 files changed

Lines changed: 55 additions & 39 deletions

File tree

.github/workflows/pytest-petsc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jobs:
7272
run: |
7373
${{ env.RUN_CMD }} python3 -c "from devito import configuration; print(''.join(['%s: %s \n' % (k, v) for (k, v) in configuration.items()]))"
7474
75-
- name: Test with pytest - non parallel
75+
- name: Test with pytest - serial
7676
run: |
7777
${{ env.RUN_CMD }} mpiexec -n 1 pytest -m "not parallel" --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }}
7878

devito/petsc/iet/logging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ def __init__(self, level, **kwargs):
2222

2323
if level <= PERF:
2424
self.function_list.extend([
25+
# KSP specific
2526
'kspgetiterationnumber',
27+
'kspgettolerances',
28+
'kspgetconvergedreason',
29+
# SNES specific
2630
'snesgetiterationnumber',
27-
'kspgettolerances'
2831
])
2932

3033
# TODO: To be extended with if level <= DEBUG: ...

devito/petsc/iet/passes.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,21 @@ def lower_petsc(iet, **kwargs):
6969
# Map PETScSolve to its Section (for logging)
7070
section_mapper = MapNodes(Section, PetscMetaData, 'groupby').visit(iet)
7171

72-
# Prefixes within a single Operator should not be duplicated
72+
# Prefixes within the same Operator should not be duplicated
7373
prefixes = [d.expr.rhs.user_prefix for d in data if d.expr.rhs.user_prefix]
7474
duplicates = {p for p in prefixes if prefixes.count(p) > 1}
7575

76-
# TODO: Avoid the other exception raised given it is an iet_pass?
76+
# TODO: Avoid the other exception raised - think due to exception being
77+
# raised inside the @iet_pass?
7778
if duplicates:
7879
dup_list = ", ".join(repr(p) for p in sorted(duplicates))
7980
raise ValueError(
8081
f"The following `options_prefix` values are duplicated "
8182
f"among your PETScSolves. Ensure each one is unique: {dup_list}"
8283
)
8384

84-
# List of calls to clear options from the global PETSc options database.
85-
# These are executed at the end of the Operator.
85+
# List of `Call`s to clear options from the global PETSc options database,
86+
# executed at the end of the Operator.
8687
clear_options = []
8788

8889
for iters, (inject_solve,) in inject_solve_mapper.items():

devito/petsc/iet/routines.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _make_callable_body(self, body, stacks=(), casts=()):
136136

137137
def _make_options_callback(self):
138138
"""
139-
Create two callbacks: one to set PETSc options and one for
139+
Create two callbacks: one to set PETSc options and one
140140
to clear them.
141141
142142
Options are only set/cleared if they were not specifed via
@@ -153,6 +153,7 @@ def _make_options_callback(self):
153153
# Ensures that the command line args take priority
154154
continue
155155
option_name = String(option)
156+
# For options without a value e.g `ksp_view`, pass Null
156157
option_value = Null if v is None else String(str(v))
157158
set_body.append(
158159
petsc_call('PetscOptionsSetValue', [Null, option_name, option_value])

devito/petsc/logging.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,12 @@ def __init__(self, name, pname, petsc_option_mapper, sobjs, section_mapper,
134134

135135
mapper = {v: k for k, v in petsc_type_mappings.items()}
136136

137-
fields = [
138-
(str(ptype), mapper[str(ptype._C_ctype)])
139-
for option in petsc_option_mapper.values() for ptype in option.values()
140-
]
137+
fields = []
138+
for obj_mapper in petsc_option_mapper.values():
139+
for petsc_option in obj_mapper.values():
140+
ctype = mapper[str(petsc_option._C_ctype)]
141+
fields.append((petsc_option.name, ctype))
142+
141143
super().__init__(name, pname, fields)
142144

143145
@property
@@ -164,24 +166,22 @@ def __getattr__(self, attr):
164166
# Helper to get the value from the profiling struct
165167
get_val = lambda v: getattr(self.value._obj, v.name)
166168

167-
# Return the value(s) for the given PETSc attribute:
168-
# - If there's only one output (e.g., for KSPGetIterationNumber),
169-
# return it directly.
170-
# - If there are multiple outputs (e.g., for KSPGetTolerances),
171-
# return a dictionary mapping each output name to
172-
# its value, e.g., {'rtol': val0, 'abstol': val1, ...}.
169+
# - If the function returns a single value (e.g., KSPGetIterationNumber),
170+
# return that value directly.
171+
# - If the function returns multiple values (e.g., KSPGetTolerances),
172+
# return a dictionary mapping each output name to its value,
173+
# e.g., {'rtol': val0, 'abstol': val1, ...}.
173174
if len(obj_mapper) == 1:
174175
return get_val(next(iter(obj_mapper.values())))
175176
return {k: get_val(v) for k, v in obj_mapper.items()}
176177

177178

178-
# TODO: change the lists to tuples
179179
@dataclass
180180
class PetscReturnVariable:
181181
name: str
182-
variable_type: list
182+
variable_type: tuple
183183
input_params: str
184-
output_param: list[str]
184+
output_param: tuple[str]
185185

186186

187187
# NOTE:
@@ -190,23 +190,32 @@ class PetscReturnVariable:
190190
# If any of the PETSc function signatures change (e.g., names, input/output parameters),
191191
# this dictionary must be updated accordingly.
192192

193+
# TODO: To be extended
193194
petsc_return_variable_dict = {
195+
# KSP specific
194196
'kspgetiterationnumber': PetscReturnVariable(
195197
name='KSPGetIterationNumber',
196-
variable_type=[PetscInt],
198+
variable_type=(PetscInt,),
199+
input_params='ksp',
200+
output_param=('kspits',)
201+
),
202+
'kspgettolerances': PetscReturnVariable(
203+
name='KSPGetTolerances',
204+
variable_type=(PetscScalar, PetscScalar, PetscScalar, PetscInt),
197205
input_params='ksp',
198-
output_param=['kspits']
206+
output_param=('rtol', 'abstol', 'dtol', 'maxits'),
199207
),
208+
'kspgetconvergedreason': PetscReturnVariable(
209+
name='KSPGetConvergedReason',
210+
variable_type=(PetscInt,),
211+
input_params='ksp',
212+
output_param=('reason',),
213+
),
214+
# SNES specific
200215
'snesgetiterationnumber': PetscReturnVariable(
201216
name='SNESGetIterationNumber',
202-
variable_type=[PetscInt],
217+
variable_type=(PetscInt,),
203218
input_params='snes',
204-
output_param=['snesits'],
219+
output_param=('snesits',),
205220
),
206-
'kspgettolerances': PetscReturnVariable(
207-
name='KSPGetTolerances',
208-
variable_type=[PetscScalar, PetscScalar, PetscScalar, PetscInt],
209-
input_params='ksp',
210-
output_param=['rtol', 'abstol', 'dtol', 'maxits'],
211-
)
212221
}

devito/petsc/solver_parameters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040

4141
def linear_solver_parameters(solver_parameters):
42+
# Flatten parameters to support nested dictionaries
4243
flattened = flatten_parameters(solver_parameters or {})
4344
processed = linear_solve_defaults.copy()
4445
processed.update(flattened)

tests/test_petsc.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,8 +1476,7 @@ def test_logging_multiple_solves(self, log_level):
14761476
@pytest.mark.parametrize('log_level', ['PERF', 'DEBUG'])
14771477
def test_logging_user_prefixes(self, log_level):
14781478
"""
1479-
Test that the PetscSummary uses the user provided options prefix
1480-
if provided, otherwise it uses the default one provided by Devito.
1479+
Verify that `PetscSummary` uses the user provided `options_prefix` when given.
14811480
"""
14821481
grid = Grid(shape=(11, 11), dtype=np.float64)
14831482

@@ -1505,8 +1504,8 @@ def test_logging_user_prefixes(self, log_level):
15051504
@pytest.mark.parametrize('log_level', ['PERF', 'DEBUG'])
15061505
def test_logging_default_prefixes(self, log_level):
15071506
"""
1508-
Test that the PetscSummary uses the default options prefix
1509-
# generated by Devito
1507+
Verify that `PetscSummary` uses the default options prefix
1508+
provided by Devito if no user `options_prefix` is specified.
15101509
"""
15111510
grid = Grid(shape=(11, 11), dtype=np.float64)
15121511

@@ -1526,7 +1525,8 @@ def test_logging_default_prefixes(self, log_level):
15261525

15271526
petsc_summary = summary.petsc
15281527

1529-
# Note, if users want to use logging, they should really set the options_prefix
1528+
# Users should set a custom options_prefix if they want logging; otherwise,
1529+
# the default automatically generated prefix is used in the `PetscSummary`.
15301530
assert all(re.fullmatch(r"devito_\d+_", k.options_prefix) for k in petsc_summary)
15311531

15321532

@@ -1594,7 +1594,7 @@ def test_options_prefix(self):
15941594
def test_options_no_value(self):
15951595
"""
15961596
Test solver parameters that do not require a value, such as
1597-
`snes_view` and `ksp_view`
1597+
`snes_view` and `ksp_view`.
15981598
"""
15991599
solver = PETScSolve(
16001600
self.eq1, target=self.e, solver_parameters={'snes_view': None},
@@ -1629,6 +1629,8 @@ def test_tolerances(self, log_level):
16291629
entry = petsc_summary.get_entry('section0', 'solver')
16301630
tolerances = entry.KSPGetTolerances
16311631

1632+
# Test that the tolerances have been set correctly and therefore
1633+
# appear as expected in the `PetscSummary`.
16321634
assert tolerances['rtol'] == params['ksp_rtol']
16331635
assert tolerances['abstol'] == params['ksp_atol']
16341636
assert tolerances['dtol'] == params['ksp_divtol']
@@ -1672,7 +1674,7 @@ def test_error_if_same_prefix(self):
16721674

16731675
@skipif('petsc')
16741676
@pytest.mark.parametrize('log_level', ['PERF', 'DEBUG'])
1675-
def test_multiple_operators(self):
1677+
def test_multiple_operators(self, log_level):
16761678
"""
16771679
Verify that solver parameters are set correctly when multiple Operators
16781680
are created with PETScSolve instances sharing the same options_prefix.
@@ -1690,7 +1692,7 @@ def test_multiple_operators(self):
16901692
self.eq2, target=self.g, options_prefix='poisson',
16911693
solver_parameters={'ksp_rtol': '1e-12'}
16921694
)
1693-
with switchconfig(language='petsc'):
1695+
with switchconfig(language='petsc', log_level=log_level):
16941696
op1 = Operator(solver1)
16951697
op2 = Operator(solver2)
16961698
summary1 = op1.apply()
@@ -1708,7 +1710,6 @@ def test_multiple_operators(self):
17081710
# TODO: Add test to check that the command line args override anything set
17091711
# in the solver_parameters dictionary
17101712

1711-
# TODO: update names of all of these tests
17121713
# @skipif('petsc')
17131714
# def test_command_line_priority(self):
17141715
# """

0 commit comments

Comments
 (0)