Skip to content

Commit a4863c8

Browse files
committed
misc: Drop petsc/iet/utils.py
1 parent 995dce0 commit a4863c8

10 files changed

Lines changed: 75 additions & 79 deletions

File tree

devito/petsc/iet/callback_builder.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,17 @@
33
from devito.ir.iet import (Call, FindSymbols, List, Uxreplace, CallableBody,
44
Dereference, DummyExpr, BlankLine, Callable, Iteration,
55
PointerCast, Definition)
6-
from devito.symbolics import (Byref, FieldFromPointer, IntDiv, Deref, Mod, String, Null)
6+
from devito.symbolics import (Byref, FieldFromPointer, IntDiv, Deref, Mod, String, Null, VOID)
77
from devito.symbolics.unevaluation import Mul
88
from devito.types.basic import AbstractFunction
9-
from devito.types import Dimension
9+
from devito.types import Dimension, Temp, TempArray
1010
from devito.tools import filter_ordered
1111

12-
from devito.petsc.iet.nodes import PETScCallable, MatShellSetOp
13-
from devito.petsc.iet.utils import (petsc_call, void, get_user_struct_fields)
12+
from devito.petsc.iet.nodes import PETScCallable, MatShellSetOp, petsc_call
1413
from devito.petsc.types import DMCast, MainUserStruct, CallbackUserStruct
1514
from devito.petsc.iet.object_builder import objs
1615
from devito.petsc.types.macros import petsc_func_begin_user
17-
from devito.petsc.types.strings import InsertMode
16+
from devito.petsc.types.modes import InsertMode
1817

1918

2019
class BaseCallback:
@@ -226,12 +225,11 @@ def _create_matvec_body(self, body, jacobian):
226225
)
227226

228227
global_to_local_begin = petsc_call(
229-
'DMGlobalToLocalBegin', [dmda, objs['X'],
230-
InsertMode.insert_values, xlocal]
228+
'DMGlobalToLocalBegin', [dmda, objs['X'], insert_values, xlocal]
231229
)
232230

233231
global_to_local_end = petsc_call('DMGlobalToLocalEnd', [
234-
dmda, objs['X'], InsertMode.insert_values, xlocal
232+
dmda, objs['X'], insert_values, xlocal
235233
])
236234

237235
dm_get_local_yvec = petsc_call(
@@ -261,11 +259,11 @@ def _create_matvec_body(self, body, jacobian):
261259
)
262260

263261
dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [
264-
dmda, ylocal, InsertMode.add_values, objs['Y']
262+
dmda, ylocal, add_values, objs['Y']
265263
])
266264

267265
dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [
268-
dmda, ylocal, InsertMode.add_values, objs['Y']
266+
dmda, ylocal, add_values, objs['Y']
269267
])
270268

271269
dm_restore_local_xvec = petsc_call(
@@ -373,13 +371,12 @@ def _create_formfunc_body(self, body):
373371
)
374372

375373
global_to_local_begin = petsc_call(
376-
'DMGlobalToLocalBegin', [dmda, objs['X'],
377-
InsertMode.insert_values, objs['xloc']]
374+
'DMGlobalToLocalBegin', [dmda, objs['X'], insert_values, objs['xloc']]
378375
)
379376

380-
global_to_local_end = petsc_call('DMGlobalToLocalEnd', [
381-
dmda, objs['X'], InsertMode.insert_values, objs['xloc']
382-
])
377+
global_to_local_end = petsc_call(
378+
'DMGlobalToLocalEnd', [dmda, objs['X'], insert_values, objs['xloc']]
379+
)
383380

384381
dm_get_local_yvec = petsc_call(
385382
'DMGetLocalVector', [dmda, Byref(objs['floc'])]
@@ -406,11 +403,11 @@ def _create_formfunc_body(self, body):
406403
)
407404

408405
dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [
409-
dmda, objs['floc'], InsertMode.add_values, objs['F']
406+
dmda, objs['floc'], add_values, objs['F']
410407
])
411408

412409
dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [
413-
dmda, objs['floc'], InsertMode.add_values, objs['F']
410+
dmda, objs['floc'], add_values, objs['F']
414411
])
415412

