Skip to content

Commit 2bc855d

Browse files
committed
compiler: Add utility function inside petsc routines
1 parent 586b77e commit 2bc855d

10 files changed

Lines changed: 128 additions & 488 deletions

File tree

devito/petsc/iet/logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def calls(self):
7777

7878
input = self.sobjs[return_variable.input_params]
7979
output_params = self.petsc_option_mapper[return_variable.name].values()
80-
outputs = [Byref(i) for i in output_params]
80+
by_ref_output = [Byref(i) for i in output_params]
8181

8282
calls.append(
83-
petsc_call(return_variable.name, [input] + outputs)
83+
petsc_call(return_variable.name, [input] + by_ref_output)
8484
)
8585
# TODO: Perform a PetscCIntCast here?
8686
exprs = [

devito/petsc/iet/passes.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,22 @@ def lower_petsc(iet, **kwargs):
6969
# Map PETScSolve to its Section (for logging)
7070
section_mapper = MapNodes(Section, PetscMetaData, 'groupby').visit(iet)
7171

72-
# prefixes within a single Operator should not be duplicated
72+
# Prefixes within a single Operator should not be duplicated
7373
prefixes = [d.expr.rhs.user_prefix for d in data if d.expr.rhs.user_prefix]
7474
duplicates = {p for p in prefixes if prefixes.count(p) > 1}
7575

76-
# TODO: How to avoid the other exception raised given it is an iet_pass?
76+
# TODO: Avoid the other exception raised given it is an iet_pass?
7777
if duplicates:
78-
dup_list = ", ".join(sorted(duplicates))
78+
dup_list = ", ".join(repr(p) for p in sorted(duplicates))
7979
raise ValueError(
8080
f"The following `options_prefix` values are duplicated "
8181
f"among your PETScSolves. Ensure each one is unique: {dup_list}"
8282
)
8383

84+
# List of calls to clear options from the global PETSc options database.
85+
# These are executed at the end of the Operator.
8486
clear_options = []
87+
8588
for iters, (inject_solve,) in inject_solve_mapper.items():
8689

8790
builder = Builder(inject_solve, iters, comm, section_mapper, **kwargs)
@@ -98,7 +101,6 @@ def lower_petsc(iet, **kwargs):
98101
),))
99102

100103
populate_matrix_context(efuncs)
101-
102104
iet = Transformer(subs).visit(iet)
103105
body = core + tuple(setup) + iet.body.body + tuple(clear_options)
104106
body = iet.body._rebuild(body=body)

devito/petsc/iet/routines.py

Lines changed: 62 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,40 @@ def _make_core(self):
117117
self._make_initial_guess()
118118
self._make_user_struct_callback()
119119

120+
def _make_petsc_callable(self, prefix, body, parameters=()):
121+
return PETScCallable(
122+
self.sregistry.make_name(prefix=prefix),
123+
body,
124+
retval=self.objs['err'],
125+
parameters=parameters
126+
)
127+
128+
def _make_callable_body(self, body, stacks=(), casts=()):
129+
return CallableBody(
130+
List(body=body),
131+
init=(petsc_func_begin_user,),
132+
stacks=stacks,
133+
casts=casts,
134+
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
135+
)
136+
120137
def _make_options_callback(self):
121-
objs = self.objs
138+
"""
139+
Create two callbacks: one to set PETSc options and one for
140+
to clear them.
141+
142+
Options are only set/cleared if they were not specifed via
143+
command line arguments.
144+
"""
122145
params = self.solver_parameters
123146
prefix = self.inject_solve.expr.rhs.formatted_prefix
124147

125-
set_body = []
126-
clear_body = []
148+
set_body, clear_body = [], []
127149

128150
for k, v in params.items():
129151
option = f'-{prefix}{k}'
130152
if option in sys.argv:
131-
# Ensures that the command line options take priority
153+
# Ensures that the command line args take priority
132154
continue
133155
option_name = String(option)
134156
option_value = Null if v is None else String(str(v))
@@ -139,31 +161,12 @@ def _make_options_callback(self):
139161
petsc_call('PetscOptionsClearValue', [Null, option_name])
140162
)
141163

