Skip to content

Commit 51d4d7e

Browse files
committed
misc: Address more comments and add docstrings
1 parent 525fdf6 commit 51d4d7e

3 files changed

Lines changed: 43 additions & 21 deletions

File tree

devito/petsc/iet/routines.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,12 @@ def _create_matvec_body(self, body, jacobian):
258258
)
259259

260260
# Dereference function data in struct
261-
dereference_funcs = [Dereference(i, ctx) for i in
262-
fields if isinstance(i.function, AbstractFunction)]
261+
derefs = self.dereference_funcs(ctx, fields)
263262

264263
matvec_body = CallableBody(
265264
List(body=body),
266265
init=(objs['begin_user'],),
267-
stacks=stacks+tuple(dereference_funcs),
266+
stacks=stacks+derefs,
268267
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
269268
)
270269

@@ -398,13 +397,12 @@ def _create_formfunc_body(self, body):
398397
)
399398

400399
# Dereference function data in struct
401-
dereference_funcs = [Dereference(i, ctx) for i in
402-
fields if isinstance(i.function, AbstractFunction)]
400+
derefs = self.dereference_funcs(ctx, fields)
403401

404402
formfunc_body = CallableBody(
405403
List(body=body),
406404
init=(objs['begin_user'],),
407-
stacks=stacks+tuple(dereference_funcs),
405+
stacks=stacks+derefs,
408406
retstmt=(Call('PetscFunctionReturn', arguments=[0]),))
409407

410408
# Replace non-function data with pointer to data in struct
@@ -509,15 +507,12 @@ def _create_form_rhs_body(self, body):
509507
)
510508

511509
# Dereference function data in struct
512-
dereference_funcs = tuple(
513-
[Dereference(i, ctx) for i in
514-
fields if isinstance(i.function, AbstractFunction)]
515-
)
510+
derefs = self.dereference_funcs(ctx, fields)
516511

517512
formrhs_body = CallableBody(
518513
List(body=[body]),
519514
init=(objs['begin_user'],),
520-
stacks=stacks+dereference_funcs,
515+
stacks=stacks+derefs,
521516
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
522517
)
523518

@@ -590,13 +585,12 @@ def _create_initial_guess_body(self, body):
590585
)
591586

592587
# Dereference function data in struct
593-
dereference_funcs = [Dereference(i, ctx) for i in
594-
fields if isinstance(i.function, AbstractFunction)]
588+
derefs = self.dereference_funcs(ctx, fields)
595589

596590
body = CallableBody(
597591
List(body=[body]),
598592
init=(objs['begin_user'],),
599-
stacks=stacks+tuple(dereference_funcs),
593+
stacks=stacks+derefs,
600594
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
601595
)
602596

@@ -662,6 +656,12 @@ def zero_vector(self, vec):
662656
"""
663657
return petsc_call('VecSet', [vec, 0.0]) if self.zero_memory else None
664658

659+
def dereference_funcs(self, struct, fields):
660+
return tuple(
661+
[Dereference(i, struct) for i in
662+
fields if isinstance(i.function, AbstractFunction)]
663+
)
664+
665665

666666
class CCBBuilder(CBBuilder):
667667
def __init__(self, **kwargs):
@@ -894,16 +894,15 @@ def _whole_formfunc_body(self, body):
894894
)
895895

896896
# Dereference function data in struct
897-
dereference_funcs = [Dereference(i, ctx) for i in
898-
fields if isinstance(i.function, AbstractFunction)]
897+
derefs = self.dereference_funcs(ctx, fields)
899898

900899
f_soa = PointerCast(fbundle)
901900
x_soa = PointerCast(xbundle)
902901

903902
formfunc_body = CallableBody(
904903
List(body=body),
905904
init=(objs['begin_user'],),
906-
stacks=stacks+tuple(dereference_funcs),
905+
stacks=stacks+derefs,
907906
casts=(f_soa, x_soa),
908907
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
909908
)

devito/petsc/types/types.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class Jacobian:
241241
This Jacobian is defined implicitly via matrix-vector products
242242
derived from the symbolic equations provided in `matvecs`.
243243
244-
It assumes the problem is linear, meaning the Jacobian
244+
The class assumes the problem is linear, meaning the Jacobian
245245
corresponds to a constant coefficient matrix and does not
246246
require explicit symbolic differentiation.
247247
"""
@@ -274,6 +274,7 @@ def _build_matvecs(self):
274274
matvecs.extend(
275275
e for e in self._build_matvec_eq(eq) if e is not None
276276
)
277+
277278
matvecs = tuple(sorted(matvecs, key=lambda e: not isinstance(e, EssentialBC)))
278279

