Skip to content

Commit d6bd4f4

Browse files
committed
misc: Address comments
1 parent 01d153a commit d6bd4f4

8 files changed

Lines changed: 100 additions & 56 deletions

File tree

devito/passes/iet/engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,6 @@ def _(i, mapper, sregistry):
530530
})
531531

532532

533-
@abstract_object.register(Array)
534533
@abstract_object.register(ArrayBasic)
535534
def _(i, mapper, sregistry):
536535
if isinstance(i, Lock):

devito/petsc/iet/routines.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def fielddata(self):
114114
def arrays(self):
115115
return self.fielddata.arrays
116116

117+
@property
118+
def target(self):
119+
return self.fielddata.target
120+
117121
def _make_core(self):
118122
self._make_matvec(self.fielddata.jacobian)
119123
self._make_formfunc()
@@ -163,9 +167,7 @@ def _create_matvec_body(self, body, jacobian):
163167
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
164168
)
165169

166-
zero_y_memory = petsc_call(
167-
'VecSet', [objs['Y'], 0.0]
168-
) if self.zero_memory else None
170+
zero_y_memory = self.zero_vector(objs['Y'])
169171

170172
dm_get_local_xvec = petsc_call(
171173
'DMGetLocalVector', [dmda, Byref(xlocal)]
@@ -297,7 +299,8 @@ def _create_formfunc_body(self, body):
297299
linsolve_expr = self.injectsolve.expr.rhs
298300
objs = self.objs
299301
sobjs = self.solver_objs
300-
target = self.fielddata.target
302+
arrays = self.arrays
303+
target = self.target
301304

302305
dmda = sobjs['callbackdm']
303306
ctx = objs['dummyctx']
@@ -307,18 +310,16 @@ def _create_formfunc_body(self, body):
307310
fields = self._dummy_fields(body)
308311
self._struct_params.extend(fields)
309312

310-
f_formfunc = self.fielddata.arrays[target]['f']
311-
x_formfunc = self.fielddata.arrays[target]['x']
313+
f_formfunc = arrays[target]['f']
314+
x_formfunc = arrays[target]['x']
312315

313316
dm_cast = DummyExpr(dmda, DMCast(objs['dummyptr']), init=True)
314317

315318
dm_get_app_context = petsc_call(
316319
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
317320
)
318321

319-
zero_f_memory = petsc_call(
320-
'VecSet', [objs['F'], 0.0]
321-
) if self.zero_memory else None
322+
zero_f_memory = self.zero_vector(objs['F'])
322323

323324
dm_get_local_xvec = petsc_call(
324325
'DMGetLocalVector', [dmda, Byref(objs['xloc'])]
@@ -437,7 +438,7 @@ def _create_form_rhs_body(self, body):
437438
linsolve_expr = self.injectsolve.expr.rhs
438439
objs = self.objs
439440
sobjs = self.solver_objs
440-
target = self.fielddata.target
441+
target = self.target
441442

442443
dmda = sobjs['callbackdm']
443444
ctx = objs['dummyctx']
@@ -508,13 +509,15 @@ def _create_form_rhs_body(self, body):
508509
)
509510

510511
# Dereference function data in struct
511-
dereference_funcs = [Dereference(i, ctx) for i in
512-
fields if isinstance(i.function, AbstractFunction)]
512+
dereference_funcs = tuple(
513+
[Dereference(i, ctx) for i in
514+
fields if isinstance(i.function, AbstractFunction)]
515+
)
513516

514517
formrhs_body = CallableBody(
515518
List(body=[body]),
516519
init=(objs['begin_user'],),
517-
stacks=stacks+tuple(dereference_funcs),
520+
stacks=stacks+dereference_funcs,
518521
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
519522
)
520523

@@ -550,7 +553,7 @@ def _create_initial_guess_body(self, body):
550553
linsolve_expr = self.injectsolve.expr.rhs
551554
objs = self.objs
552555
sobjs = self.solver_objs
553-
target = self.fielddata.target
556+
target = self.target
554557

555558
dmda = sobjs['callbackdm']
556559
ctx = objs['dummyctx']
@@ -653,6 +656,12 @@ def _uxreplace_efuncs(self):
653656
mapper.update({k: visitor.visit(v)})
654657
return mapper
655658

659+
def zero_vector(self, vec):
660+
"""
661+
Zeros the memory of the output vector before computation
662+
"""
663+
return petsc_call('VecSet', [vec, 0.0]) if self.zero_memory else None
664+
656665

657666
class CCBBuilder(CBBuilder):
658667
def __init__(self, **kwargs):
@@ -1059,7 +1068,7 @@ def residual_bundle(self, body, bundles):
10591068
if i.base in mapper:
10601069
bundle = mapper[i.base]
10611070
index = bundles['target_indices'][i.function.target]
1062-
index = (index,)+i.indices
1071+
index = (index,) + i.indices
10631072
subs[i] = bundle.__getitem__(index)
10641073

10651074
body = Uxreplace(subs).visit(body)
@@ -1178,9 +1187,9 @@ def _extend_build(self, base_dict):
11781187
base_dict[f'{name}F'] = CallbackVec(f'{name}F')
11791188

11801189
# Bundle objects/metadata required by the coupled residual callback
1181-
f_components = []
1182-
x_components = []
1190+
f_components, x_components = [], []
11831191
bundle_mapper = {}
1192+
pname = sreg.make_name(prefix='Field')
11841193

11851194
target_indices = {t: i for i, t in enumerate(targets)}
11861195

@@ -1190,18 +1199,15 @@ def _extend_build(self, base_dict):
11901199
f_components.append(f_arr)
11911200
x_components.append(x_arr)
11921201

1193-
bundle_pname = sreg.make_name(prefix='Field')
11941202
fbundle = PetscBundle(
1195-
name='f_bundle', components=f_components, pname=bundle_pname
1203+
name='f_bundle', components=f_components, pname=pname
11961204
)
11971205
xbundle = PetscBundle(
1198-
name='x_bundle', components=x_components, pname=bundle_pname
1206+
name='x_bundle', components=x_components, pname=pname
11991207
)
12001208

12011209
# Build the bundle mapper
1202-
for i, t in enumerate(targets):
1203-
f_arr = arrays[t]['f']
1204-
x_arr = arrays[t]['x']
1210+
for f_arr, x_arr in zip(f_components, x_components):
12051211
bundle_mapper[f_arr.base] = fbundle
12061212
bundle_mapper[x_arr.base] = xbundle
12071213

devito/petsc/solve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def linear_solve_args(self):
8383
self.time_mapper = generate_time_mapper(funcs)
8484
arrays = self.generate_arrays_combined(target)
8585

86-
eqns = sorted(eqns, key=lambda e: 0 if isinstance(e, EssentialBC) else 1)
86+
eqns = sorted(eqns, key=lambda e: not isinstance(e, EssentialBC))
8787

8888
jacobian = Jacobian(target, eqns, arrays, self.time_mapper)
8989
residual = Residual(target, eqns, arrays, self.time_mapper, jacobian.scdiag)

devito/petsc/types/array.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,34 @@ def symbolic_shape(self):
126126

127127
class PetscBundle(Bundle):
128128
"""
129+
Tensor symbol representing an unrolled vector of PETScArrays.
130+
131+
This class declares a struct in the generated ccode to represent the
132+
fields defined at each node of the grid. For example:
133+
134+
typedef struct {
135+
PetscScalar u,v,omega,temperature;
136+
} Field;
137+
138+
Residual evaluations are then written using:
139+
140+
f[i][j].omega = ...
141+
142+
Reference - https://petsc.org/release/manual/vec/#sec-struct
143+
144+
Parameters
145+
----------
146+
name : str
147+
Name of the symbol.
148+
components : tuple of PETScArray
149+
The PETScArrays of the Bundle.
150+
pname : str, optional
151+
The name of the struct in the generated C code. Defaults to "Field".
152+
153+
Warnings
154+
--------
155+
PetscBundles are created and managed directly by Devito (IOW, they are not
156+
expected to be used directly in user code).
129157
"""
130158
is_Bundle = True
131159
_data_alignment = False
@@ -136,11 +164,11 @@ def __init__(self, *args, pname="Field", **kwargs):
136164
super().__init__(*args, **kwargs)
137165
self._pname = pname
138166

139-
@property
167+
@cached_property
140168
def _C_ctype(self):
141169
fields = [(i.target.name, dtype_to_ctype(i.dtype)) for i in self.components]
142170
return POINTER(type(self.pname, (Structure,), {'_fields_': fields}))
143-
171+
144172
@cached_property
145173
def symbolic_shape(self):
146174
return self.c0.symbolic_shape
@@ -176,8 +204,10 @@ def __getitem__(self, index):
176204
component_names=names
177205
)
178206
else:
179-
raise ValueError("Expected %d or %d indices, got %d instead"
180-
% (self.ndim, self.ndim + 1, len(index)))
207+
raise ValueError(
208+
f"Expected {self.ndim} or {self.ndim + 1} indices, "
209+
f"got {len(index)} instead"
210+
)
181211

