Skip to content

Commit 68326da

Browse files
ggormanclaude
andcommitted
types: Fix SubDomain pickling for Dask compatibility
Fixes circular reference pickling bug where Grid and SubDomain objects reference each other, causing failures in Dask workflows. Changes: - Implement lazy SubDistributor initialization via property getter - Restore grid references in Grid.__setstate__ for legacy API - Update examples/seismic/model.py to use new SubDomain API - Add regression tests for SubDomain pickling (new and legacy API) - Simplify redundant check in test_interpolate_subdomain_mpi_mfe The issue occurred because SubDomain.__setstate__ tried to create a SubDistributor before Grid was fully unpickled, causing AttributeError. Now the distributor is created lazily on first access, after both objects are fully restored. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 3c2218a commit 68326da

4 files changed

Lines changed: 93 additions & 30 deletions

File tree

devito/types/grid.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,10 @@ def __setstate__(self, state):
402402
for k, v in state.items():
403403
setattr(self, k, v)
404404
self._distributor = Distributor(self.shape, self.dimensions, MPI.COMM_SELF)
405+
# Restore grid references in subdomains (needed for legacy API)
406+
for sd in self._subdomains:
407+
if hasattr(sd, '_grid'):
408+
sd._grid = self
405409

406410

407411
class AbstractSubDomain(CartesianDiscretization):
@@ -510,7 +514,22 @@ def stepping_dim(self):
510514

511515
@property
512516
def distributor(self):
513-
"""The Distributor used for MPI-decomposing the CartesianDiscretization."""
517+
"""
518+
The Distributor used for MPI-decomposing the SubDomain.
519+
520+
Lazy-initialized to handle circular reference pickling issues.
521+
"""
522+
if self._distributor is None and self._grid is not None:
523+
# Lazy initialization after unpickling or late binding
524+
# Check Grid is fully initialized before creating SubDistributor
525+
if (hasattr(self._grid, '_dimensions') and
526+
hasattr(self._grid, 'distributor') and
527+
self._grid.distributor is not None):
528+
try:
529+
self._distributor = SubDistributor(self)
530+
except AttributeError:
531+
# Grid still not ready, return None for now
532+
pass
514533
return self._distributor
515534

516535
def is_distributed(self, dim):
@@ -681,8 +700,9 @@ def __getstate__(self):
681700
def __setstate__(self, state):
682701
for k, v in state.items():
683702
setattr(self, k, v)
684-
if self.grid:
685-
self._distributor = SubDistributor(self)
703+
# Don't create distributor here - will be lazy-initialized on first access
704+
# to avoid race condition when Grid hasn't been fully unpickled yet
705+
self._distributor = None
686706

687707

688708
class MultiSubDomain(AbstractSubDomain):

examples/seismic/model.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ class PhysicalDomain(SubDomain):
5959

6060
name = 'physdomain'
6161

62-
def __init__(self, so, fs=False):
63-
super().__init__()
62+
def __init__(self, so, fs=False, grid=None):
6463
self.so = so
6564
self.fs = fs
65+
super().__init__(grid=grid)
6666

6767
def define(self, dimensions):
6868
map_d = {d: d for d in dimensions}
@@ -75,9 +75,9 @@ class FSDomain(SubDomain):
7575

7676
name = 'fsdomain'
7777

78-
def __init__(self, so):
79-
super().__init__()
78+
def __init__(self, so, grid=None):
8079
self.size = so
80+
super().__init__(grid=grid)
8181

