55from devito .passes .iet .engine import iet_pass
66from devito .ir .iet import (
77 Transformer , MapNodes , Iteration , CallableBody , List , Call , FindNodes , Section ,
8- FindSymbols , DummyExpr , Uxreplace , Dereference
8+ FindSymbols , DummyExpr , Uxreplace , Dereference , HaloSpot
99)
1010from devito .symbolics import Byref , Macro , Null , FieldFromPointer
1111from devito .types .basic import DataSymbol , LocalType
12+ from devito .types .dimension import DefaultDimension
1213from devito .types .misc import FIndexed
1314import devito .logger
1415from devito .passes .iet .linearization import linearize_accesses , Tracker
1819 CallbackUserStruct
1920)
2021from devito .petsc .types .macros import petsc_func_begin_user
21- from devito .petsc .iet .nodes import PetscMetaData , petsc_call
22+ from devito .petsc .iet .nodes import PetscMetaData , petsc_call , PETScCallable
2223from devito .petsc .config import core_metadata , petsc_languages
2324from devito .petsc .iet .callbacks import (
2425 BaseCallbackBuilder , CoupledCallbackBuilder , populate_matrix_context ,
@@ -68,7 +69,20 @@ def lower_petsc(iet, **kwargs):
6869 "but multiple `Grid`s were found."
6970 )
7071 grid = unique_grids .pop ()
71- devito_mpi = kwargs ['options' ].get ('mpi' , False )
72+
73+ # Protect PETSc solve targets from being dropped by `_drop_if_unwritten`.
74+ # `lower_petsc` runs before `mpiize`, replacing `PetscMetaData` (an
75+ # `Expression` subclass whose `.write` reveals the target function) with
76+ # `Call` nodes to run the PETSc solver. Once that happens, `_drop_if_unwritten` can no
77+ # longer see the target as written and incorrectly discards its `HaloSpot`. So we
78+ # compose `dist-drop-unwritten` with a guard that always returns
79+ # False for PETSc targets.
80+ options = kwargs ['options' ]
81+ petsc_targets = {n .write for n in data if n .write is not None }
82+ if petsc_targets :
83+ options ['dist-drop-unwritten' ] = lambda f : f not in petsc_targets
84+
85+ devito_mpi = options .get ('mpi' , False )
7286 comm = grid .distributor ._obj_comm if devito_mpi else 'PETSC_COMM_WORLD'
7387
7488 # Create core PETSc calls (not specific to each `petscsolve`)
@@ -112,14 +126,54 @@ def lower_petsc(iet, **kwargs):
112126 ),))
113127
114128 populate_matrix_context (efuncs )
129+
130+ # Strip HaloSpots from PETSc callback efuncs before returning them.
131+ # The callbacks are built via rcompile(..., mpi=False), so HaloSpots
132+ # survive in their IETs but are NOT converted to haloupdate calls there.
133+ # When the main mpiize pass (mpi=True) later processes these callbacks,
134+ # it would convert those HaloSpots into haloupdate calls — which is wrong,
135+ # since halo exchanges must only happen in the main kernel. Strip them here
136+ # before they reach mpiize.
137+ for name , efunc in list (efuncs .items ()):
138+ if isinstance (efunc , PETScCallable ):
139+ halos = FindNodes (HaloSpot ).visit (efunc )
140+ if halos :
141+ mapper = {hs : hs .body for hs in halos }
142+ efuncs [name ] = Transformer (mapper ).visit (efunc )
143+
115144 iet = Transformer (subs ).visit (iet )
116145 body = core + tuple (setup ) + iet .body .body + tuple (clear_options )
146+ # from IPython import embed; embed()
117147 body = iet .body ._rebuild (body = body )
118148 iet = iet ._rebuild (body = body )
149+ # from IPython import embed; embed()
119150 metadata = {** core_metadata (), 'efuncs' : tuple (efuncs .values ())}
120151 return iet , metadata
121152
122153
154+ @iet_pass
155+ def strip_petsc_callback_halos (iet , ** kwargs ):
156+ """
157+ Remove any HaloSpot nodes that `mpiize` may have injected into PETSc
158+ callback functions (FormFunction, SetPointBCs, FormRHS, etc.).
159+
160+ HaloSpots should only appear in the main kernel, never inside PETSc
161+ callbacks which run as part of the PETSc solver internals. All
162+ PETSc callbacks are instances of `PETScCallable`; the main kernel is
163+ not, so we use that to distinguish the two.
164+ """
165+ if not isinstance (iet , PETScCallable ):
166+ return iet , {}
167+
168+ halos = FindNodes (HaloSpot ).visit (iet )
169+ if not halos :
170+ return iet , {}
171+
172+ # Replace each HaloSpot with its body (unwrap it)
173+ mapper = {hs : hs .body for hs in halos }
174+ return Transformer (mapper ).visit (iet ), {}
175+
176+
123177def lower_petsc_symbols (iet , ** kwargs ):
124178 """
125179 The `place_definitions` and `place_casts` passes may introduce new
@@ -134,7 +188,7 @@ def lower_petsc_symbols(iet, **kwargs):
134188 # Rebuild `MainUserStruct` and update iet accordingly
135189 rebuild_parent_user_struct (iet , mapper = callback_struct_mapper )
136190
137- iet = linear_indices (iet , ** kwargs )
191+ linear_indices (iet , ** kwargs )
138192
139193
140194@iet_pass
@@ -151,9 +205,14 @@ def linear_indices(iet, **kwargs):
151205
152206 tracker = Tracker ('basic' , dtype , kwargs ['sregistry' ])
153207
208+ # Exclude SubDomainSet backing functions from linearization: they must
209+ # remain as 2D array reads (border[n0][col]), not flat-indexed via a macro.
210+ # SubDomainSet subfunctions are identified by having a DefaultDimension
211+ # (sds_dim) among their dimensions.
154212 indexeds = [
155213 i for i in FindSymbols ('indexeds' ).visit (iet )
156214 if not isinstance (i .function , LocalType )
215+ and not any (isinstance (d , DefaultDimension ) for d in i .function .dimensions )
157216 ]
158217 candidates = {i .function .name for i in indexeds }
159218 key = lambda f : f .name in candidates
0 commit comments