Skip to content

Commit b2bc705

Browse files
committed
compiler: fix halo placement for non out dimm exchange
1 parent bab9aae commit b2bc705

2 files changed

Lines changed: 36 additions & 1 deletion

File tree

devito/ir/clusters/algorithms.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def callback(self, clusters, prefix, seen=None):
443443

444444
# Construct a representation of the halo accesses
445445
processed = []
446-
for c in clusters:
446+
for i, c in enumerate(clusters):
447447
if c.properties.is_sequential(d) or \
448448
c in seen:
449449
continue
@@ -453,6 +453,11 @@ def callback(self, clusters, prefix, seen=None):
453453
not d._defines & hs.distributed_aindices:
454454
continue
455455

456+
if any(halo_write(ci, hs) for ci in clusters[:i]) and \
457+
hs.dimensions & set(prefix.itdims):
458+
# If there's a halo write before `c`, then we cannot inject the HaloTouch
459+
continue
460+
456461
points = set()
457462
for f in hs.fmapper:
458463
for a in c.scope.getreads(f):
@@ -781,3 +786,15 @@ def normalize_reductions_sparse(cluster, sregistry):
781786
processed.append(e)
782787

783788
return cluster.rebuild(processed)
789+
790+
791+
def halo_write(c, hs):
792+
loc_vals = hs.loc_values
793+
for f in hs.fmapper:
794+
# Does `c` write to any of the halo points of `f`?
795+
w = c.scope.writes.get(f, None)
796+
if w is None:
797+
continue
798+
elif any(set(a.indices) & loc_vals for a in [wi.access for wi in w]):
799+
return True
800+
return False

tests/test_mpi.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2203,6 +2203,24 @@ def test_lift_halo_update_outside_distributed(self, mode):
22032203
halo_update = tloop.nodes[0].body[0].body[0].body[0]
22042204
assert isinstance(halo_update, HaloUpdateList)
22052205

2206+
@pytest.mark.parallel(mode=4)
2207+
def test_halo_inner_dim(self, mode):
2208+
grid = Grid((11, 11, 11))
2209+
2210+
np.random.seed(0)
2211+
v = TimeFunction(name="v", grid=grid, space_order=4,
2212+
time_order=1, save=Buffer(1))
2213+
v.data[:] = np.random.randn(*grid.shape)
2214+
e = TimeFunction(name="dummy", grid=grid, space_order=4, time_order=0)
2215+
2216+
eq = [Eq(v.forward, v + 1), Eq(e, v.forward.dydz)]
2217+
2218+
op = Operator(eq, opt=('advanced', {'blocklevels': 0}))
2219+
assert_structure(op, ['txyz', 't', 'txyz', 'txyz'], 'txyzxyzz')
2220+
op(time=100)
2221+
2222+
assert np.isclose(norm(e), 23484.863, rtol=0, atol=1e-1)
2223+
22062224

22072225
class TestOperatorAdvanced:
22082226

0 commit comments

Comments
 (0)