diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 5fbe13af0f..34e78752e1 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -442,8 +442,8 @@ def callback(self, clusters, prefix, seen=None): d = prefix[-1].dim # Construct a representation of the halo accesses - processed = [] - for c in clusters: + processed = list(clusters) + for n, c in enumerate(clusters): if c.properties.is_sequential(d) or \ c in seen: continue @@ -480,10 +480,16 @@ def callback(self, clusters, prefix, seen=None): halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties) - processed.append(halo_touch) - seen.update({halo_touch, c}) + # Insert `halo_touch` at the top of the IterationSpace within which + # `c` is scheduled + index = 0 + for i in reversed(range(n)): + if not processed[i].ispace.is_subset(c.ispace): + index = i + 1 + break + processed.insert(index, halo_touch) - processed.extend(clusters) + seen.update({halo_touch, c}) return processed diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index 01edaaaa69..7daeaeadd9 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -261,7 +261,7 @@ def nprocs_local(self): @property def topology(self): - return self._topology + return DimensionTuple(*self._topology, getters=self.dimensions) @property def topology_logical(self): diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 91328a65e6..fc41ba7ce2 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -617,12 +617,11 @@ def classify(exprs, ispace): f"scheme for `{f}` along Dimension `{d}`") elif hl.pop() is STENCIL: halos.append(Halo(d, s)) - else: + elif d._defines & set(ispace.itdims): raw_loc_indices[d].append(s) loc_indices, loc_dirs = process_loc_indices(raw_loc_indices, ispace.directions) - mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims) return mapper diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index f0acdeca72..9c30948064 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -60,6 +60,8 @@ def __init__(self, query: Callable[[Expression], bool], deep: bool = False) -> N def _next(self, expr: Expression) -> Iterable[Expression]: if self.deep and expr.is_Indexed: return expr.indices + elif self.deep and q_dimension(expr): + return expr.bound_symbols elif q_leaf(expr): return () return expr.args diff --git a/tests/test_dle.py b/tests/test_dle.py index 5fa1c305c2..0c0532193b 100644 --- a/tests/test_dle.py +++ b/tests/test_dle.py @@ -755,7 +755,7 @@ def test_dynamic_nthreads(self): ('[Eq(f, 2*f)]', [2, 0, 0], False), ('[Eq(u, 2*u)]', [0, 2, 0, 0], False), ('[Eq(u, 2*u + f)]', [0, 3, 0, 0, 0, 0, 0], True), - ('[Eq(u, 2*u), Eq(f, u.dzr)]', [0, 2, 0, 0, 0], False) + ('[Eq(u, 2*u), Eq(f, u.dzr)]', [0, 2, 0, 0, 2, 0, 0], False) ]) def test_collapsing(self, eqns, expected, blocking): grid = Grid(shape=(3, 3, 3)) diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 9280cd7b6e..7746e06155 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -2,7 +2,6 @@ import numpy as np import pytest -from test_dse import TestTTI from conftest import _R, assert_blocking, assert_structure, body0 from devito import ( @@ -2203,6 +2202,25 @@ def test_lift_halo_update_outside_distributed(self, mode): halo_update = tloop.nodes[0].body[0].body[0].body[0] assert isinstance(halo_update, HaloUpdateList) + @pytest.mark.parallel(mode=4) + def test_halo_inner_dim(self, mode): + grid = Grid((11, 11, 11)) + + np.random.seed(0) + v = TimeFunction(name="v", grid=grid, space_order=4, + time_order=1, save=Buffer(1)) + v.data[:] = np.random.randn(*grid.shape) + e = TimeFunction(name="dummy", grid=grid, space_order=4, time_order=0) + + eq = [Eq(v.forward, v + 1), Eq(e, v.forward.dydz)] + + op = Operator(eq, opt=('advanced', {'blocklevels': 0})) + + assert_structure(op, ['txyz', 't', 'txyz', 'txyz'], 'txyzxyzz') + op(time=100) + + assert np.isclose(norm(e), 23484.863, rtol=0, atol=1e-1) + class TestOperatorAdvanced: @@ -2734,9 +2752,10 @@ def test_haloupdate_same_timestep_v2(self, mode): titer = op.body.body[-1].body[0] assert titer.dim is grid.time_dim - assert titer.nodes[0].body[0].body[0].is_List - assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1 - assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call + block = titer.nodes[0].body[0].body[1] + assert block.is_List + assert len(block.body) == 3 + assert block.body[0].body[0].is_Call op.apply(time=0) @@ -3139,7 +3158,7 @@ def test_fission_due_to_antidep(self, mode): assert_structure(op1, ['t', 't,x0_blk0,y0_blk0,x,y,z', 't,x0_blk0,y0_blk0,x,y,z'], - 't,x0_blk0,y0_blk0,x,y,z,z') + 'tx0_blk0y0_blk0xyzz') def init(f, v=1): f.data[:] = np.indices(grid.shape).sum(axis=0) % (.004*v) + .01 @@ -3513,9 +3532,9 @@ def test_issue_2448_backward(self, mode): class TestTTIOp: - @pytest.mark.skipif(TestTTI is None, reason="Requires installing the tests") @pytest.mark.parallel(mode=1) def test_halo_structure(self, mode): + from test_dse import TestTTI solver = TestTTI().tti_operator(opt='advanced', space_order=8) op = solver.op_fwd(save=False) diff --git a/tests/test_operator.py b/tests/test_operator.py index e8d8b55bed..d5c9827836 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -1582,12 +1582,12 @@ def test_no_fission_as_illegal(self, exprs): (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z+2])', 'Eq(tu[t,x,y,0], tu[t,x,y,0] + 1.)'), - '+++++', ['txyz', 'txyz', 'txy'], 'txyzz'), + '+++++++', ['txyz', 'txyz', 'txy'], 'txyzxyz'), # 7) WAR 1->2, 2->3 (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z+2])', 'Eq(tw[t,x,y,z], tv[t,x,y,z-1] + 1.)'), - '++++++', ['txyz', 'txyz', 'txyz'], 'txyzzz'), + '++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzz'), # 8) WAR 1->2; WAW 1->3 (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x+2,y,z])', @@ -1597,7 +1597,7 @@ def test_no_fission_as_illegal(self, exprs): (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z-2])', 'Eq(tw[t,x,y,z], tv[t,x,y+1,z] + 1.)'), - '+++++++', ['txyz', 'txyz', 'txyz'], 'txyzzyz'), + '+++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzyz'), # 10) WAR 1->2; WAW 1->3 (('Eq(tu[t-1,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z+2])',