279280
matvecs = self._scale_non_bcs(matvecs)
@@ -469,6 +470,20 @@ def __repr__(self):
469470

470471
class Residual:
471472
"""
473+
Gennerates the metadata needed to define the nonlinear residual function
474+
F(target) = 0 for use with PETSc's SNES interface.
475+
476+
PETSc's SNES interface includes methods for solving nonlinear systems of
477+
equations using Newton-type methods. For linear problems, `SNESKSPONLY`
478+
is used to perform a single Newton iteration, unifying the
479+
interface for both linear and nonlinear problems.
480+
481+
This class encapsulates the symbolic equations used to construct the
482+
residual function F(target) = F_(target) - b, where b contains all
483+
terms independent of the solution `target`.
484+
485+
References:
486+
- https://petsc.org/main/manual/snes/
472487
"""
473488
def __init__(self, target, eqns, arrays, time_mapper, scdiag):
474489
self.target = target
@@ -481,12 +496,18 @@ def __init__(self, target, eqns, arrays, time_mapper, scdiag):
481496
@property
482497
def formfuncs(self):
483498
"""
499+
Stores the equations used to build the `FormFunction`
500+
callback generated at the IET level. This function is
501+
passed to PETSc via `SNESSetFunction(..., FormFunction, ...)`.
484502
"""
485503
return self._formfuncs
486504

487505
@property
488506
def formrhs(self):
489507
"""
508+
Stores the equations used to generate the RHS
509+
vector `b` through the `FormRHS` callback generated at the IET level.
510+
The SNES solver is then called via `SNESSolve(..., b, target)`.
490511
"""
491512
return self._formrhs
492513

@@ -544,7 +565,7 @@ class MixedResidual(Residual):
544565
"""
545566
"""
546567
def __init__(self, target_eqns, arrays, time_mapper, scdiag):
547-
self.targets = as_tuple(target_eqns.keys())
568+
self.targets = tuple(target_eqns.keys())
548569
self.arrays = arrays
549570
self.time_mapper = time_mapper
550571
self.scdiag = scdiag
@@ -592,14 +613,15 @@ def _build_function_eq(self, eq, target):
592613
self.arrays[target]['x'], eq.rhs, subdomain=eq.subdomain
593614
)
594615
return (zero_row, zero_col)
616+
595617
else:
596618
if isinstance(zeroed, (int, float)):
597619
rhs = zeroed * volume
598620
else:
599621
rhs = zeroed.subs(mapper)
600622
rhs = rhs.subs(self.time_mapper)*volume
601623

602-
return as_tuple(Eq(self.arrays[target]['f'], rhs, subdomain=eq.subdomain))
624+
return (Eq(self.arrays[target]['f'], rhs, subdomain=eq.subdomain),)
603625

604626

605627
class InitialGuess:

devito/symbolics/extraction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ def separate_eqn(eqn, target):
2020

2121
from devito.operations.solve import eval_time_derivatives
2222
zeroed_eqn = eval_time_derivatives(zeroed_eqn.lhs)
23-
target_funcs = set(generate_targets(zeroed_eqn, target))
2423

24+
target_funcs = set(generate_targets(zeroed_eqn, target))
2525
b, F_target = remove_targets(zeroed_eqn, target_funcs)
26+
2627
return -b, F_target, zeroed_eqn, target_funcs
2728

2829

0 commit comments

Comments
 (0)