416413
dm_restore_local_xvec = petsc_call(
@@ -490,14 +487,12 @@ def _create_form_rhs_body(self, body):
490487
)
491488

492489
dm_global_to_local_begin = petsc_call(
493-
'DMGlobalToLocalBegin', [dmda, objs['B'],
494-
InsertMode.insert_values, sobjs['blocal']]
490+
'DMGlobalToLocalBegin', [dmda, objs['B'], insert_values, sobjs['blocal']]
495491
)
496492

497-
dm_global_to_local_end = petsc_call('DMGlobalToLocalEnd', [
498-
dmda, objs['B'], InsertMode.insert_values,
499-
sobjs['blocal']
500-
])
493+
dm_global_to_local_end = petsc_call(
494+
'DMGlobalToLocalEnd', [dmda, objs['B'], insert_values, sobjs['blocal']]
495+
)
501496

502497
b_arr = self.field_data.arrays[target]['b']
503498

@@ -519,13 +514,11 @@ def _create_form_rhs_body(self, body):
519514
)
520515

521516
dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [
522-
dmda, sobjs['blocal'], InsertMode.insert_values,
523-
objs['B']
517+
dmda, sobjs['blocal'], insert_values, objs['B']
524518
])
525519

526520
dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [
527-
dmda, sobjs['blocal'], InsertMode.insert_values,
528-
objs['B']
521+
dmda, sobjs['blocal'], insert_values, objs['B']
529522
])
530523

531524
vec_restore_array = petsc_call(
@@ -822,13 +815,12 @@ def _whole_formfunc_body(self, body):
822815
'DMGetLocalVector', [dmda, Byref(objs['xloc'])]
823816
)
824817

825-
global_to_local_begin = petsc_call(
826-
'DMGlobalToLocalBegin', [dmda, objs['X'],
827-
InsertMode.insert_values, objs['xloc']]
828-
)
818+
global_to_local_begin = petsc_call('DMGlobalToLocalBegin', [
819+
dmda, objs['X'], insert_values, objs['xloc']
820+
])
829821

830822
global_to_local_end = petsc_call('DMGlobalToLocalEnd', [
831-
dmda, objs['X'], InsertMode.insert_values, objs['xloc']
823+
dmda, objs['X'], insert_values, objs['xloc']
832824
])
833825

834826
dm_get_local_yvec = petsc_call(
@@ -856,11 +848,11 @@ def _whole_formfunc_body(self, body):
856848
)
857849

858850
dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [
859-
dmda, objs['floc'], InsertMode.add_values, objs['F']
851+
dmda, objs['floc'], add_values, objs['F']
860852
])
861853

862854
dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [
863-
dmda, objs['floc'], InsertMode.add_values, objs['F']
855+
dmda, objs['floc'], add_values, objs['F']
864856
])
865857

