Skip to content

Commit ad3c629

Browse files
committed
mpi: make data behave as serial with a single rank
1 parent 217a3af commit ad3c629

3 files changed

Lines changed: 10 additions & 6 deletions

File tree

devito/core/operator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def _specialize_clusters(cls, clusters, **kwargs):
318318

319319
# Fetch passes to be called
320320
passes_mapper = cls._make_clusters_passes_mapper(**kwargs)
321+
print(passes)
321322

322323
# Call passes
323324
for i in passes:

devito/data/data.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _local(self):
154154

155155
def _global(self, glb_idx, decomposition):
156156
"""A "global" view of ``self`` over a given Decomposition."""
157-
if self._is_distributed:
157+
if self._is_mpi_distributed:
158158
raise ValueError("Cannot derive a decomposed view from a decomposed Data")
159159
if len(decomposition) != self.ndim:
160160
raise ValueError("`decomposition` should have ndim=%d entries" % self.ndim)
@@ -197,7 +197,10 @@ def wrapper(data, *args, **kwargs):
197197

198198
@property
199199
def _is_mpi_distributed(self):
200-
return self._is_distributed and configuration['mpi']
200+
is_mpi = self._is_distributed and configuration['mpi']
201+
if is_mpi:
202+
is_mpi = is_mpi and self._distributor.comm.size > 2
203+
return is_mpi
201204

202205
def __repr__(self):
203206
return super(Data, self._local).__repr__()
@@ -341,7 +344,7 @@ def __setitem__(self, glb_idx, val, comm_type):
341344
super().__setitem__(loc_idx, val)
342345
else:
343346
super().__setitem__(glb_idx, val)
344-
elif isinstance(val, Data) and val._is_distributed:
347+
elif isinstance(val, Data) and val._is_mpi_distributed:
345348
if comm_type is index_by_index:
346349
glb_idx, val = self._process_args(glb_idx, val)
347350
val_idx = as_tuple([slice(i.glb_min, i.glb_max+1, 1) for
@@ -361,14 +364,14 @@ def __setitem__(self, glb_idx, val, comm_type):
361364
or data_global[j].size == 0
362365
if not skip:
363366
self.__setitem__(idx_global[j], data_global[j])
364-
elif self._is_distributed:
367+
elif self._is_mpi_distributed:
365368
# `val` is decomposed, `self` is decomposed -> local set
366369
super().__setitem__(glb_idx, val)
367370
else:
368371
# `val` is decomposed, `self` is replicated -> gatherall-like
369372
raise NotImplementedError
370373
elif isinstance(val, np.ndarray):
371-
if self._is_distributed:
374+
if self._is_mpi_distributed:
372375
# `val` is replicated, `self` is decomposed -> `val` gets decomposed
373376
glb_idx = self._normalize_index(glb_idx)
374377
glb_idx, val = self._process_args(glb_idx, val)

devito/finite_differences/differentiable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ def _diff2sympy(obj):
10881088
# Interpolation for finite differences
10891089
@singledispatch
10901090
def interp_for_fd(expr, x0, **kwargs):
1091-
return expr
1091+
return expr._evaluate(expand=True)
10921092

10931093

10941094
@interp_for_fd.register(sympy.Derivative)

0 commit comments

Comments
 (0)