182212
@property
183213
def pname(self):
@@ -187,7 +217,7 @@ def pname(self):
187217
class PetscComponentAccess(ComponentAccess):
188218
def __new__(cls, arg, index=0, component_names=None, **kwargs):
189219
if not arg.is_Indexed:
190-
raise ValueError("Expected Indexed, got `%s` instead" % type(arg))
220+
raise ValueError(f"Expected Indexed, got `{type(arg)}` instead")
191221
names = component_names or cls._default_component_names
192222

193223
obj = Expr.__new__(cls, arg)

devito/petsc/types/types.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from itertools import chain
44

5-
from devito.tools import Reconstructable, sympy_mutex, as_tuple
5+
from devito.tools import Reconstructable, sympy_mutex, as_tuple, frozendict
66
from devito.tools.dtypes_lowering import dtype_mapper
77
from devito.petsc.utils import petsc_variables
88
from devito.symbolics.extraction import separate_eqn, generate_targets, centre_stencil
@@ -269,11 +269,11 @@ def col_target(self):
269269
return self.target
270270

271271
def _build_matvecs(self):
272-
matvecs = [
273-
e for eq in self.eqns for e in
274-
self._build_matvec_eq(eq)
275-
if e is not None
276-
]
272+
matvecs = []
273+
for eq in self.eqns:
274+
matvecs.extend(
275+
e for e in self._build_matvec_eq(eq) if e is not None
276+
)
277277
matvecs = tuple(sorted(matvecs, key=lambda e: not isinstance(e, EssentialBC)))
278278

