Skip to content

Commit eb8043f

Browse files
committed
compiler: Start extending the PetscSummary
1 parent 1d8de25 commit eb8043f

11 files changed

Lines changed: 474 additions & 178 deletions

File tree

devito/petsc/iet/logging.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,54 @@ def __init__(self, level, **kwargs):
2323
if level <= PERF:
2424
self.function_list.extend([
2525
'kspgetiterationnumber',
26-
'snesgetiterationnumber'
26+
'snesgetiterationnumber',
27+
'kspgettolerances'
2728
])
2829

2930
# TODO: To be extended with if level <= DEBUG: ...
3031

32+
# from IPython import embed; embed()
33+
# if str(self.inject_solve.expr.rhs.solver_parameters['ksp_rtol']) == '1e-15':
34+
3135
name = self.sreg.make_name(prefix='petscinfo')
3236
pname = self.sreg.make_name(prefix='petscprofiler')
3337

3438
self.statstruct = PetscInfo(
35-
name, pname, self.logobjs, self.sobjs,
39+
name, pname, self.petsc_option_mapper, self.sobjs,
3640
self.section_mapper, self.inject_solve,
3741
self.function_list
3842
)
43+
# else:
44+
# name = self.sreg.make_name(prefix='petscinfooo')
45+
# pname = self.sreg.make_name(prefix='petscprofilerrrr')
46+
47+
# self.statstruct = PetscInfo(
48+
# name, pname, self.petsc_option_mapper, self.sobjs,
49+
# self.section_mapper, self.inject_solve,
50+
# self.function_list
51+
# )
52+
53+
# from IPython import embed; embed() # noqa: E402
54+
55+
# @property
56+
# def statstruct(self):
57+
# return self._statstruct
3958

4059
@cached_property
41-
def logobjs(self):
60+
def petsc_option_mapper(self):
4261
"""
4362
Create PETSc objects specifically needed for logging solver statistics.
63+
64+
ADD EXTENDED DOCSTRING
4465
"""
45-
return {
46-
info.name: info.variable_type(
47-
self.sreg.make_name(prefix=info.output_param)
48-
)
49-
for func_name in self.function_list
50-
for info in [petsc_return_variable_dict[func_name]]
51-
}
66+
opts = {}
67+
for func_name in self.function_list:
68+
info = petsc_return_variable_dict[func_name]
69+
opts[info.name] = {}
70+
for vtype, out in zip(info.variable_type, info.output_param, strict=True):
71+
opts[info.name][out] = vtype(self.sreg.make_name(prefix=out))
72+
73+
return opts
5274