142-
set_body = CallableBody(
143-
List(body=set_body),
144-
init=(petsc_func_begin_user,),
145-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
146-
)
164+
set_body = self._make_callable_body(set_body)
165+
clear_body = self._make_callable_body(clear_body)
147166

148-
clear_body = CallableBody(
149-
List(body=clear_body),
150-
init=(petsc_func_begin_user,),
151-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
152-
)
167+
set_callback = self._make_petsc_callable('SetPetscOptions', set_body)
168+
clear_callback = self._make_petsc_callable('ClearPetscOptions', clear_body)
153169

154-
set_callback = PETScCallable(
155-
self.sregistry.make_name(prefix='SetPetscOptions'),
156-
set_body,
157-
retval=objs['err'],
158-
parameters=()
159-
)
160-
161-
clear_callback = PETScCallable(
162-
self.sregistry.make_name(prefix='ClearPetscOptions'),
163-
clear_body,
164-
retval=objs['err'],
165-
parameters=()
166-
)
167170
self._set_options_efunc = set_callback
168171
self._efuncs[set_callback.name] = set_callback
169172
self._clear_options_efunc = clear_callback
@@ -179,13 +182,9 @@ def _make_matvec(self, jacobian, prefix='MatMult'):
179182
body = self._create_matvec_body(
180183
List(body=irs.uiet.body), jacobian
181184
)
182-
183185
objs = self.objs
184-
cb = PETScCallable(
185-
self.sregistry.make_name(prefix=prefix),
186-
body,
187-
retval=objs['err'],
188-
parameters=(objs['J'], objs['X'], objs['Y'])
186+
cb = self._make_petsc_callable(
187+
prefix, body, parameters=(objs['J'], objs['X'], objs['Y'])
189188
)
190189
self._J_efuncs.append(cb)
191190
self._efuncs[cb.name] = cb
@@ -303,12 +302,7 @@ def _create_matvec_body(self, body, jacobian):
303302
# Dereference function data in struct
304303
derefs = dereference_funcs(ctx, fields)
305304

306-
body = CallableBody(
307-
List(body=body),
308-
init=(petsc_func_begin_user,),
309-
stacks=stacks+derefs,
310-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
311-
)
305+
body = self._make_callable_body(body, stacks=stacks+derefs)
312306

313307
# Replace non-function data with pointer to data in struct
314308
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields}
@@ -318,6 +312,7 @@ def _create_matvec_body(self, body, jacobian):
318312
return body
319313

320314
def _make_formfunc(self):
315+
objs = self.objs
321316
F_exprs = self.field_data.residual.F_exprs
322317
# Compile `F_exprs` into an IET via recursive compilation
323318
irs, _ = self.rcompile(
@@ -327,13 +322,9 @@ def _make_formfunc(self):
327322
body_formfunc = self._create_formfunc_body(
328323
List(body=irs.uiet.body)
329324
)
330-
objs = self.objs
331-
cb = PETScCallable(
332-
self.sregistry.make_name(prefix='FormFunction'),
333-
body_formfunc,
334-
retval=objs['err'],
335-
parameters=(objs['snes'], objs['X'], objs['F'], objs['dummyptr'])
336-
)
325+
parameters = (objs['snes'], objs['X'], objs['F'], objs['dummyptr'])
326+
cb = self._make_petsc_callable('FormFunction', body_formfunc, parameters)
327+
337328
self._F_efunc = cb
338329
self._efuncs[cb.name] = cb
339330

@@ -442,12 +433,7 @@ def _create_formfunc_body(self, body):
442433
# Dereference function data in struct
443434
derefs = dereference_funcs(ctx, fields)
444435

445-
body = CallableBody(
446-
List(body=body),
447-
init=(petsc_func_begin_user,),
448-
stacks=stacks+derefs,
449-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),))
450-
436+
body = self._make_callable_body(body, stacks=stacks+derefs)
451437
# Replace non-function data with pointer to data in struct
452438
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields}
453439

