Skip to content

Commit 3a8c97a

Browse files
committed
compiler: refine distributed dimension check for smarter halotouch
1 parent b4d0594 commit 3a8c97a

4 files changed

Lines changed: 19 additions & 12 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,14 +472,14 @@ def callback(self, clusters, prefix, seen=None):
472472
# the args is important because that's what search functions honor!
473473
points = sorted(points, key=str)
474474

475-
# Construct the HaloTouch Cluster
476-
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
477-
478475
key0 = lambda i: i in prefix[:-1] or i in hs.loc_indices # noqa: B023
479476
key1 = lambda i: i not in hs.distributed_defined # noqa: B023
480477
key = lambda i: key0(i) and key1(i) # noqa: B023
481478
ispace = c.ispace.project(key)
482479

480+
# Construct the HaloTouch Cluster
481+
expr = Eq(self.B, HaloTouch(*points, halo_scheme=hs))
482+
483483
properties = c.properties.sequentialize()
484484

485485
halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties)
@@ -792,7 +792,10 @@ def halo_write(c, hs):
792792

793793
for f in hs.fmapper:
794794
for a in c.scope.getwrites(f):
795-
if set(a.access.indices) & loc_vals:
795+
is_write = set(a.access.indices) & loc_vals
796+
is_dist = any(c.grid.distributor.topology.get(d, 1) > 1
797+
for d in hs.dimensions)
798+
if is_write and is_dist:
796799
return True
797800

798801
return False

devito/mpi/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def nprocs_local(self):
261261

262262
@property
263263
def topology(self):
264-
return self._topology
264+
return DimensionTuple(*self._topology, getters=self.dimensions)
265265

266266
@property
267267
def topology_logical(self):

devito/mpi/halo_scheme.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,10 @@ def classify(exprs, ispace):
619619
dims.add(ai.parent)
620620
else:
621621
dims.add(ai)
622+
elif ai is not None and not ai._defines & set(f.dimensions):
623+
# Indirect dimension access, e.g., `u[t, rrecx]`.
624+
# Cannot assume it isn't distributed.
625+
dims.add(ai)
622626

623627
if not halo_labels:
624628
continue
@@ -642,12 +646,11 @@ def classify(exprs, ispace):
642646
f"scheme for `{f}` along Dimension `{d}`")
643647
elif hl.pop() is STENCIL:
644648
halos.append(Halo(d, s))
645-
else:
649+
elif d._defines & set(ispace.itdims):
646650
raw_loc_indices[d].append(s)
647651

648652
loc_indices, loc_dirs = process_loc_indices(raw_loc_indices,
649653
ispace.directions)
650-
651654
mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)
652655

653656
return mapper

tests/test_mpi.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
import pytest
5-
from test_dse import TestTTI
65

76
from conftest import _R, assert_blocking, assert_structure, body0
87
from devito import (
@@ -2216,6 +2215,8 @@ def test_halo_inner_dim(self, mode):
22162215
eq = [Eq(v.forward, v + 1), Eq(e, v.forward.dydz)]
22172216

22182217
op = Operator(eq, opt=('advanced', {'blocklevels': 0}))
2218+
if grid.distributor.comm.rank == 0:
2219+
print(op)
22192220
assert_structure(op, ['txyz', 't', 'txyz', 'txyz'], 'txyzxyzz')
22202221
op(time=100)
22212222

@@ -2754,7 +2755,7 @@ def test_haloupdate_same_timestep_v2(self, mode):
27542755
assert titer.dim is grid.time_dim
27552756
assert titer.nodes[0].body[0].body[0].is_List
27562757
assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1
2757-
assert not titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
2758+
assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call
27582759

27592760
op.apply(time=0)
27602761

@@ -3156,8 +3157,8 @@ def test_fission_due_to_antidep(self, mode):
31563157
# First, check the generated code
31573158
assert_structure(op1, ['t',
31583159
't,x0_blk0,y0_blk0,x,y,z',
3159-
't,x1_blk0,y1_blk0,x,y,z'],
3160-
'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz')
3160+
't,x0_blk0,y0_blk0,x,y,z'],
3161+
'tx0_blk0y0_blk0xyzz')
31613162

31623163
def init(f, v=1):
31633164
f.data[:] = np.indices(grid.shape).sum(axis=0) % (.004*v) + .01
@@ -3531,9 +3532,9 @@ def test_issue_2448_backward(self, mode):
35313532

35323533
class TestTTIOp:
35333534

3534-
@pytest.mark.skipif(TestTTI is None, reason="Requires installing the tests")
35353535
@pytest.mark.parallel(mode=1)
35363536
def test_halo_structure(self, mode):
3537+
from test_dse import TestTTI
35373538
solver = TestTTI().tti_operator(opt='advanced', space_order=8)
35383539
op = solver.op_fwd(save=False)
35393540

0 commit comments

Comments
 (0)