866858
dm_restore_local_xvec = petsc_call(
@@ -1033,7 +1025,7 @@ def _submat_callback_body(self):
10331025
[
10341026
objs['submat_arr'].indexed[sb.linear_idx],
10351027
'MATOP_MULT',
1036-
MatShellSetOp(matvec_lookup[sb.name].name, void, void),
1028+
MatShellSetOp(matvec_lookup[sb.name].name, VOID._dtype, VOID._dtype),
10371029
],
10381030
)
10391031
for sb in nonzero_submats if sb.name in matvec_lookup
@@ -1120,3 +1112,18 @@ def zero_vector(vec):
11201112
Set all entries of a PETSc vector to zero.
11211113
"""
11221114
return petsc_call('VecSet', [vec, 0.0])
1115+
1116+
1117+
def get_user_struct_fields(iet):
1118+
fields = [f.function for f in FindSymbols('basics').visit(iet)]
1119+
from devito.types.basic import LocalType
1120+
avoid = (Temp, TempArray, LocalType)
1121+
fields = [f for f in fields if not isinstance(f.function, avoid)]
1122+
fields = [
1123+
f for f in fields if not (f.is_Dimension and not (f.is_Time or f.is_Modulo))
1124+
]
1125+
return fields
1126+
1127+
1128+
insert_values = InsertMode.insert_values
1129+
add_values = InsertMode.add_values

devito/petsc/iet/logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from devito.logger import PERF
66
from devito.tools import frozendict
77

8-
from devito.petsc.iet.utils import petsc_call
8+
from devito.petsc.iet.nodes import petsc_call
99
from devito.petsc.logging import petsc_return_variable_dict, PetscInfo
1010

1111

devito/petsc/iet/nodes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class PETScCall(Call):
3131
pass
3232

3333

34+
def petsc_call(specific_call, call_args):
35+
return PETScCall('PetscCall', [PETScCall(specific_call, arguments=call_args)])
36+
37+
3438
# Mapping special Eq operations to their corresponding IET Expression subclass types.
3539
# These operations correspond to subclasses of Eq utilised within PETScSolve.
3640
petsc_iet_mapper = {OpPetsc: PetscMetaData}

devito/petsc/iet/passes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,16 @@
1616
CallbackUserStruct
1717
)
1818
from devito.petsc.types.macros import petsc_func_begin_user
19-
from devito.petsc.iet.nodes import PetscMetaData
19+
from devito.petsc.iet.nodes import PetscMetaData, petsc_call
2020
from devito.petsc.utils import core_metadata, petsc_languages
2121
from devito.petsc.iet.callback_builder import (
22-
BaseCallback, CoupledCallback, populate_matrix_context
22+
BaseCallback, CoupledCallback, populate_matrix_context, get_user_struct_fields
2323
)
2424
from devito.petsc.iet.object_builder import BaseObjectBuilder, CoupledObjectBuilder, objs
2525
from devito.petsc.iet.pre_solver import BaseSetup, CoupledSetup, make_core_petsc_calls
2626
from devito.petsc.iet.run_solver import Solver, CoupledSolver
2727
from devito.petsc.iet.time_dependence import TimeDependent, TimeIndependent
2828
from devito.petsc.iet.logging import PetscLogger
29-
from devito.petsc.iet.utils import petsc_call, get_user_struct_fields
3029

3130

3231
@iet_pass

devito/petsc/iet/pre_solver.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from devito.symbolics import (Byref, FieldFromPointer, VOID,
55
FieldFromComposite, Null)
66

7-
from devito.petsc.iet.nodes import FormFunctionCallback, MatShellSetOp, PETScCall
8-
from devito.petsc.iet.utils import petsc_call
7+
from devito.petsc.iet.nodes import (
8+
FormFunctionCallback, MatShellSetOp, PETScCall, petsc_call
9+
)
910

1011

1112
def make_core_petsc_calls(objs, comm):
@@ -82,13 +83,11 @@ def _setup(self):
8283
[sobjs['snes'], Byref(sobjs['ksp'])])
8384

8485
matvec = self.cbbuilder.main_matvec_callback
85-
matvec_operation = petsc_call(
86-
'MatShellSetOperation',
86+
matvec_operation = petsc_call('MatShellSetOperation',
8787
[sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)]
8888
)
8989
formfunc = self.cbbuilder._F_efunc
90-
formfunc_operation = petsc_call(
91-
'SNESSetFunction',
90+
formfunc_operation = petsc_call('SNESSetFunction',
9291
[sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void),
9392
self.snes_ctx]
9493
)
@@ -337,4 +336,4 @@ def petsc_call_mpi(specific_call, call_args):
337336
return PETScCall('PetscCallMPI', [PETScCall(specific_call, arguments=call_args)])
338337

339338

340-
void = 'void'
339+
void = VOID._dtype

devito/petsc/iet/run_solver.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
)
66
from devito.symbolics import (Byref, Null)
77

8-
from devito.petsc.iet.nodes import PetscMetaData
9-
from devito.petsc.types.strings import InsertMode, ScatterMode
10-
from devito.petsc.iet.utils import petsc_call
8+
from devito.petsc.iet.nodes import PetscMetaData, petsc_call
9+
from devito.petsc.types.modes import InsertMode, ScatterMode
1110

1211

1312
class Solver:
@@ -45,7 +44,7 @@ def _execute_solve(self):
4544
initguess_call = None
4645

4746
dm_local_to_global_x = petsc_call(
48-
'DMLocalToGlobal', [dmda, sobjs['xlocal'], InsertMode.insert_values,
47+
'DMLocalToGlobal', [dmda, sobjs['xlocal'], insert_values,
4948
sobjs['xglobal']]
5049
)
5150

@@ -54,7 +53,7 @@ def _execute_solve(self):
5453
)
5554

5655
dm_global_to_local_x = petsc_call('DMGlobalToLocal', [
57-
dmda, sobjs['xglobal'], InsertMode.insert_values, sobjs['xlocal']]
56+
dmda, sobjs['xglobal'], insert_values, sobjs['xlocal']]
5857
)
5958

6059
vec_reset_array = self.time_dependence.reset_array(target)
@@ -114,38 +113,44 @@ def _execute_solve(self):
114113
self.time_dependence.place_array(t),
115114
petsc_call(
116115
'DMLocalToGlobal',
117-
[dm, target_xloc, InsertMode.insert_values, target_xglob]
116+
[dm, target_xloc, insert_values, target_xglob]
118117
),
119118
petsc_call(
120119
'VecScatterCreate',
121120
[xglob, field, target_xglob, Null, Byref(s)]
122121
),
123122
petsc_call(
124123
'VecScatterBegin',
125-
[s, target_xglob, xglob, InsertMode.insert_values, ScatterMode.scatter_reverse]
124+
[s, target_xglob, xglob, insert_values, scatter_reverse]
126125
),
127126
petsc_call(
128127
'VecScatterEnd',
129-
[s, target_xglob, xglob, InsertMode.insert_values, ScatterMode.scatter_reverse]
128+
[s, target_xglob, xglob, insert_values, scatter_reverse]
130129
),
131130
BlankLine,
132131
)
133132

134133
post_solve += (
135134
petsc_call(
136135
'VecScatterBegin',
137-
[s, xglob, target_xglob, InsertMode.insert_values, ScatterMode.scatter_forward]
136+
[s, xglob, target_xglob, insert_values, scatter_forward]
138137
),
139138
petsc_call(
140139
'VecScatterEnd',
141-
[s, xglob, target_xglob, InsertMode.insert_values, ScatterMode.scatter_forward]
140+
[s, xglob, target_xglob, insert_values, scatter_forward]
142141
),
143142
petsc_call(
144143
'DMGlobalToLocal',
145-
[dm, target_xglob, InsertMode.insert_values, target_xloc]
144+
[dm, target_xglob, insert_values, target_xloc]
146145
)
147146
)
148147

149148
snes_solve = (petsc_call('SNESSolve', [sobjs['snes'], Null, xglob]),)
150149

151150
return (struct_assignment,) + pre_solve + snes_solve + post_solve + (BlankLine,)
151+
152+
153+
insert_values = InsertMode.insert_values
154+
add_values = InsertMode.add_values
155+
scatter_reverse = ScatterMode.scatter_reverse
156+
scatter_forward = ScatterMode.scatter_forward

devito/petsc/iet/time_dependence.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from devito.ir.iet import Uxreplace, DummyExpr
44
from devito.symbolics import FieldFromPointer, cast, FieldFromComposite
55
from devito.symbolics.unevaluation import Mul
6-
from devito.petsc.iet.utils import petsc_call
6+
7+
from devito.petsc.iet.nodes import petsc_call
78

89

910
class TimeIndependent:

devito/petsc/iet/utils.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

devito/petsc/types/modes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ class ScatterMode:
1414
Reference - https://petsc.org/release/manualpages/Vec/ScatterMode/
1515
"""
1616
scatter_reverse = 'SCATTER_REVERSE'
17-
scatter_forward = 'SCATTER_FORWARD'
17+
scatter_forward = 'SCATTER_FORWARD'

devito/petsc/types/object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from devito.symbolics import Byref, cast
77
from devito.types.basic import DataSymbol, LocalType
88

9-
from devito.petsc.iet.utils import petsc_call
9+
from devito.petsc.iet.nodes import petsc_call
1010

1111

1212
class PetscMixin:

0 commit comments

Comments
 (0)