5375
@cached_property
5476
def calls(self):
@@ -58,20 +80,21 @@ def calls(self):
5880
"""
5981
struct = self.statstruct
6082
calls = []
61-
for param in self.function_list:
62-
param = petsc_return_variable_dict[param]
63-
64-
inputs = []
65-
for i in param.input_params:
66-
inputs.append(self.sobjs[i])
67-
68-
logobj = self.logobjs[param.name]
83+
for func_name in self.function_list:
84+
return_variable = petsc_return_variable_dict[func_name]
6985

86+
input = self.sobjs[return_variable.input_params]
87+
output_params = self.petsc_option_mapper[return_variable.name].values()
88+
outputs = [Byref(i) for i in output_params]
89+
# from IPython import embed; embed()
7090
calls.append(
71-
petsc_call(param.name, inputs + [Byref(logobj)])
91+
petsc_call(return_variable.name, [input] + outputs)
7292
)
7393
# TODO: Perform a PetscCIntCast here?
74-
expr = DummyExpr(FieldFromPointer(logobj._C_symbol, struct), logobj._C_symbol)
75-
calls.append(expr)
94+
exprs = [
95+
DummyExpr(FieldFromPointer(i._C_symbol, struct), i._C_symbol)
96+
for i in output_params
97+
]
98+
calls.extend(exprs)
7699

77100
return tuple(calls)

devito/petsc/iet/passes.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,13 @@ def lower_petsc(iet, **kwargs):
7373

7474
# Generate a shared callback used by all PETScSolve instances to set
7575
# individual PetscOptions
76-
set_solver_option(efuncs)
76+
# set_solver_option(efuncs)
77+
# List of all callbacks that clear PetscOptions
78+
# from IPython import embed; embed()
79+
# TODO: throw a warning/error if the user passes a solver in with the same options_prefix
80+
# it's going to lead to weird solver option behaviour. Note, if you use the options_prefix across
81+
# different Operator runs, it will not be an issue
82+
clear_options = []
7783

7884
for iters, (inject_solve,) in inject_solve_mapper.items():
7985

@@ -86,11 +92,17 @@ def lower_petsc(iet, **kwargs):
8692

8793
efuncs.update(builder.cbbuilder.efuncs)
8894

95+
# clear_options.append(builder.cbbuilder._clear_options_efunc)
96+
# from IPython import embed; embed() # noqa: E402
97+
clear_options.extend((petsc_call(
98+
builder.cbbuilder._clear_options_efunc.name, []
99+
),))
100+
89101
populate_matrix_context(efuncs)
90102

91103
iet = Transformer(subs).visit(iet)
92-
93-
body = core + tuple(setup) + iet.body.body
104+
# from IPython import embed; embed()
105+
body = core + tuple(setup) + iet.body.body + tuple(clear_options)
94106
body = iet.body._rebuild(body=body)
95107
iet = iet._rebuild(body=body)
96108
metadata = {**core_metadata(), 'efuncs': tuple(efuncs.values())}

devito/petsc/iet/routines.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from collections import OrderedDict
23
from functools import cached_property
34
import math
@@ -42,7 +43,8 @@ def __init__(self, **kwargs):
4243
self._efuncs = OrderedDict()
4344
self._struct_params = []
4445

45-
self._options_efunc = None
46+
self._set_options_efunc = None
47+
self._clear_options_efunc = None
4648
self._main_matvec_callback = None
4749
self._user_struct_callback = None
4850
self._F_efunc = None
@@ -120,26 +122,48 @@ def _make_options_callback(self):
120122
params = self.solver_parameters
121123
prefix = self.inject_solve.expr.rhs.formatted_prefix
122124

123-
body = []
125+
set_body = []
126+
clear_body = []
127+
124128
for k, v in params.items():
125-
option_key = String(f"-{prefix}{k}")
129+
option = f'-{prefix}{k}'
130+
if option in sys.argv:
131+
# Ensures that the command line options take priority
132+
continue
133+
option_name = String(option)
126134
option_value = Null if v is None else String(str(v))
127-
body.append(petsc_call('SetPetscOption', [option_key, option_value]))
135+
set_body.append(petsc_call('PetscOptionsSetValue', [Null, option_name, option_value]))
136+
clear_body.append(petsc_call('PetscOptionsClearValue', [Null, option_name]))
128137

129-
body = CallableBody(
130-
List(body=body),
138+
set_body = CallableBody(
139+
List(body=set_body),
131140
init=(petsc_func_begin_user,),
132141
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
133142
)
134143

135-
cb = PETScCallable(
144+
clear_body = CallableBody(
145+
List(body=clear_body),
146+
init=(petsc_func_begin_user,),
147+
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
148+
)
149+
150+
set_callback = PETScCallable(
136151
self.sregistry.make_name(prefix='SetPetscOptions'),
137-
body,
152+
set_body,
138153
retval=objs['err'],
139154
parameters=()
140155
)
141-
self._options_efunc = cb
142-
self._efuncs[cb.name] = cb
156+
157+
clear_callback = PETScCallable(
158+
self.sregistry.make_name(prefix='ClearPetscOptions'),
159+
clear_body,
160+
retval=objs['err'],
161+
parameters=()
162+
)
163+
self._set_options_efunc = set_callback
164+
self._efuncs[set_callback.name] = set_callback
165+
self._clear_options_efunc = clear_callback
166+
self._efuncs[clear_callback.name] = clear_callback
143167

144168
def _make_matvec(self, jacobian, prefix='MatMult'):
145169
# Compile `matvecs` into an IET via recursive compilation
@@ -1264,7 +1288,7 @@ def _setup(self):
12641288
) if self.formatted_prefix else None
12651289

12661290
set_options = petsc_call(
1267-
self.cbbuilder._options_efunc.name, []
1291+
self.cbbuilder._set_options_efunc.name, []
12681292
)
12691293

12701294
snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda])
@@ -1421,7 +1445,7 @@ def _setup(self):
14211445
) if self.formatted_prefix else None
14221446

14231447
set_options = petsc_call(
1424-
self.cbbuilder._options_efunc.name, []
1448+
self.cbbuilder._set_options_efunc.name, []
14251449
)
14261450

14271451
snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda])

devito/petsc/logging.py

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from devito.types import CompositeObject
55

6-
from devito.petsc.types import PetscInt
6+
from devito.petsc.types import PetscInt, PetscScalar
77
from devito.petsc.utils import petsc_type_mappings
88

99

@@ -63,7 +63,9 @@ def petsc_entry(self, petscinfo):
6363
containing the values for each PETSc function call.
6464
"""
6565
funcs = self._functions
66+
# from IPython import embed; embed()
6667
values = tuple(getattr(petscinfo, c) for c in funcs)
68+
# from IPython import embed; embed()
6769
return PetscEntry(**{k: v for k, v in zip(funcs, values)})
6870

6971
def _add_properties(self):
@@ -119,20 +121,28 @@ def __getitem__(self, key):
119121

120122
class PetscInfo(CompositeObject):
121123

