Skip to content

Commit 0885857

Browse files
committed
fix halospots from appearing in petsc callbacks etc
1 parent bc5629a commit 0885857

1 file changed

Lines changed: 63 additions & 4 deletions

File tree

devito/petsc/iet/passes.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from devito.passes.iet.engine import iet_pass
66
from devito.ir.iet import (
77
Transformer, MapNodes, Iteration, CallableBody, List, Call, FindNodes, Section,
8-
FindSymbols, DummyExpr, Uxreplace, Dereference
8+
FindSymbols, DummyExpr, Uxreplace, Dereference, HaloSpot
99
)
1010
from devito.symbolics import Byref, Macro, Null, FieldFromPointer
1111
from devito.types.basic import DataSymbol, LocalType
12+
from devito.types.dimension import DefaultDimension
1213
from devito.types.misc import FIndexed
1314
import devito.logger
1415
from devito.passes.iet.linearization import linearize_accesses, Tracker
@@ -18,7 +19,7 @@
1819
CallbackUserStruct
1920
)
2021
from devito.petsc.types.macros import petsc_func_begin_user
21-
from devito.petsc.iet.nodes import PetscMetaData, petsc_call
22+
from devito.petsc.iet.nodes import PetscMetaData, petsc_call, PETScCallable
2223
from devito.petsc.config import core_metadata, petsc_languages
2324
from devito.petsc.iet.callbacks import (
2425
BaseCallbackBuilder, CoupledCallbackBuilder, populate_matrix_context,
@@ -68,7 +69,20 @@ def lower_petsc(iet, **kwargs):
6869
"but multiple `Grid`s were found."
6970
)
7071
grid = unique_grids.pop()
71-
devito_mpi = kwargs['options'].get('mpi', False)
72+
73+
# Protect PETSc solve targets from being dropped by `_drop_if_unwritten`.
74+
# `lower_petsc` runs before `mpiize`, replacing `PetscMetaData` (an
75+
# `Expression` subclass whose `.write` reveals the target function) with
76+
# `Call` nodes to run the PETSc solver. Once that happens, `_drop_if_unwritten` can no
77+
# longer see the target as written and incorrectly discards its `HaloSpot`. So we
78+
# compose `dist-drop-unwritten` with a guard that always returns
79+
# False for PETSc targets.
80+
options = kwargs['options']
81+
petsc_targets = {n.write for n in data if n.write is not None}
82+
if petsc_targets:
83+
options['dist-drop-unwritten'] = lambda f: f not in petsc_targets
84+
85+
devito_mpi = options.get('mpi', False)
7286
comm = grid.distributor._obj_comm if devito_mpi else 'PETSC_COMM_WORLD'
7387

7488
# Create core PETSc calls (not specific to each `petscsolve`)
@@ -112,14 +126,54 @@ def lower_petsc(iet, **kwargs):
112126
),))
113127

114128
populate_matrix_context(efuncs)
129+
130+
# Strip HaloSpots from PETSc callback efuncs before returning them.
131+
# The callbacks are built via rcompile(..., mpi=False), so HaloSpots
132+
# survive in their IETs but are NOT converted to haloupdate calls there.
133+
# When the main mpiize pass (mpi=True) later processes these callbacks,
134+
# it would convert those HaloSpots into haloupdate calls — which is wrong,
135+
# since halo exchanges must only happen in the main kernel. Strip them here
136+
# before they reach mpiize.
137+
for name, efunc in list(efuncs.items()):
138+
if isinstance(efunc, PETScCallable):
139+
halos = FindNodes(HaloSpot).visit(efunc)
140+
if halos:
141+
mapper = {hs: hs.body for hs in halos}
142+
efuncs[name] = Transformer(mapper).visit(efunc)
143+
115144
iet = Transformer(subs).visit(iet)
116145
body = core + tuple(setup) + iet.body.body + tuple(clear_options)
146+
# from IPython import embed; embed()
117147
body = iet.body._rebuild(body=body)
118148
iet = iet._rebuild(body=body)
149+
# from IPython import embed; embed()
119150
metadata = {**core_metadata(), 'efuncs': tuple(efuncs.values())}
120151
return iet, metadata
121152

122153

154+
@iet_pass
155+
def strip_petsc_callback_halos(iet, **kwargs):
156+
"""
157+
Remove any HaloSpot nodes that `mpiize` may have injected into PETSc
158+
callback functions (FormFunction, SetPointBCs, FormRHS, etc.).
159+
160+
HaloSpots should only appear in the main kernel, never inside PETSc
161+
callbacks which run as part of the PETSc solver internals. All
162+
PETSc callbacks are instances of `PETScCallable`; the main kernel is
163+
not, so we use that to distinguish the two.
164+
"""
165+
if not isinstance(iet, PETScCallable):
166+
return iet, {}
167+
168+
halos = FindNodes(HaloSpot).visit(iet)
169+
if not halos:
170+
return iet, {}
171+
172+
# Replace each HaloSpot with its body (unwrap it)
173+
mapper = {hs: hs.body for hs in halos}
174+
return Transformer(mapper).visit(iet), {}
175+
176+
123177
def lower_petsc_symbols(iet, **kwargs):
124178
"""
125179
The `place_definitions` and `place_casts` passes may introduce new
@@ -134,7 +188,7 @@ def lower_petsc_symbols(iet, **kwargs):
134188
# Rebuild `MainUserStruct` and update iet accordingly
135189
rebuild_parent_user_struct(iet, mapper=callback_struct_mapper)
136190

137-
iet = linear_indices(iet, **kwargs)
191+
linear_indices(iet, **kwargs)
138192

139193

140194
@iet_pass
@@ -151,9 +205,14 @@ def linear_indices(iet, **kwargs):
151205

152206
tracker = Tracker('basic', dtype, kwargs['sregistry'])
153207

208+
# Exclude SubDomainSet backing functions from linearization: they must
209+
# remain as 2D array reads (border[n0][col]), not flat-indexed via a macro.
210+
# SubDomainSet subfunctions are identified by having a DefaultDimension
211+
# (sds_dim) among their dimensions.
154212
indexeds = [
155213
i for i in FindSymbols('indexeds').visit(iet)
156214
if not isinstance(i.function, LocalType)
215+
and not any(isinstance(d, DefaultDimension) for d in i.function.dimensions)
157216
]
158217
candidates = {i.function.name for i in indexeds}
159218
key = lambda f: f.name in candidates

0 commit comments

Comments
 (0)