Skip to content

Commit 5bf762a

Browse files
committed
compiler: refine distributed dimension check for smarter halotouch
1 parent b4d0594 commit 5bf762a

5 files changed

Lines changed: 22 additions & 13 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def callback(self, clusters, prefix, seen=None):
453453
not d._defines & hs.distributed_aindices:
454454
continue
455455

456-
if any(halo_write(ci, hs) for ci in clusters[:i]):
456+
if any(halo_write(ci, hs, prefix) for ci in clusters[:i]):
457457
# If there's a halo write before `c`, then we cannot inject the HaloTouch
458458
continue
459459

@@ -472,14 +472,17 @@ def callback(self, clusters, prefix, seen=None):
472472
# the args is important because that's what search functions honor!
473473
points = sorted(points, key=str)
474474

475-
# Construct the HaloTouch Cluster
476-
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
477-
478475
key0 = lambda i: i in prefix[:-1] or i in hs.loc_indices # noqa: B023
479476
key1 = lambda i: i not in hs.distributed_defined # noqa: B023
480477
key = lambda i: key0(i) and key1(i) # noqa: B023
481478
ispace = c.ispace.project(key)
482479

480+
# Reconstruct the HaloScheme with the new IterationSpace
481+
hs = HaloScheme(c.exprs, ispace)
482+
483+
# Construct the HaloTouch Cluster
484+
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
485+
483486
properties = c.properties.sequentialize()
484487

485488
halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
@@ -787,12 +790,15 @@ def normalize_reductions_sparse(cluster, sregistry):
787790
return cluster.rebuild(processed)
788791

789792

790-
def halo_write(c, hs):
793+
def halo_write(c, hs, prefix):
791794
loc_vals = hs.loc_values
795+
hsdims = hs.dimensions & set(prefix.itdims)
792796

793797
for f in hs.fmapper:
794798
for a in c.scope.getwrites(f):
795-
if set(a.access.indices) & loc_vals:
799+
is_write = set(a.access.indices) & loc_vals
800+
is_dist = any(c.grid.is_distributed(d) for d in hsdims)
801+
if is_write and is_dist:
796802
return True
797803

798804
return False

devito/mpi/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def nprocs_local(self):
261261

262262
@property
263263
def topology(self):
264-
return self._topology
264+
return DimensionTuple(*self._topology, getters=self.dimensions)
265265

266266
@property
267267
def topology_logical(self):

devito/mpi/halo_scheme.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,10 @@ def classify(exprs, ispace):
619619
dims.add(ai.parent)
620620
else:
621621
dims.add(ai)
622+
elif ai is not None and not ai._defines & set(f.dimensions):
623+
# Indirect dimension access, e.g., `u[t, rrecx]`.
624+
# Cannot assume it isn't distributed.
625+
dims.add(ai)
622626

623627
if not halo_labels:
624628
continue
@@ -642,12 +646,11 @@ def classify(exprs, ispace):
642646
f"scheme for `{f}` along Dimension `{d}`")
643647
elif hl.pop() is STENCIL:
644648
halos.append(Halo(d, s))
645-
else:
649+
elif d._defines & set(ispace.itdims):
646650
raw_loc_indices[d].append(s)
647651

648652
loc_indices, loc_dirs = process_loc_indices(raw_loc_indices,
649653
ispace.directions)
650-
651654
mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
652655

653656
return mapper

devito/types/grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def is_distributed(self, dim):
358358
True if `dim` is a distributed Dimension for this CartesianDiscretization,
359359
False otherwise.
360360
"""
361-
return any(dim is d for d in self.distributor.dimensions)
361+
return self.distributor.topology.get(dim, 1) > 1
362362

363363
@cached_property
364364
def _arg_names(self):
@@ -530,7 +530,7 @@ def is_distributed(self, dim):
530530
False otherwise.
531531
"""
532532
if self.grid:
533-
return any(dim is d for d in self.distributor.dimensions)
533+
return self.distributor.topology.get(dim, 1) > 1
534534
return False
535535

536536
@property

tests/test_mpi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3156,8 +3156,8 @@ def test_fission_due_to_antidep(self, mode):
31563156
# First, check the generated code
31573157
assert_structure(op1, ['t',
31583158
't,x0_blk0,y0_blk0,x,y,z',
3159-
't,x1_blk0,y1_blk0,x,y,z'],
3160-
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz')
3159+
't,x0_blk0,y0_blk0,x,y,z'],
3160+
'tx0_blk0y0_blk0xyzz')
31613161

31623162
def init(f, v=1):
31633163
f.data[:] = np.indices(grid.shape).sum(axis=0) % (.004*v) + .01

0 commit comments

Comments
 (0)