@@ -137,6 +137,26 @@ def eval(cls, *args):
137137
138138
139139class FieldData :
140+ """
141+ Metadata class passed to `LinearSolveExpr`. Encapsulates metadata for a single
142+ `target` field needed to interface with PETSc SNES solvers.
143+
144+ Parameters
145+ ----------
146+
147+ target : Function-like
148+ The target field to solve into, which is a Function-like object.
149+ jacobian : Jacobian
150+ Defines the matrix-vector product for the linear system, where the vector is
151+ the PETScArray representing the `target`.
152+ residual : Residual
153+ Defines the nonlinear residual function F(target) = 0.
154+ initialguess : InitialGuess
155+ Defines the initial guess for the solution, which satisfies
156+ essential boundary conditions.
157+ arrays : dict
158+ A dictionary mapping `target` to its corresponding PETScArrays.
159+ """
140160 def __init__ (self , target = None , jacobian = None , residual = None ,
141161 initialguess = None , arrays = None , ** kwargs ):
142162 self ._target = target
@@ -190,6 +210,24 @@ def targets(self):
190210
191211
192212class MultipleFieldData (FieldData ):
213+ """
214+ Metadata class passed to `LinearSolveExpr`, for mixed-field problems,
215+ where the solution vector spans multiple `targets`.
216+
217+ Parameters
218+ ----------
219+ targets : list of Function-like
220+ The fields to solve into, each represented by a Function-like object.
221+ jacobian : MixedJacobian
222+ Defines the matrix-vector products for the full system Jacobian.
223+ residual : MixedResidual
224+ Defines the nonlinear residual function F(targets) = 0.
225+ initialguess : InitialGuess
226+ Defines the initial guess metadata, which satisfies
227+ essential boundary conditions.
228+ arrays : dict
229+ A dictionary mapping the `targets` to their corresponding PETScArrays.
230+ """
193231 def __init__ (self , targets , arrays , jacobian = None , residual = None ):
194232 self ._targets = as_tuple (targets )
195233 self ._arrays = arrays
@@ -208,6 +246,7 @@ def space_dimensions(self):
208246
209247 @property
210248 def grid (self ):
249+ """The unique `Grid` associated with all targets."""
211250 grids = [t .grid for t in self .targets ]
212251 if len (set (grids )) > 1 :
213252 raise ValueError (
@@ -233,8 +272,70 @@ def space_order(self):
233272 def targets (self ):
234273 return self ._targets
235274
275+
276+ class BaseJacobian :
277+ def __init__ (self , arrays , target = None ):
278+ self .arrays = arrays
279+ self .target = target
280+
281+ def _scale_non_bcs (self , matvecs , target = None ):
282+ target = target or self .target
283+ vol = target .grid .symbolic_volume_cell
284+
285+ return [
286+ m if isinstance (m , EssentialBC ) else m ._rebuild (rhs = m .rhs * vol )
287+ for m in matvecs
288+ ]
289+
290+ def _compute_scdiag (self , matvecs , col_target = None ):
291+ """
292+ """
293+ x = self .arrays [col_target or self .target ]['x' ]
294+
295+ centres = {
296+ centre_stencil (m .rhs , x , as_coeff = True )
297+ for m in matvecs if not isinstance (m , EssentialBC )
298+ }
299+ return centres .pop () if len (centres ) == 1 else 1.0
300+
301+ def _scale_bcs (self , matvecs , scdiag ):
302+ """
303+ Scale the essential BCs
304+ """
305+ return [
306+ m ._rebuild (rhs = m .rhs * scdiag ) if isinstance (m , ZeroRow ) else m
307+ for m in matvecs
308+ ]
309+
310+ def _build_matvec_expr (self , expr , ** kwargs ):
311+ col_target = kwargs .get ('col_target' , self .target )
312+ row_target = kwargs .get ('row_target' , self .target )
313+
314+ _ , F_target , _ , targets = separate_eqn (expr , col_target )
315+ if F_target :
316+ return self ._make_matvec (
317+ expr , F_target , targets , col_target , row_target
318+ )
319+ else :
320+ return (None ,)
321+
322+ def _make_matvec (self , expr , F_target , targets , col_target , row_target ):
323+ y = self .arrays [row_target ]['y' ]
324+ x = self .arrays [col_target ]['x' ]
325+
326+ if isinstance (expr , EssentialBC ):
327+ # NOTE: Essential BCs are trivial equations in the solver.
328+ # See `EssentialBC` for more details.
329+ zero_row = ZeroRow (y , x , subdomain = expr .subdomain )
330+ zero_column = ZeroColumn (x , 0. , subdomain = expr .subdomain )
331+ return (zero_row , zero_column )
332+ else :
333+ rhs = F_target .subs (targets_to_arrays (x , targets ))
334+ rhs = rhs .subs (self .time_mapper )
335+ return (Eq (y , rhs , subdomain = expr .subdomain ),)
336+
236337
237- class Jacobian :
338+ class Jacobian ( BaseJacobian ) :
238339 """
239340 Represents a Jacobian matrix.
240341
@@ -246,14 +347,14 @@ class Jacobian:
246347 require explicit symbolic differentiation.
247348 """
248349 def __init__ (self , target , exprs , arrays , time_mapper ):
249- self . target = target
350+ super (). __init__ ( arrays = arrays , target = target )
250351 self .exprs = exprs
251- self .arrays = arrays
252352 self .time_mapper = time_mapper
253353 self ._build_matvecs ()
254354
255355 @property
256356 def matvecs (self ):
357+ # TODO: add shortcut explanation etc
257358 """
258359 Stores the expressions used to generate the `MatMult`
259360 callback generated at the IET level. This function is
@@ -279,7 +380,6 @@ def _build_matvecs(self):
279380 matvecs .extend (
280381 e for e in self ._build_matvec_expr (eq ) if e is not None
281382 )
282-
283383 matvecs = tuple (sorted (matvecs , key = lambda e : not isinstance (e , EssentialBC )))
284384
285385 matvecs = self ._scale_non_bcs (matvecs )
@@ -289,83 +389,8 @@ def _build_matvecs(self):
289389 self ._matvecs = matvecs
290390 self ._scdiag = scdiag
291391
292- def _build_matvec_expr (self , expr , col_target = None , row_target = None ):
293- col_target = col_target or self .target
294- row_target = row_target or self .target
295392
296- _ , F_target , _ , targets = separate_eqn (expr , col_target )
297- if F_target :
298- return self ._make_matvec (
299- expr , F_target , targets , col_target , row_target
300- )
301- else :
302- return (None ,)
303-
304- def _make_matvec (self , expr , F_target , targets , col_target , row_target ):
305- y = self .arrays [row_target ]['y' ]
306- x = self .arrays [col_target ]['x' ]
307-
308- if isinstance (expr , EssentialBC ):
309- # NOTE: Essential BCs are trivial equations in the solver.
310- # See `EssentialBC` for more details.
311- zero_row = ZeroRow (y , x , subdomain = expr .subdomain )
312- zero_column = ZeroColumn (x , 0. , subdomain = expr .subdomain )
313- return (zero_row , zero_column )
314- else :
315- rhs = F_target .subs (targets_to_arrays (x , targets ))
316- rhs = rhs .subs (self .time_mapper )
317- return (Eq (y , rhs , subdomain = expr .subdomain ),)
318-
319- def _scale_non_bcs (self , matvecs , target = None ):
320- target = target or self .target
321- vol = target .grid .symbolic_volume_cell
322-
323- return [
324- m if isinstance (m , EssentialBC ) else m ._rebuild (rhs = m .rhs * vol )
325- for m in matvecs
326- ]
327-
328- def _compute_scdiag (self , matvecs , col_target = None ):
329- """
330- """
331- x = self .arrays [col_target or self .target ]['x' ]
332-
333- centres = {
334- centre_stencil (m .rhs , x , as_coeff = True )
335- for m in matvecs if not isinstance (m , EssentialBC )
336- }
337- return centres .pop () if len (centres ) == 1 else 1.0
338-
339- def _scale_bcs (self , matvecs , scdiag ):
340- """
341- Scale the essential BCs
342- """
343- return [
344- m ._rebuild (rhs = m .rhs * scdiag ) if isinstance (m , ZeroRow ) else m
345- for m in matvecs
346- ]
347-
348-
349- class SubMatrixBlock :
350- def __init__ (self , name , matvecs , scdiag , row_target ,
351- col_target , row_idx , col_idx , linear_idx ):
352- self .name = name
353- self .matvecs = matvecs
354- self .scdiag = scdiag
355- self .row_target = row_target
356- self .col_target = col_target
357- self .row_idx = row_idx
358- self .col_idx = col_idx
359- self .linear_idx = linear_idx
360-
361- def is_diag (self ):
362- return self .row_idx == self .col_idx
363-
364- def __repr__ (self ):
365- return (f"<SubMatrixBlock { self .name } >" )
366-
367-
368- class MixedJacobian (Jacobian ):
393+ class MixedJacobian (BaseJacobian ):
369394 """
370395 Represents a Jacobian for a linear system with a solution vector
371396 composed of multiple fields (targets).
@@ -380,12 +405,12 @@ class MixedJacobian(Jacobian):
380405
381406 # TODO: pcfieldsplit support for each block
382407 """
383- def __init__ (self , target_eqns , arrays , time_mapper ):
384- self . targets = tuple ( target_eqns . keys () )
385- self .arrays = arrays
408+ def __init__ (self , target_exprs , arrays , time_mapper ):
409+ super (). __init__ ( arrays = arrays , target = None )
410+ self .targets = tuple ( target_exprs . keys ())
386411 self .time_mapper = time_mapper
387412 self ._submatrices = []
388- self ._build_blocks (target_eqns )
413+ self ._build_blocks (target_exprs )
389414
390415 @property
391416 def submatrices (self ):
@@ -427,7 +452,9 @@ def _build_blocks(self, target_exprs):
427452 matvecs = []
428453 for expr in exprs :
429454 matvecs .extend (
430- e for e in self ._build_matvec_expr (expr , col_target , row_target )
455+ e for e in self ._build_matvec_expr (
456+ expr , col_target = col_target , row_target = row_target
457+ )
431458 )
432459 matvecs = [m for m in matvecs if m is not None ]
433460
@@ -469,6 +496,25 @@ def __repr__(self):
469496 return f"<MixedJacobian with { self .n_submatrices } submatrices: [{ summary } ]>"
470497
471498
499+ class SubMatrixBlock :
500+ def __init__ (self , name , matvecs , scdiag , row_target ,
501+ col_target , row_idx , col_idx , linear_idx ):
502+ self .name = name
503+ self .matvecs = matvecs
504+ self .scdiag = scdiag
505+ self .row_target = row_target
506+ self .col_target = col_target
507+ self .row_idx = row_idx
508+ self .col_idx = col_idx
509+ self .linear_idx = linear_idx
510+
511+ def is_diag (self ):
512+ return self .row_idx == self .col_idx
513+
514+ def __repr__ (self ):
515+ return (f"<SubMatrixBlock { self .name } >" )
516+
517+
472518class Residual :
473519 """
474520 Gennerates the metadata needed to define the nonlinear residual function
0 commit comments