8282
def define(self, dimensions):
8383
"""
@@ -104,12 +104,8 @@ def __init__(self, origin, spacing, shape, space_order, nbl=20,
104104
origin_pml = [dtype(o - s*nbl) for o, s in zip(origin, spacing)]
105105
shape_pml = np.array(shape) + 2 * self.nbl
106106

107-
# Model size depending on freesurface
108-
physdomain = PhysicalDomain(space_order, fs=fs)
109-
subdomains = subdomains + (physdomain,)
107+
# Free surface adjustments
110108
if fs:
111-
fsdomain = FSDomain(space_order)
112-
subdomains = subdomains + (fsdomain,)
113109
origin_pml[-1] = origin[-1]
114110
shape_pml[-1] -= self.nbl
115111

@@ -118,8 +114,19 @@ def __init__(self, origin, spacing, shape, space_order, nbl=20,
118114
if grid is None:
119115
# Physical extent is calculated per cell, so shape - 1
120116
extent = tuple(np.array(spacing) * (shape_pml - 1))
117+
# Create grid first (new API - no subdomains parameter)
121118
self.grid = Grid(extent=extent, shape=shape_pml, origin=origin_pml,
122-
dtype=dtype, subdomains=subdomains, topology=topology)
119+
dtype=dtype, topology=topology)
120+
121+
# Create subdomains with grid parameter (new API)
122+
physdomain = PhysicalDomain(space_order, fs=fs, grid=self.grid)
123+
all_subdomains = list(subdomains) + [physdomain]
124+
if fs:
125+
fsdomain = FSDomain(space_order, grid=self.grid)
126+
all_subdomains.append(fsdomain)
127+
128+
# Add subdomains to grid
129+
self.grid._subdomains = self.grid._subdomains + tuple(all_subdomains)
123130
else:
124131
self.grid = grid
125132

tests/test_interpolation.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,23 +1150,11 @@ def test_interpolate_subdomain_mpi_mfe(self, mode):
11501150
# when interpolating from points outside the subdomain
11511151
rank = grid.distributor.myrank
11521152

1153-
# Check all values on this rank are finite
1154-
if len(sr0.data) > 0:
1155-
print(f"Rank {rank} has {len(sr0.data)} values: {sr0.data}")
1156-
# Convert to numpy array to check for non-finite values
1157-
vals = np.asarray(sr0.data)
1158-
if not np.all(np.isfinite(vals)):
1159-
bad_indices = np.where(~np.isfinite(vals))[0]
1160-
raise AssertionError(
1161-
f"BUG REPRODUCED! Rank {rank}, points {bad_indices}: "
1162-
f"non-finite values {vals[bad_indices]} detected. "
1163-
f"Full data: {vals}. "
1164-
f"This occurs when interpolating from a Function on a SubDomain "
1165-
f"for coordinates outside the subdomain."
1166-
)
1167-
# Verify all values are reasonable (not huge garbage)
1168-
assert np.all(np.abs(vals) < 1000), \
1169-
f"Rank {rank}: Suspicious large values: {vals}"
1153+
# Check all values are finite and reasonable (not huge garbage)
1154+
assert np.all(np.isfinite(sr0.data)), \
1155+
f"Rank {rank}: sr0 has non-finite values: {sr0.data}"
1156+
assert np.all(np.abs(sr0.data) < 1000), \
1157+
f"Rank {rank}: Suspicious large values: {sr0.data}"
11701158

11711159
@pytest.mark.parallel(mode=4)
11721160
def test_interpolate_multiple_subdomains_unique_temps(self, mode):

tests/test_pickle.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,54 @@ def test_function(self, pickle, on_sd):
127127
assert f.dtype == new_f.dtype
128128
assert f.shape == new_f.shape
129129

130+
@pytest.mark.parametrize('use_new_api', [False, True])
131+
def test_grid_with_subdomain(self, pickle, use_new_api):
132+
"""
133+
Test pickling Grid with SubDomains (both old and new API).
134+
135+
Regression test for circular reference pickling bug where Grid
136+
contains SubDomains that reference back to the Grid.
137+
"""
138+
if use_new_api:
139+
# New API: create SubDomain with grid parameter
140+
grid = Grid(shape=(10, 10, 10))
141+
sd = SD(grid=grid)
142+
grid._subdomains = grid._subdomains + (sd,)
143+
else:
144+
# Old API: pass subdomains to Grid constructor
145+
sd = SD()
146+
grid = Grid(shape=(10, 10, 10), subdomains=(sd,))
147+
148+
# Pickle the grid
149+
pkl_grid = pickle.dumps(grid)
150+
new_grid = pickle.loads(pkl_grid)
151+
152+
# Verify subdomains were pickled correctly
153+
assert len(new_grid._subdomains) == len(grid._subdomains)
154+
155+
# Verify subdomain distributor is accessible (lazy-initialized)
156+
new_sd = [s for s in new_grid._subdomains if s.name == 'sd'][0]
157+
assert new_sd.distributor is not None
158+
assert new_sd.grid is new_grid
159+
160+
def test_subdomain_with_grid_circular_ref(self, pickle):
161+
"""
162+
Test pickling SubDomain with grid reference (circular reference).
163+
164+
This tests the lazy distributor initialization fix for unpickling
165+
SubDomains that have circular references to their Grid.
166+
"""
167+
grid = Grid(shape=(10, 10, 10))
168+
sd = SD(grid=grid)
169+
170+
# Pickle the subdomain (which references the grid)
171+
pkl_sd = pickle.dumps(sd)
172+
new_sd = pickle.loads(pkl_sd)
173+
174+
# Verify grid reference and distributor
175+
assert new_sd.grid is not None
176+
assert new_sd.distributor is not None
177+
130178
@pytest.mark.parametrize('interp', ['linear', 'sinc'])
131179
def test_sparse_function(self, pickle, interp):
132180
grid = Grid(shape=(3,))

0 commit comments

Comments
 (0)