279279
matvecs = self._scale_non_bcs(matvecs)
@@ -287,12 +287,13 @@ def _build_matvec_eq(self, eq, col_target=None, row_target=None):
287287
col_target = col_target or self.target
288288
row_target = row_target or self.target
289289

290-
b, F_target, _, targets = separate_eqn(eq, col_target)
290+
_, F_target, _, targets = separate_eqn(eq, col_target)
291291
if F_target:
292292
return self._make_matvec(
293293
eq, F_target, targets, col_target, row_target
294294
)
295-
return (None,)
295+
else:
296+
return (None,)
296297

297298
def _make_matvec(self, eq, F_target, targets, col_target, row_target):
298299
y = self.arrays[row_target]['y']
@@ -307,7 +308,7 @@ def _make_matvec(self, eq, F_target, targets, col_target, row_target):
307308
else:
308309
rhs = F_target.subs(targets_to_arrays(x, targets))
309310
rhs = rhs.subs(self.time_mapper)
310-
return as_tuple(Eq(y, rhs, subdomain=eq.subdomain))
311+
return (Eq(y, rhs, subdomain=eq.subdomain),)
311312

312313
def _scale_non_bcs(self, matvecs, target=None):
313314
target = target or self.target
@@ -376,7 +377,7 @@ class MixedJacobian(Jacobian):
376377
def __init__(self, target_eqns, arrays, time_mapper):
377378
"""
378379
"""
379-
self.targets = as_tuple(target_eqns.keys())
380+
self.targets = tuple(target_eqns.keys())
380381
self.arrays = arrays
381382
self.time_mapper = time_mapper
382383
self._submatrices = []
@@ -391,7 +392,7 @@ def submatrices(self):
391392
return self._submatrices
392393

