Skip to content

Commit 7718ce5

Browse files
committed
Change SolveExpr name to SolverMetaData and optimise logger
1 parent 727ed35 commit 7718ce5

7 files changed

Lines changed: 23 additions & 22 deletions

File tree

devito/petsc/clusters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from devito.tools import timed_pass
2-
from devito.petsc.types import SolveExpr
2+
from devito.petsc.types import SolverMetaData
33

44

55
@timed_pass()
@@ -19,7 +19,7 @@ def petsc_lift(clusters):
1919
"""
2020
processed = []
2121
for c in clusters:
22-
if isinstance(c.exprs[0].rhs, SolveExpr):
22+
if isinstance(c.exprs[0].rhs, SolverMetaData):
2323
ispace = c.ispace.lift(c.exprs[0].rhs.field_data.space_dimensions)
2424
processed.append(c.rebuild(ispace=ispace))
2525
else:

devito/petsc/iet/logging.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ class PetscLogger:
1313
"""
1414
Class for PETSc loggers that collect solver related statistics.
1515
"""
16-
def __init__(self, level, get_info=[], **kwargs):
16+
# TODO: Update docstring with kwargs
17+
def __init__(self, level, **kwargs):
1718

18-
self.query_functions = get_info
19+
self.query_functions = kwargs.get('get_info', [])
1920
self.sobjs = kwargs.get('solver_objs')
2021
self.sreg = kwargs.get('sregistry')
2122
self.section_mapper = kwargs.get('section_mapper', {})
@@ -32,9 +33,9 @@ def __init__(self, level, get_info=[], **kwargs):
3233
# SNES specific
3334
'snesgetiterationnumber',
3435
]
35-
self.query_functions.extend(
36-
[f for f in funcs if f not in self.query_functions]
37-
)
36+
self.query_functions = set(self.query_functions)
37+
self.query_functions.update(funcs)
38+
self.query_functions = sorted(list(self.query_functions))
3839

3940
# TODO: To be extended with if level <= DEBUG: ...
4041

devito/petsc/iet/passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def solve(self):
203203
def logger(self):
204204
log_level = devito.logger.logger.level
205205
return PetscLogger(
206-
log_level, self.get_info, **self.common_kwargs
206+
log_level, get_info=self.get_info, **self.common_kwargs
207207
)
208208

209209
@cached_property

devito/petsc/iet/routines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1726,10 +1726,10 @@ class TimeDependent(NonTimeDependent):
17261726
for each `SNESSolve` at every time step, don't require the time loop, but
17271727
may still need access to data from other time steps.
17281728
- All `Function` objects are passed through the initial lowering via the
1729-
`SolveExpr` object, ensuring the correct time loop is generated
1729+
`SolverMetaData` object, ensuring the correct time loop is generated
17301730
in the main kernel.
17311731
- Another mapper is created based on the modulo dimensions
1732-
generated by the `SolveExpr` object in the main kernel
1732+
generated by the `SolverMetaData` object in the main kernel
17331733
(e.g., {time: time, t: t0, t + 1: t1}).
17341734
- These two mappers are used to generate a final mapper `symb_to_moddim`
17351735
(e.g. {tau0: t0, tau1: t1}) which is used at the IET level to

devito/petsc/solve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from devito.types.equation import PetscEq
22
from devito.tools import as_tuple
3-
from devito.petsc.types import (LinearSolveExpr, PETScArray, DMDALocalInfo,
3+
from devito.petsc.types import (LinearSolverMetaData, PETScArray, DMDALocalInfo,
44
FieldData, MultipleFieldData, Jacobian, Residual,
55
MixedResidual, MixedJacobian, InitialGuess)
66
from devito.petsc.types.equation import EssentialBC
@@ -103,7 +103,7 @@ def build_expr(self):
103103
target, funcs, field_data = self.linear_solve_args()
104104
# Placeholder expression for inserting calls to the solver
105105

106-
linear_solve = LinearSolveExpr(
106+
linear_solve = LinearSolverMetaData(
107107
funcs,
108108
solver_parameters=self.solver_parameters,
109109
field_data=field_data,

devito/petsc/types/types.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class GetArgs(MetaData):
3232
pass
3333

3434

35-
class SolveExpr(MetaData):
35+
class SolverMetaData(MetaData):
3636
"""
3737
A symbolic expression passed through the Operator, containing the metadata
3838
needed to execute the PETSc solver.
@@ -59,7 +59,7 @@ def __new__(cls, expr, solver_parameters=None,
5959
obj.localinfo = localinfo
6060
obj.user_prefix = user_prefix
6161
obj.formatted_prefix = formatted_prefix
62-
obj.get_info = get_info
62+
obj.get_info = get_info if get_info is not None else []
6363
return obj
6464

6565
def __repr__(self):
@@ -76,7 +76,7 @@ def _hashable_content(self):
7676
return (self.expr, self.formatted_prefix, self.solver_parameters)
7777

7878
def __eq__(self, other):
79-
return (isinstance(other, SolveExpr) and
79+
return (isinstance(other, SolverMetaData) and
8080
self.expr == other.expr and
8181
self.formatted_prefix == other.formatted_prefix
8282
and self.solver_parameters == other.solver_parameters)
@@ -92,15 +92,15 @@ def eval(cls, *args):
9292
func = Reconstructable._rebuild
9393

9494

95-
class LinearSolveExpr(SolveExpr):
95+
class LinearSolverMetaData(SolverMetaData):
9696
"""
9797
Linear problems are handled by setting the SNESType to 'ksponly',
9898
enabling a unified interface for both linear and nonlinear solvers.
9999
"""
100100
pass
101101

102102

103-
class NonLinearSolveExpr(SolveExpr):
103+
class NonLinearSolverMetaData(SolverMetaData):
104104
"""
105105
TODO: Non linear solvers are not yet supported.
106106
"""
@@ -109,7 +109,7 @@ class NonLinearSolveExpr(SolveExpr):
109109

110110
class FieldData:
111111
"""
112-
Metadata for a single `target` field passed to `SolveExpr`.
112+
Metadata for a single `target` field passed to `SolverMetaData`.
113113
Used to interface with PETSc SNES solvers at the IET level.
114114
115115
Parameters
@@ -134,8 +134,8 @@ def __init__(self, target=None, jacobian=None, residual=None,
134134
petsc_precision = dtype_mapper[petsc_variables['PETSC_PRECISION']]
135135
if self._target.dtype != petsc_precision:
136136
raise TypeError(
137-
f"Your target dtype must match the precision of your "
138-
f"PETSc configuration. "
137+
"Your target dtype must match the precision of your "
138+
"PETSc configuration. "
139139
f"Expected {petsc_precision}, but got {self._target.dtype}."
140140
)
141141
self._jacobian = jacobian
@@ -182,7 +182,7 @@ def targets(self):
182182

183183
class MultipleFieldData(FieldData):
184184
"""
185-
Metadata class passed to `SolveExpr`, for mixed-field problems,
185+
Metadata class passed to `SolverMetaData`, for mixed-field problems,
186186
where the solution vector spans multiple `targets`. Used to interface
187187
with PETSc SNES solvers at the IET level.
188188

devito/petsc/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def generate_time_mapper(funcs):
123123
level to align with the `TimeDimension` and `ModuloDimension` objects
124124
present in the initial lowering.
125125
NOTE: All functions used in PETSc callback functions are attached to
126-
the `SolveExpr` object, which is passed through the initial lowering
126+
the `SolverMetaData` object, which is passed through the initial lowering
127127
(and subsequently dropped and replaced with calls to run the solver).
128128
Therefore, the appropriate time loop will always be correctly generated inside
129129
the main kernel.

0 commit comments

Comments
 (0)