Skip to content

Commit 53b58a9

Browse files
committed
misc: Clean up
1 parent 8d10b87 commit 53b58a9

2 files changed

Lines changed: 59 additions & 33 deletions

File tree

devito/petsc/iet/routines.py

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from devito.petsc.types import PETScArray, PetscBundle
1717
from devito.petsc.iet.nodes import (PETScCallable, FormFunctionCallback,
1818
MatShellSetOp, PetscMetaData)
19-
from devito.petsc.iet.utils import petsc_call, petsc_struct, zero_vector
19+
from devito.petsc.iet.utils import (petsc_call, petsc_struct, zero_vector,
20+
dereference_funcs, residual_bundle)
2021
from devito.petsc.utils import solver_mapper
2122
from devito.petsc.types import (DM, Mat, CallbackVec, Vec, KSP, PC, SNES,
2223
PetscInt, StartPtr, PointerIS, PointerDM, VecScatter,
@@ -120,7 +121,7 @@ def _make_core(self):
120121
self._make_user_struct_callback()
121122

122123
def _make_matvec(self, jacobian, prefix='MatMult'):
123-
# Compile matvec `eqns` into an IET via recursive compilation
124+
# Compile `matvecs` into an IET via recursive compilation
124125
matvecs = jacobian.matvecs
125126
irs, _ = self.rcompile(
126127
matvecs, options={'mpi': False}, sregistry=self.sregistry,
@@ -251,7 +252,7 @@ def _create_matvec_body(self, body, jacobian):
251252
)
252253

253254
# Dereference function data in struct
254-
derefs = self.dereference_funcs(ctx, fields)
255+
derefs = dereference_funcs(ctx, fields)
255256

256257
body = CallableBody(
257258
List(body=body),
@@ -390,7 +391,7 @@ def _create_formfunc_body(self, body):
390391
)
391392

392393
# Dereference function data in struct
393-
derefs = self.dereference_funcs(ctx, fields)
394+
derefs = dereference_funcs(ctx, fields)
394395

395396
body = CallableBody(
396397
List(body=body),
@@ -500,7 +501,7 @@ def _create_form_rhs_body(self, body):
500501
)
501502

502503
# Dereference function data in struct
503-
derefs = self.dereference_funcs(ctx, fields)
504+
derefs = dereference_funcs(ctx, fields)
504505

505506
body = CallableBody(
506507
List(body=[body]),
@@ -578,7 +579,7 @@ def _create_initial_guess_body(self, body):
578579
)
579580

580581
# Dereference function data in struct
581-
derefs = self.dereference_funcs(ctx, fields)
582+
derefs = dereference_funcs(ctx, fields)
582583

583584
body = CallableBody(
584585
List(body=[body]),
@@ -643,12 +644,6 @@ def _uxreplace_efuncs(self):
643644
mapper.update({k: visitor.visit(v)})
644645
return mapper
645646

646-
def dereference_funcs(self, struct, fields):
647-
return tuple(
648-
[Dereference(i, struct) for i in
649-
fields if isinstance(i.function, AbstractFunction)]
650-
)
651-
652647

653648
class CCBBuilder(CBBuilder):
654649
def __init__(self, **kwargs):
@@ -749,17 +744,17 @@ def _whole_matvec_body(self):
749744

750745
def _make_whole_formfunc(self):
751746
F_exprs = self.fielddata.residual.F_exprs
752-
# Compile formfunc `eqns` into an IET via recursive compilation
753-
irs_formfunc, _ = self.rcompile(
747+
# Compile `F_exprs` into an IET via recursive compilation
748+
irs, _ = self.rcompile(
754749
F_exprs, options={'mpi': False}, sregistry=self.sregistry,
755750
concretize_mapper=self.concretize_mapper
756751
)
757-
body_formfunc = self._whole_formfunc_body(List(body=irs_formfunc.uiet.body))
752+
body = self._whole_formfunc_body(List(body=irs.uiet.body))
758753

759754
objs = self.objs
760755
cb = PETScCallable(
761756
self.sregistry.make_name(prefix='WholeFormFunc'),
762-
body_formfunc,
757+
body,
763758
retval=objs['err'],
764759
parameters=(objs['snes'], objs['X'], objs['F'], objs['dummyptr'])
765760
)
@@ -783,7 +778,8 @@ def _whole_formfunc_body(self, body):
783778
bundles = sobjs['bundles']
784779
fbundle = bundles['f']
785780
xbundle = bundles['x']
786-
body = self.residual_bundle(body, bundles)
781+
782+
body = residual_bundle(body, bundles)
787783

788784
dm_cast = DummyExpr(dmda, DMCast(objs['dummyptr']), init=True)
789785

@@ -870,7 +866,7 @@ def _whole_formfunc_body(self, body):
870866
)
871867

872868
# Dereference function data in struct
873-
derefs = self.dereference_funcs(ctx, fields)
869+
derefs = dereference_funcs(ctx, fields)
874870

875871
f_soa = PointerCast(fbundle)
876872
x_soa = PointerCast(xbundle)
@@ -1034,21 +1030,6 @@ def _submat_callback_body(self):
10341030
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
10351031
)
10361032

1037-
def residual_bundle(self, body, bundles):
1038-
mapper = bundles['bundle_mapper']
1039-
indexeds = FindSymbols('indexeds').visit(body)
1040-
subs = {}
1041-
1042-
for i in indexeds:
1043-
if i.base in mapper:
1044-
bundle = mapper[i.base]
1045-
index = bundles['target_indices'][i.function.target]
1046-
index = (index,) + i.indices
1047-
subs[i] = bundle.__getitem__(index)
1048-
1049-
body = Uxreplace(subs).visit(body)
1050-
return body
1051-
10521033

10531034
class BaseObjectBuilder:
10541035
"""

devito/petsc/iet/utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from devito.petsc.iet.nodes import PetscMetaData, PETScCall
22
from devito.ir.equations import OpPetsc
3+
from devito.ir.iet import Dereference, FindSymbols, Uxreplace
4+
from devito.types.basic import AbstractFunction
35

46

57
def petsc_call(specific_call, call_args):
@@ -19,9 +21,52 @@ def petsc_struct(name, fields, pname, liveness='lazy', modifier=None):
1921

2022

2123
def zero_vector(vec):
24+
"""
25+
Set all entries of a PETSc vector to zero.
26+
"""
2227
return petsc_call('VecSet', [vec, 0.0])
2328

2429

30+
def dereference_funcs(struct, fields):
31+
"""
32+
Dereference AbstractFunctions from a struct.
33+
"""
34+
return tuple(
35+
[Dereference(i, struct) for i in
36+
fields if isinstance(i.function, AbstractFunction)]
37+
)
38+
39+
40+
def residual_bundle(body, bundles):
41+
"""
42+
Replaces PetscArrays in `body` with PetscBundle struct field accesses
43+
(e.g., f_v[ix][iy] -> f_bundle[ix][iy].v).
44+
45+
Example:
46+
f_v[ix][iy] = x_v[ix][iy];
47+
f_u[ix][iy] = x_u[ix][iy];
48+
becomes:
49+
f_bundle[ix][iy].v = x_bundle[ix][iy].v;
50+
f_bundle[ix][iy].u = x_bundle[ix][iy].u;
51+
52+
NOTE: This is used because the data is interleaved for
53+
multi-component DMDAs in PETSc.
54+
"""
55+
mapper = bundles['bundle_mapper']
56+
indexeds = FindSymbols('indexeds').visit(body)
57+
subs = {}
58+
59+
for i in indexeds:
60+
if i.base in mapper:
61+
bundle = mapper[i.base]
62+
index = bundles['target_indices'][i.function.target]
63+
index = (index,) + i.indices
64+
subs[i] = bundle.__getitem__(index)
65+
66+
body = Uxreplace(subs).visit(body)
67+
return body
68+
69+
2570
# Mapping special Eq operations to their corresponding IET Expression subclass types.
2671
# These operations correspond to subclasses of Eq utilised within PETScSolve.
2772
petsc_iet_mapper = {OpPetsc: PetscMetaData}

0 commit comments

Comments
 (0)