393394
@property
394-
def no_submatrices(self):
395+
def n_submatrices(self):
395396
"""
396397
Return the number of submatrix blocks.
397398
"""
@@ -421,11 +422,13 @@ def _build_blocks(self, target_eqns):
421422
for i, row_target in enumerate(self.targets):
422423
eqns = target_eqns[row_target]
423424
for j, col_target in enumerate(self.targets):
424-
matvecs = [
425-
e for eq in eqns for e in
426-
self._build_matvec_eq(eq, col_target, row_target)
427-
]
425+
matvecs = []
426+
for eq in eqns:
427+
matvecs.extend(
428+
e for e in self._build_matvec_eq(eq, col_target, row_target)
429+
)
428430
matvecs = [m for m in matvecs if m is not None]
431+
429432
# Sort to put EssentialBC first if any
430433
matvecs = tuple(
431434
sorted(matvecs, key=lambda e: not isinstance(e, EssentialBC))
@@ -461,7 +464,7 @@ def __repr__(self):
461464
f"{sm.name} (row={sm.row_idx}, col={sm.col_idx})"
462465
for sm in self.submatrices
463466
)
464-
return f"<MixedJacobian with {self.no_submatrices} submatrices: [{summary}]>"
467+
return f"<MixedJacobian with {self.n_submatrices} submatrices: [{summary}]>"
465468

466469

467470
class Residual:
@@ -499,12 +502,13 @@ def _build_equations(self):
499502
# TODO: If b is zero then don't need a rhs vector+callback
500503
rhs.extend(self._make_b(eq, b))
501504

502-
self._formfuncs = [self._scale_bcs(eq) for eq in funcs]
503-
self._formrhs = rhs
505+
self._formfuncs = tuple([self._scale_bcs(eq) for eq in funcs])
506+
self._formrhs = tuple(rhs)
504507

505508
def _make_F_target(self, eq, F_target, targets):
506509
arrays = self.arrays[self.target]
507510
volume = self.target.grid.symbolic_volume_cell
511+
508512
if isinstance(eq, EssentialBC):
509513
# The initial guess satisfies the essential BCs, so this term is zero.
510514
# Still included to support Jacobian testing via finite differences.
@@ -513,19 +517,20 @@ def _make_F_target(self, eq, F_target, targets):
513517
# Move essential boundary condition to the right-hand side
514518
zero_col = ZeroColumn(arrays['x'], eq.rhs, subdomain=eq.subdomain)
515519
return (zero_row, zero_col)
520+
516521
else:
517522
if isinstance(F_target, (int, float)):
518523
rhs = F_target * volume
519524
else:
520525
rhs = F_target.subs(targets_to_arrays(arrays['x'], targets))
521526
rhs = rhs.subs(self.time_mapper) * volume
522-
return as_tuple(Eq(arrays['f'], rhs, subdomain=eq.subdomain))
527+
return (Eq(arrays['f'], rhs, subdomain=eq.subdomain),)
523528

524529
def _make_b(self, eq, b):
525530
b_arr = self.arrays[self.target]['b']
526531
rhs = 0. if isinstance(eq, EssentialBC) else b.subs(self.time_mapper)
527532
rhs = rhs * self.target.grid.symbolic_volume_cell
528-
return as_tuple(Eq(b_arr, rhs, subdomain=eq.subdomain))
533+
return (Eq(b_arr, rhs, subdomain=eq.subdomain),)
529534

530535
def _scale_bcs(self, eq, scdiag=None):
531536
"""
@@ -617,11 +622,11 @@ def _build_equations(self):
617622
"""
618623
Return a list of initial guess equations.
619624
"""
620-
self._eqs = [
625+
self._eqs = tuple([
621626
eq for eq in
622627
(self._make_initial_guess(e) for e in self.eqns)
623628
if eq is not None
624-
]
629+
])
625630

626631
def _make_initial_guess(self, eq):
627632
if isinstance(eq, EssentialBC):
@@ -655,4 +660,4 @@ def targets_to_arrays(array, targets):
655660
array_targets = [
656661
array.subs(dict(zip(array.indices, i))) for i in space_indices
657662
]
658-
return dict(zip(targets, array_targets))
663+
return frozendict(zip(targets, array_targets))

0 commit comments

Comments
 (0)