@@ -466,11 +452,8 @@ def _make_formrhs(self):
466452
List(body=irs.uiet.body)
467453
)
468454
objs = self.objs
469-
cb = PETScCallable(
470-
self.sregistry.make_name(prefix='FormRHS'),
471-
body,
472-
retval=objs['err'],
473-
parameters=(sobjs['callbackdm'], objs['B'])
455+
cb = self._make_petsc_callable(
456+
'FormRHS', body, parameters=(sobjs['callbackdm'], objs['B'])
474457
)
475458
self._b_efunc = cb
476459
self._efuncs[cb.name] = cb
@@ -552,12 +535,7 @@ def _create_form_rhs_body(self, body):
552535
# Dereference function data in struct
553536
derefs = dereference_funcs(ctx, fields)
554537

555-
body = CallableBody(
556-
List(body=[body]),
557-
init=(petsc_func_begin_user,),
558-
stacks=stacks+derefs,
559-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
560-
)
538+
body = self._make_callable_body([body], stacks=stacks+derefs)
561539

562540
# Replace non-function data with pointer to data in struct
563541
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
@@ -568,6 +546,7 @@ def _create_form_rhs_body(self, body):
568546
def _make_initial_guess(self):
569547
exprs = self.field_data.initial_guess.exprs
570548
sobjs = self.solver_objs
549+
objs = self.objs
571550

572551
# Compile initital guess `eqns` into an IET via recursive compilation
573552
irs, _ = self.rcompile(
@@ -577,12 +556,8 @@ def _make_initial_guess(self):
577556
body = self._create_initial_guess_body(
578557
List(body=irs.uiet.body)
579558
)
580-
objs = self.objs
581-
cb = PETScCallable(
582-
self.sregistry.make_name(prefix='FormInitialGuess'),
583-
body,
584-
retval=objs['err'],
585-
parameters=(sobjs['callbackdm'], objs['xloc'])
559+
cb = self._make_petsc_callable(
560+
'FormInitialGuess', body, parameters=(sobjs['callbackdm'], objs['xloc'])
586561
)
587562
self._initial_guesses.append(cb)
588563
self._efuncs[cb.name] = cb
@@ -629,13 +604,7 @@ def _create_initial_guess_body(self, body):
629604

630605
# Dereference function data in struct
631606
derefs = dereference_funcs(ctx, fields)
632-
633-
body = CallableBody(
634-
List(body=[body]),
635-
init=(petsc_func_begin_user,),
636-
stacks=stacks+derefs,
637-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
638-
)
607+
body = self._make_callable_body(body, stacks=stacks+derefs)
639608

640609
# Replace non-function data with pointer to data in struct
641610
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
@@ -658,10 +627,7 @@ def _make_user_struct_callback(self):
658627
DummyExpr(FieldFromPointer(i._C_symbol, mainctx), i._C_symbol)
659628
for i in mainctx.callback_fields
660629
]
661-
struct_callback_body = CallableBody(
662-
List(body=body), init=(petsc_func_begin_user,),
663-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
664-
)
630+
struct_callback_body = self._make_callable_body(body)
665631
cb = Callable(
666632
self.sregistry.make_name(prefix='PopulateUserContext'),
667633
struct_callback_body, self.objs['err'],
@@ -731,11 +697,9 @@ def _make_whole_matvec(self):
731697
objs = self.objs
732698
body = self._whole_matvec_body()
733699

734-
cb = PETScCallable(
735-
self.sregistry.make_name(prefix='WholeMatMult'),
736-
List(body=body),
737-
retval=objs['err'],
738-
parameters=(objs['J'], objs['X'], objs['Y'])
700+
parameters = (objs['J'], objs['X'], objs['Y'])
701+
cb = self._make_petsc_callable(
702+
'WholeMatMult', List(body=body), parameters=parameters
739703
)
740704
self._main_matvec_callback = cb
741705
self._efuncs[cb.name] = cb
@@ -782,13 +746,11 @@ def _whole_matvec_body(self):
782746
[objs['Y'], Deref(FieldFromPointer(rows, ctx)), Byref(Y)]
783747
),
784748
)
785-
return CallableBody(
786-
List(body=(ctx_main, zero_y_memory, BlankLine) + calls),
787-
init=(petsc_func_begin_user,),
788-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
789-
)
749+
body = (ctx_main, zero_y_memory, BlankLine) + calls
750+
return self._make_callable_body(body)
790751

791752
def _make_whole_formfunc(self):
753+
objs = self.objs
792754
F_exprs = self.field_data.residual.F_exprs
793755
# Compile `F_exprs` into an IET via recursive compilation
794756
irs, _ = self.rcompile(
@@ -797,13 +759,11 @@ def _make_whole_formfunc(self):
797759
)
798760
body = self._whole_formfunc_body(List(body=irs.uiet.body))
799761

800-
objs = self.objs
801-
cb = PETScCallable(
802-
self.sregistry.make_name(prefix='WholeFormFunc'),
803-
body,
804-
retval=objs['err'],
805-
parameters=(objs['snes'], objs['X'], objs['F'], objs['dummyptr'])
762+
parameters = (objs['snes'], objs['X'], objs['F'], objs['dummyptr'])
763+
cb = self._make_petsc_callable(
764+
'WholeFormFunc', body, parameters=parameters
806765
)
766+
807767
self._F_efunc = cb
808768
self._efuncs[cb.name] = cb
809769

@@ -917,14 +877,10 @@ def _whole_formfunc_body(self, body):
917877
f_soa = PointerCast(fbundle)
918878
x_soa = PointerCast(xbundle)
919879

920-
formfunc_body = CallableBody(
921-
List(body=body),
922-
init=(petsc_func_begin_user,),
923-
stacks=stacks+derefs,
880+
formfunc_body = self._make_callable_body(
881+
body, stacks=stacks+derefs,
924882
casts=(f_soa, x_soa),
925-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
926883
)
927-
928884
# Replace non-function data with pointer to data in struct
929885
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields}
930886

@@ -941,12 +897,9 @@ def _create_submatrices(self):
941897
objs['matreuse'],
942898
objs['Submats'],
943899
)
944-
cb = PETScCallable(
945-
self.sregistry.make_name(prefix='MatCreateSubMatrices'),
946-
body,
947-
retval=objs['err'],
948-
parameters=params
949-
)
900+
cb = self._make_petsc_callable(
901+
'MatCreateSubMatrices', body, parameters=params)
902+
950903
self._submatrices_callback = cb
951904
self._efuncs[cb.name] = cb
952905

@@ -1068,12 +1021,7 @@ def _submat_callback_body(self):
10681021
iteration,
10691022
] + matmult_op
10701023

1071-
return CallableBody(
1072-
List(body=tuple(body)),
1073-
init=(petsc_func_begin_user,),
1074-
stacks=(get_ctx, deref_subdm),
1075-
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
1076-
)
1024+
return self._make_callable_body(tuple(body), stacks=(get_ctx, deref_subdm))
10771025

10781026

10791027
class BaseObjectBuilder:

devito/petsc/logging.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ class PetscInfo(CompositeObject):
125125
def __init__(self, name, pname, petsc_option_mapper, sobjs, section_mapper,
126126
inject_solve, function_list):
127127

128-
# TODO: change name to match new name elsewehere
129128
self.petsc_option_mapper = petsc_option_mapper
130129
self.sobjs = sobjs
131130
self.section_mapper = section_mapper
@@ -162,12 +161,15 @@ def __getattr__(self, attr):
162161
# Maps the petsc_option to its generated variable name e.g {'its': its0}
163162
obj_mapper = self.petsc_option_mapper[attr]
164163

165-
# Helper to get the value from the petsc profiling struct
164+
# Helper to get the value from the profiling struct
166165
get_val = lambda v: getattr(self.value._obj, v.name)
167166

168-
# If there's only one value to retrieve for the given attribute, for example, KSPGetIterationNumber
169-
# we return it directly, otherwise we return a dictionary of all values e.g for KSPGetTolerances
170-
# we return {'rtol': val0, 'abstol': val1, ...}
167+
# Return the value(s) for the given PETSc attribute:
168+
# - If there's only one output (e.g., for KSPGetIterationNumber),
169+
# return it directly.
170+
# - If there are multiple outputs (e.g., for KSPGetTolerances),
171+
# return a dictionary mapping each output name to
172+
# its value, e.g., {'rtol': val0, 'abstol': val1, ...}.
171173
if len(obj_mapper) == 1:
172174
return get_val(next(iter(obj_mapper.values())))
173175
return {k: get_val(v) for k, v in obj_mapper.items()}

0 commit comments

Comments
 (0)