Skip to content

Commit 312c826

Browse files
committed
misc: Add BaseJacobian
1 parent a4a5660 commit 312c826

1 file changed

Lines changed: 131 additions & 85 deletions

File tree

devito/petsc/types/types.py

Lines changed: 131 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,26 @@ def eval(cls, *args):
137137

138138

139139
class FieldData:
140+
"""
141+
Metadata class passed to `LinearSolveExpr`. Encapsulates metadata for a single
142+
`target` field needed to interface with PETSc SNES solvers.
143+
144+
Parameters
145+
----------
146+
147+
target : Function-like
148+
The target field to solve into, which is a Function-like object.
149+
jacobian : Jacobian
150+
Defines the matrix-vector product for the linear system, where the vector is
151+
the PETScArray representing the `target`.
152+
residual : Residual
153+
Defines the nonlinear residual function F(target) = 0.
154+
initialguess : InitialGuess
155+
Defines the initial guess for the solution, which satisfies
156+
essential boundary conditions.
157+
arrays : dict
158+
A dictionary mapping `target` to its corresponding PETScArrays.
159+
"""
140160
def __init__(self, target=None, jacobian=None, residual=None,
141161
initialguess=None, arrays=None, **kwargs):
142162
self._target = target
@@ -190,6 +210,24 @@ def targets(self):
190210

191211

192212
class MultipleFieldData(FieldData):
213+
"""
214+
Metadata class passed to `LinearSolveExpr`, for mixed-field problems,
215+
where the solution vector spans multiple `targets`.
216+
217+
Parameters
218+
----------
219+
targets : list of Function-like
220+
The fields to solve into, each represented by a Function-like object.
221+
jacobian : MixedJacobian
222+
Defines the matrix-vector products for the full system Jacobian.
223+
residual : MixedResidual
224+
Defines the nonlinear residual function F(targets) = 0.
225+
initialguess : InitialGuess
226+
Defines the initial guess metadata, which satisfies
227+
essential boundary conditions.
228+
arrays : dict
229+
A dictionary mapping the `targets` to their corresponding PETScArrays.
230+
"""
193231
def __init__(self, targets, arrays, jacobian=None, residual=None):
194232
self._targets = as_tuple(targets)
195233
self._arrays = arrays
@@ -208,6 +246,7 @@ def space_dimensions(self):
208246

