Skip to content

Commit 52250e1

Browse files
committed
dsl/misc/compiler: Clean up, fix __indices_setup__ in PETScArray and allow constraining individual EssentialBC equations
1 parent 5fc3bcd commit 52250e1

4 files changed

Lines changed: 55 additions & 20 deletions

File tree

devito/petsc/solve.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,12 @@ def petscsolve(target_exprs, target=None, solver_parameters=None,
7979
'kspgettype', 'kspgetnormtype', 'snesgetiterationnumber']
8080
8181
constrain_bcs : bool, optional
82-
If `True`, essential boundary conditions specifed by `EssentialBC` equations
83-
are constrained through a `PetscSection`. As a result, the corresponding degrees
84-
of freedom are excluded from the global solver and are not imposed using
85-
trivial equations.
82+
If `True`, ALL `EssentialBC` equations are constrained via a `PetscSection`,
83+
excluding the corresponding degrees of freedom from the global solver.
84+
Individual `EssentialBC`s can also be constrained independently by passing
85+
``constrain=True`` to their constructor (e.g.
86+
``EssentialBC(lhs, rhs, subdomain=..., constrain=True)``), regardless of
87+
this flag.
8688
8789
Returns
8890
-------
@@ -142,11 +144,8 @@ def linear_solve_args(self):
142144
residual = Residual(target, exprs, arrays, self.time_mapper, jacobian.scdiag)
143145
initial_guess = InitialGuess(target, exprs, arrays, self.time_mapper)
144146

145-
constrain_bc = (
146-
ConstrainBC(target, exprs, arrays)
147-
if self.constrain_bcs
148-
else None
149-
)
147+
constrain_exprs = self._get_constrain_exprs(exprs)
148+
constrain_bc = ConstrainBC(target, constrain_exprs, arrays) if constrain_exprs else None
150149

151150
field_data = FieldData(
152151
target=target,
@@ -159,6 +158,16 @@ def linear_solve_args(self):
159158

160159
return target, funcs, field_data
161160

161+
def _get_constrain_exprs(self, exprs):
162+
"""
163+
Return the subset of `exprs` to be constrained via PetscSection.
164+
If `constrain_bcs=True`, all `EssentialBC` exprs are returned.
165+
Otherwise, only those individually marked with ``constrain=True``.
166+
"""
167+
if self.constrain_bcs:
168+
return tuple(e for e in exprs if isinstance(e, EssentialBC))
169+
return tuple(e for e in exprs if isinstance(e, EssentialBC) and e.constrain)
170+
162171
def generate_arrays(self, *targets):
163172
return {
164173
t: {
@@ -197,9 +206,11 @@ def linear_solve_args(self):
197206
)
198207

199208
constrain_bc = {
200-
t: ConstrainBC(t, as_tuple(self.target_exprs[t]), arrays)
209+
t: ConstrainBC(t, self._get_constrain_exprs(as_tuple(self.target_exprs[t])), arrays)
201210
for t in targets
202-
} if self.constrain_bcs else None
211+
if self._get_constrain_exprs(as_tuple(self.target_exprs[t]))
212+
}
213+
constrain_bc = constrain_bc if constrain_bc else None
203214

204215
all_data = MultipleFieldData(
205216
targets=targets,

devito/petsc/types/array.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@ def __dtype_setup__(cls, **kwargs):
5151
@classmethod
5252
def __indices_setup__(cls, *args, **kwargs):
5353
target = kwargs['target']
54-
dimensions = tuple(target.indices[d] for d in target.space_dimensions)
54+
dimensions = tuple(target.space_dimensions)
55+
target_indices = tuple(target.indices[d] for d in target.space_dimensions)
5556
if args:
5657
indices = args
5758
else:
58-
indices = dimensions
59+
indices = target_indices
5960
return as_tuple(dimensions), as_tuple(indices)
6061

6162
def __halo_setup__(self, **kwargs):
@@ -123,6 +124,19 @@ def symbolic_shape(self):
123124
# Reverse it since DMDA is setup backwards to Devito dimensions.
124125
return DimensionTuple(*field_from_composites[::-1], getters=self.dimensions)
125126

127+
# TODO: Is this necessary? Taken directly from `Function`.
128+
def _eval_at(self, func):
129+
if self.staggered == func.staggered:
130+
return self
131+
mapper = {}
132+
for d in self.dimensions:
133+
try:
134+
if self.indices_ref[d] is not func.indices_ref[d]:
135+
mapper[self.indices_ref[d]] = func.indices_ref[d]
136+
except KeyError:
137+
pass
138+
return self.subs(mapper)
139+
126140

127141
class PetscBundle(Bundle):
128142
"""

devito/petsc/types/equation.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,40 @@ class EssentialBC(Eq):
99
Represents an essential boundary condition for use with `petscsolve`.
1010
1111
The compiler will automatically zero the corresponding rows/columns in the Jacobian
12-
and lift the boundary terms into the residual RHS, unless the user
13-
specifies `constrain_bcs=True` to `petscsolve`.
12+
and lift the boundary terms into the residual RHS, unless the BC is constrained
13+
via a `PetscSection`. Constraining can be enabled in two ways:
14+
15+
- Globally: pass `constrain_bcs=True` to `petscsolve` to constrain all
16+
`EssentialBC`s in the solve.
17+
- Individually: pass `constrain=True` to a specific `EssentialBC` constructor,
18+
e.g. ``EssentialBC(lhs, rhs, subdomain=..., constrain=True)``.
1419
1520
Note:
1621
- To define an essential boundary condition, use:
1722
Eq(target, boundary_value, subdomain=...),
1823
where `target` is the Function-like object passed to `petscsolve`.
1924
- SubDomains used for multiple `EssentialBC`s must not overlap.
2025
"""
21-
__rkwargs__ = Eq.__rkwargs__ + ("target",)
26+
__rkwargs__ = Eq.__rkwargs__ + ("target", "constrain")
2227

23-
def __new__(cls, *args, target=None, **kwargs):
28+
def __new__(cls, *args, target=None, constrain=False, **kwargs):
2429
obj = super().__new__(cls, *args, **kwargs)
2530

2631
if target is None:
2732
target = obj.lhs.function
2833

2934
obj._target = target
35+
obj._constrain = constrain
3036
return obj
3137

3238
@property
3339
def target(self):
3440
return self._target
3541

42+
@property
43+
def constrain(self):
44+
return self._constrain
45+
3646

3747
class ZeroRow(EssentialBC):
3848
"""

tests/test_petsc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,9 +1481,9 @@ class TestMPI:
14811481
def test_laplacian_1d(self, nx, unorm, mode):
14821482
"""
14831483
"""
1484-
configuration['compiler'] = 'custom'
1485-
os.environ['CC'] = 'mpicc'
1486-
PetscInitialize()
1484+
# configuration['compiler'] = 'custom'
1485+
# os.environ['CC'] = 'mpicc'
1486+
# PetscInitialize()
14871487

14881488
class SubSide(SubDomain):
14891489
def __init__(self, side='left', grid=None):

0 commit comments

Comments
 (0)