122-
__rargs__ = ('name', 'pname', 'logobjs', 'sobjs', 'section_mapper',
124+
__rargs__ = ('name', 'pname', 'petsc_option_mapper', 'sobjs', 'section_mapper',
123125
'inject_solve', 'function_list')
124126

125-
def __init__(self, name, pname, logobjs, sobjs, section_mapper,
127+
def __init__(self, name, pname, petsc_option_mapper, sobjs, section_mapper,
126128
inject_solve, function_list):
127129

128-
self.logobjs = logobjs
130+
# TODO: change name to match new name elsewehere
131+
self.petsc_option_mapper = petsc_option_mapper
129132
self.sobjs = sobjs
130133
self.section_mapper = section_mapper
131134
self.inject_solve = inject_solve
132135
self.function_list = function_list
133136

134137
mapper = {v: k for k, v in petsc_type_mappings.items()}
135-
fields = [(str(i), mapper[str(i._C_ctype)]) for i in logobjs.values()]
138+
139+
self.formatted_prefix = inject_solve.expr.rhs.formatted_prefix
140+
141+
fields = [
142+
(str(ptype), mapper[str(ptype._C_ctype)])
143+
for option in petsc_option_mapper.values() for ptype in option.values()
144+
]
145+
136146
super().__init__(name, pname, fields)
137147

138148
@property
@@ -143,20 +153,53 @@ def section(self):
143153
@property
144154
def summary_key(self):
145155
user_prefix = self.inject_solve.expr.rhs.user_prefix
156+
# TODO: this will be the case when using the default options prefix provided by Devito
157+
# if user_prefix is None:
158+
# user_prefix = self.formatted_prefix
146159
return (self.section, user_prefix)
147160

148161
def __getattr__(self, attr):
149-
if attr in self.logobjs.keys():
150-
return getattr(self.value._obj, self.logobjs[attr].name)
162+
if attr in self.petsc_option_mapper.keys():
163+
if len(self.petsc_option_mapper[attr].values()) > 1:
164+
tmp = {}
165+
for i, j in self.petsc_option_mapper[attr].items():
166+
tmp2 = getattr(self.value._obj, j.name)
167+
tmp[i] = tmp2
168+
return tmp
169+
else:
170+
# TODO: CLEANNN
171+
first_val = list(self.petsc_option_mapper[attr].values())[0]
172+
return getattr(self.value._obj, first_val.name)
173+
# from IPython import embed; embed() # noqa: E402
151174
raise AttributeError(f"{attr} not found in PETSc return variables")
152-
153-
175+
176+
# TODO: maybe just overrider _hashable_content??
177+
178+
# def _hashable_content(self):
179+
# """
180+
# Return a tuple of the formatted prefix and section for hashing.
181+
# This is used to ensure that two PetscInfo objects with the same
182+
# formatted prefix and section are considered equal.
183+
# """
184+
# return (self.name, self.dtype, self.inject_solve.expr.rhs)
185+
186+
# def __eq__(self, other):
187+
# if not isinstance(other, PetscInfo):
188+
# return NotImplemented
189+
# # return self.formatted_prefix == other.formatted_prefix
190+
# return self.inject_solve.expr.rhs == other.inject_solve.expr.rhs
191+
192+
# def __hash__(self):
193+
# return hash(self.inject_solve.expr.rhs)
194+
195+
196+
# TODO: change the lists to tuples
154197
@dataclass
155198
class PetscReturnVariable:
156199
name: str
157-
variable_type: None
158-
input_params: list
159-
output_param: str
200+
variable_type: list
201+
input_params: str
202+
output_param: list[str]
160203

161204

162205
# NOTE:
@@ -168,14 +211,20 @@ class PetscReturnVariable:
168211
petsc_return_variable_dict = {
169212
'kspgetiterationnumber': PetscReturnVariable(
170213
name='KSPGetIterationNumber',
171-
variable_type=PetscInt,
172-
input_params=['ksp'],
173-
output_param='kspiter'
214+
variable_type=[PetscInt],
215+
input_params='ksp',
216+
output_param=['kspiter']
174217
),
175218
'snesgetiterationnumber': PetscReturnVariable(
176219
name='SNESGetIterationNumber',
177-
variable_type=PetscInt,
178-
input_params=['snes'],
179-
output_param='snesiter',
220+
variable_type=[PetscInt],
221+
input_params='snes',
222+
output_param=['snesiter'],
223+
),
224+
'kspgettolerances': PetscReturnVariable(
225+
name='KSPGetTolerances',
226+
variable_type=[PetscScalar, PetscScalar, PetscScalar, PetscInt],
227+
input_params='ksp',
228+
output_param=['rtol', 'abstol', 'dtol', 'maxits'],
180229
)
181230
}

0 commit comments

Comments
 (0)