209247
@property
210248
def grid(self):
249+
"""The unique `Grid` associated with all targets."""
211250
grids = [t.grid for t in self.targets]
212251
if len(set(grids)) > 1:
213252
raise ValueError(
@@ -233,8 +272,70 @@ def space_order(self):
233272
def targets(self):
234273
return self._targets
235274

275+
276+
class BaseJacobian:
277+
def __init__(self, arrays, target=None):
278+
self.arrays = arrays
279+
self.target = target
280+
281+
def _scale_non_bcs(self, matvecs, target=None):
282+
target = target or self.target
283+
vol = target.grid.symbolic_volume_cell
284+
285+
return [
286+
m if isinstance(m, EssentialBC) else m._rebuild(rhs=m.rhs * vol)
287+
for m in matvecs
288+
]
289+
290+
def _compute_scdiag(self, matvecs, col_target=None):
291+
"""
292+
"""
293+
x = self.arrays[col_target or self.target]['x']
294+
295+
centres = {
296+
centre_stencil(m.rhs, x, as_coeff=True)
297+
for m in matvecs if not isinstance(m, EssentialBC)
298+
}
299+
return centres.pop() if len(centres) == 1 else 1.0
300+
301+
def _scale_bcs(self, matvecs, scdiag):
302+
"""
303+
Scale the essential BCs
304+
"""
305+
return [
306+
m._rebuild(rhs=m.rhs * scdiag) if isinstance(m, ZeroRow) else m
307+
for m in matvecs
308+
]
309+
310+
def _build_matvec_expr(self, expr, **kwargs):
311+
col_target = kwargs.get('col_target', self.target)
312+
row_target = kwargs.get('row_target', self.target)
313+
314+
_, F_target, _, targets = separate_eqn(expr, col_target)
315+
if F_target:
316+
return self._make_matvec(
317+
expr, F_target, targets, col_target, row_target
318+
)
319+
else:
320+
return (None,)
321+
322+
def _make_matvec(self, expr, F_target, targets, col_target, row_target):
323+
y = self.arrays[row_target]['y']
324+
x = self.arrays[col_target]['x']
325+
326+
if isinstance(expr, EssentialBC):
327+
# NOTE: Essential BCs are trivial equations in the solver.
328+
# See `EssentialBC` for more details.
329+
zero_row = ZeroRow(y, x, subdomain=expr.subdomain)
330+
zero_column = ZeroColumn(x, 0., subdomain=expr.subdomain)
331+
return (zero_row, zero_column)
332+
else:
333+
rhs = F_target.subs(targets_to_arrays(x, targets))
334+
rhs = rhs.subs(self.time_mapper)
335+
return (Eq(y, rhs, subdomain=expr.subdomain),)
336+
236337

237-
class Jacobian:
338+
class Jacobian(BaseJacobian):
238339
"""
239340
Represents a Jacobian matrix.
240341
@@ -246,14 +347,14 @@ class Jacobian:
246347
require explicit symbolic differentiation.
247348
"""
248349
def __init__(self, target, exprs, arrays, time_mapper):
249-
self.target = target
350+
super().__init__(arrays=arrays, target=target)
250351
self.exprs = exprs
251-
self.arrays = arrays
252352
self.time_mapper = time_mapper
253353
self._build_matvecs()
254354

255355
@property
256356
def matvecs(self):
357+
# TODO: add shortcut explanation etc
257358
"""
258359
Stores the expressions used to generate the `MatMult`
259360
callback generated at the IET level. This function is
@@ -279,7 +380,6 @@ def _build_matvecs(self):
279380
matvecs.extend(
280381
e for e in self._build_matvec_expr(eq) if e is not None
281382
)
282-
283383
matvecs = tuple(sorted(matvecs, key=lambda e: not isinstance(e, EssentialBC)))
284384

285385
matvecs = self._scale_non_bcs(matvecs)
@@ -289,83 +389,8 @@ def _build_matvecs(self):
289389
self._matvecs = matvecs
290390
self._scdiag = scdiag
291391

292-
def _build_matvec_expr(self, expr, col_target=None, row_target=None):
293-
col_target = col_target or self.target
294-
row_target = row_target or self.target
295392

296-
_, F_target, _, targets = separate_eqn(expr, col_target)
297-
if F_target:
298-
return self._make_matvec(
299-
expr, F_target, targets, col_target, row_target
300-
)
301-
else:
302-
return (None,)
303-
304-
def _make_matvec(self, expr, F_target, targets, col_target, row_target):
305-
y = self.arrays[row_target]['y']
306-
x = self.arrays[col_target]['x']
307-
308-
if isinstance(expr, EssentialBC):
309-
# NOTE: Essential BCs are trivial equations in the solver.
310-
# See `EssentialBC` for more details.
311-
zero_row = ZeroRow(y, x, subdomain=expr.subdomain)
312-
zero_column = ZeroColumn(x, 0., subdomain=expr.subdomain)
313-
return (zero_row, zero_column)
314-
else:
315-
rhs = F_target.subs(targets_to_arrays(x, targets))
316-
rhs = rhs.subs(self.time_mapper)
317-
return (Eq(y, rhs, subdomain=expr.subdomain),)
318-
319-
def _scale_non_bcs(self, matvecs, target=None):
320-
target = target or self.target
321-
vol = target.grid.symbolic_volume_cell
322-
323-
return [
324-
m if isinstance(m, EssentialBC) else m._rebuild(rhs=m.rhs * vol)
325-
for m in matvecs
326-
]
327-
328-
def _compute_scdiag(self, matvecs, col_target=None):
329-
"""
330-
"""
331-
x = self.arrays[col_target or self.target]['x']
332-
333-
centres = {
334-
centre_stencil(m.rhs, x, as_coeff=True)
335-
for m in matvecs if not isinstance(m, EssentialBC)
336-
}
337-
return centres.pop() if len(centres) == 1 else 1.0
338-
339-
def _scale_bcs(self, matvecs, scdiag):
340-
"""
341-
Scale the essential BCs
342-
"""
343-
return [
344-
m._rebuild(rhs=m.rhs * scdiag) if isinstance(m, ZeroRow) else m
345-
for m in matvecs
346-
]
347-
348-
349-
class SubMatrixBlock:
350-
def __init__(self, name, matvecs, scdiag, row_target,
351-
col_target, row_idx, col_idx, linear_idx):
352-
self.name = name
353-
self.matvecs = matvecs
354-
self.scdiag = scdiag
355-
self.row_target = row_target
356-
self.col_target = col_target
357-
self.row_idx = row_idx
358-
self.col_idx = col_idx
359-
self.linear_idx = linear_idx
360-
361-
def is_diag(self):
362-
return self.row_idx == self.col_idx
363-
364-
def __repr__(self):
365-
return (f"<SubMatrixBlock {self.name}>")
366-
367-
368-
class MixedJacobian(Jacobian):
393+
class MixedJacobian(BaseJacobian):
369394
"""
370395
Represents a Jacobian for a linear system with a solution vector
371396
composed of multiple fields (targets).
@@ -380,12 +405,12 @@ class MixedJacobian(Jacobian):
380405
381406
# TODO: pcfieldsplit support for each block
382407
"""
383-
def __init__(self, target_eqns, arrays, time_mapper):
384-
self.targets = tuple(target_eqns.keys())
385-
self.arrays = arrays
408+
def __init__(self, target_exprs, arrays, time_mapper):
409+
super().__init__(arrays=arrays, target=None)
410+
self.targets = tuple(target_exprs.keys())
386411
self.time_mapper = time_mapper
387412
self._submatrices = []
388-
self._build_blocks(target_eqns)
413+
self._build_blocks(target_exprs)
389414

390415
@property
391416
def submatrices(self):
@@ -427,7 +452,9 @@ def _build_blocks(self, target_exprs):
427452
matvecs = []
428453
for expr in exprs:
429454
matvecs.extend(
430-
e for e in self._build_matvec_expr(expr, col_target, row_target)
455+
e for e in self._build_matvec_expr(
456+
expr, col_target=col_target, row_target=row_target
457+
)
431458
)
432459
matvecs = [m for m in matvecs if m is not None]
433460

@@ -469,6 +496,25 @@ def __repr__(self):
469496
return f"<MixedJacobian with {self.n_submatrices} submatrices: [{summary}]>"
470497

471498

499+
class SubMatrixBlock:
500+
def __init__(self, name, matvecs, scdiag, row_target,
501+
col_target, row_idx, col_idx, linear_idx):
502+
self.name = name
503+
self.matvecs = matvecs
504+
self.scdiag = scdiag
505+
self.row_target = row_target
506+
self.col_target = col_target
507+
self.row_idx = row_idx
508+
self.col_idx = col_idx
509+
self.linear_idx = linear_idx
510+
511+
def is_diag(self):
512+
return self.row_idx == self.col_idx
513+
514+
def __repr__(self):
515+
return (f"<SubMatrixBlock {self.name}>")
516+
517+
472518
class Residual:
473519
"""
474520
Gennerates the metadata needed to define the nonlinear residual function

0 commit comments

Comments
 (0)