From 7633a1a2f07198e6926bce0680d688212ad37e04 Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 21 Apr 2026 16:16:31 -0400 Subject: [PATCH 1/4] compiler: fix halo placement for non out dimm exchange --- devito/ir/clusters/algorithms.py | 17 ++++++++++++++++- devito/mpi/halo_scheme.py | 27 ++++++++++++++++++++++++++- devito/passes/iet/mpi.py | 30 ++---------------------------- tests/test_mpi.py | 24 +++++++++++++++++++++--- 4 files changed, 65 insertions(+), 33 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 5fbe13af0f..25a1a57ba3 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -443,7 +443,7 @@ def callback(self, clusters, prefix, seen=None): # Construct a representation of the halo accesses processed = [] - for c in clusters: + for i, c in enumerate(clusters): if c.properties.is_sequential(d) or \ c in seen: continue @@ -453,6 +453,10 @@ def callback(self, clusters, prefix, seen=None): not d._defines & hs.distributed_aindices: continue + if any(halo_write(ci, hs) for ci in clusters[:i]): + # If there's a halo write before `c`, then we cannot inject the HaloTouch + continue + points = set() for f in hs.fmapper: for a in c.scope.getreads(f): @@ -781,3 +785,14 @@ def normalize_reductions_sparse(cluster, sregistry): processed.append(e) return cluster.rebuild(processed) + + +def halo_write(c, hs): + loc_vals = hs.loc_values + + for f in hs.fmapper: + for a in c.scope.getwrites(f): + if set(a.access.indices) & loc_vals: + return True + + return False diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 91328a65e6..68d27b0f61 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -5,7 +5,7 @@ from operator import attrgetter import sympy -from sympy import Max, Min +from sympy import Max, Min, S from devito import configuration from devito.data import CENTER, CORE, LEFT, OWNED, RIGHT @@ -514,6 +514,31 @@ def merge(self, hs): fmapper[f] = fmapper.get(f, hse).merge(hse) return HaloScheme.build(fmapper, self.honored) + def _is_iter_carried(self, scope): + """ + True if the HaloScheme is iteration-carried, i.e., it induces + a halo exchange that requires values from the previous iteration(s); False + otherwise. + """ + + def rule0(dep): + # E.g., `dep=W -> R`, `d=t` => OK + return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause) + + def rule1(dep, loc_indices): + # E.g., `dep=W -> R`, `loc_indices={t: t0}` => OK + return any(dep.distance_mapper[d] == 0 and + dep.source[d] is not v and + dep.sink[d] is not v + for d, v in loc_indices.items()) + + for f, v in self.fmapper.items(): + for dep in scope.d_flow.project(f): + if not rule0(dep) and not rule1(dep, v.loc_indices): + return False + + return True + def classify(exprs, ispace): """ diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 3a3d354905..98a2d427b7 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -206,7 +206,7 @@ def _hoist_invariant(iet): # Ensure there's another HaloScheme that could cover for # us should we get hoisted while still satisfying the # data dependences - if hsf1.issubset(hsf0) and _is_iter_carried(hsf1, scope): + if hsf1.issubset(hsf0) and hsf1._is_iter_carried(scope): hs, hsf = hs1, hsf1 elif hsf0.issubset(hsf1) and hs0 is halo_spots[0]: # Special case @@ -474,32 +474,6 @@ def _check_control_flow(hs0, hs1, cond_mapper): return cond0 != cond1 -def _is_iter_carried(hsf, scope): - """ - True if the provided HaloScheme `hsf` is iteration-carried, i.e., it induces - a halo exchange that requires values from the previous iteration(s); False - otherwise. - """ - - def rule0(dep): - # E.g., `dep=W -> R`, `d=t` => OK - return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause) - - def rule1(dep, loc_indices): - # E.g., `dep=W -> R`, `loc_indices={t: t0}` => OK - return any(dep.distance_mapper[d] == 0 and - dep.source[d] is not v and - dep.sink[d] is not v - for d, v in loc_indices.items()) - - for f, v in hsf.fmapper.items(): - for dep in scope.d_flow.project(f): - if not rule0(dep) and not rule1(dep, v.loc_indices): - return False - - return True - - def _is_mergeable(hsf0, hsf1, scope): """ True if `hsf1` can be merged into `hsf0`, i.e., if they are compatible @@ -515,7 +489,7 @@ def _is_mergeable(hsf0, hsf1, scope): return False # Finally, check the data dependences would be satisfied - return _is_iter_carried(hsf1, scope) + return hsf1._is_iter_carried(scope) def _semantical_eq_loc_indices(hsf0, hsf1): diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 9280cd7b6e..f724e0332e 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -2203,6 +2203,24 @@ def test_lift_halo_update_outside_distributed(self, mode): halo_update = tloop.nodes[0].body[0].body[0].body[0] assert isinstance(halo_update, HaloUpdateList) + @pytest.mark.parallel(mode=4) + def test_halo_inner_dim(self, mode): + grid = Grid((11, 11, 11)) + + np.random.seed(0) + v = TimeFunction(name="v", grid=grid, space_order=4, + time_order=1, save=Buffer(1)) + v.data[:] = np.random.randn(*grid.shape) + e = TimeFunction(name="dummy", grid=grid, space_order=4, time_order=0) + + eq = [Eq(v.forward, v + 1), Eq(e, v.forward.dydz)] + + op = Operator(eq, opt=('advanced', {'blocklevels': 0})) + assert_structure(op, ['txyz', 't', 'txyz', 'txyz'], 'txyzxyzz') + op(time=100) + + assert np.isclose(norm(e), 23484.863, rtol=0, atol=1e-1) + class TestOperatorAdvanced: @@ -2736,7 +2754,7 @@ def test_haloupdate_same_timestep_v2(self, mode): assert titer.dim is grid.time_dim assert titer.nodes[0].body[0].body[0].is_List assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1 - assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call + assert not titer.nodes[0].body[0].body[0].body[0].body[0].is_Call op.apply(time=0) @@ -3138,8 +3156,8 @@ def test_fission_due_to_antidep(self, mode): # First, check the generated code assert_structure(op1, ['t', 't,x0_blk0,y0_blk0,x,y,z', - 't,x0_blk0,y0_blk0,x,y,z'], - 't,x0_blk0,y0_blk0,x,y,z,z') + 't,x1_blk0,y1_blk0,x,y,z'], + 'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz') def init(f, v=1): f.data[:] = np.indices(grid.shape).sum(axis=0) % (.004*v) + .01 From 86fa36f6c15a8a2c3c0f5485352873d8098eca5b Mon Sep 17 00:00:00 2001 From: mloubout Date: Wed, 22 Apr 2026 09:22:42 -0400 Subject: [PATCH 2/4] compiler: refine distributed dimension check for smarter halotouch --- devito/ir/clusters/algorithms.py | 21 ++++++++++++++------- devito/mpi/distributed.py | 2 +- devito/mpi/halo_scheme.py | 3 +-- tests/test_mpi.py | 10 +++++----- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 25a1a57ba3..92354cf3d4 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -453,7 +453,7 @@ def callback(self, clusters, prefix, seen=None): not d._defines & hs.distributed_aindices: continue - if any(halo_write(ci, hs) for ci in clusters[:i]): + if any(_halo_write(ci, hs) for ci in clusters[:i]): # If there's a halo write before `c`, then we cannot inject the HaloTouch continue @@ -787,12 +787,19 @@ def normalize_reductions_sparse(cluster, sregistry): return cluster.rebuild(processed) -def halo_write(c, hs): - loc_vals = hs.loc_values - +def _halo_write(c, hs): + """ + Check if the cluster `c` writes into any of the local values read by `hs`. + """ for f in hs.fmapper: - for a in c.scope.getwrites(f): - if set(a.access.indices) & loc_vals: - return True + if not any(f.grid.distributor.topology.get(d, 1) > 1 + for d in hs.dimensions): + # Not distributed halo dimension, write does not impact the halo exchange + continue + + if any(set(a.access.indices) & hs.loc_values for a in c.scope.getwrites(f)): + # Writing into a local value, which is read by the halo exchange, + # creates a write dependency + return True return False diff --git a/devito/mpi/distributed.py b/devito/mpi/distributed.py index 01edaaaa69..7daeaeadd9 100644 --- a/devito/mpi/distributed.py +++ b/devito/mpi/distributed.py @@ -261,7 +261,7 @@ def nprocs_local(self): @property def topology(self): - return self._topology + return DimensionTuple(*self._topology, getters=self.dimensions) @property def topology_logical(self): diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 68d27b0f61..0feb8e74cc 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -642,12 +642,11 @@ def classify(exprs, ispace): f"scheme for `{f}` along Dimension `{d}`") elif hl.pop() is STENCIL: halos.append(Halo(d, s)) - else: + elif d._defines & set(ispace.itdims): raw_loc_indices[d].append(s) loc_indices, loc_dirs = process_loc_indices(raw_loc_indices, ispace.directions) - mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims) return mapper diff --git a/tests/test_mpi.py b/tests/test_mpi.py index f724e0332e..32f16b123b 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -2,7 +2,6 @@ import numpy as np import pytest -from test_dse import TestTTI from conftest import _R, assert_blocking, assert_structure, body0 from devito import ( @@ -2216,6 +2215,7 @@ def test_halo_inner_dim(self, mode): eq = [Eq(v.forward, v + 1), Eq(e, v.forward.dydz)] op = Operator(eq, opt=('advanced', {'blocklevels': 0})) + assert_structure(op, ['txyz', 't', 'txyz', 'txyz'], 'txyzxyzz') op(time=100) @@ -2754,7 +2754,7 @@ def test_haloupdate_same_timestep_v2(self, mode): assert titer.dim is grid.time_dim assert titer.nodes[0].body[0].body[0].is_List assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1 - assert not titer.nodes[0].body[0].body[0].body[0].body[0].is_Call + assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call op.apply(time=0) @@ -3156,8 +3156,8 @@ def test_fission_due_to_antidep(self, mode): # First, check the generated code assert_structure(op1, ['t', 't,x0_blk0,y0_blk0,x,y,z', - 't,x1_blk0,y1_blk0,x,y,z'], - 'tx0_blk0y0_blk0xyzx1_blk0y1_blk0xyz') + 't,x0_blk0,y0_blk0,x,y,z'], + 'tx0_blk0y0_blk0xyzz') def init(f, v=1): f.data[:] = np.indices(grid.shape).sum(axis=0) % (.004*v) + .01 @@ -3531,9 +3531,9 @@ def test_issue_2448_backward(self, mode): class TestTTIOp: - @pytest.mark.skipif(TestTTI is None, reason="Requires installing the tests") @pytest.mark.parallel(mode=1) def test_halo_structure(self, mode): + from test_dse import TestTTI solver = TestTTI().tti_operator(opt='advanced', space_order=8) op = solver.op_fwd(save=False) From 013bb51c912b90fcb8e6b9090b830f87e41d6e5a Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 23 Apr 2026 09:36:35 -0400 Subject: [PATCH 3/4] compiler: switch to better halo placement --- devito/ir/clusters/algorithms.py | 38 +++++++++----------------------- tests/test_mpi.py | 7 +++--- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/devito/ir/clusters/algorithms.py b/devito/ir/clusters/algorithms.py index 92354cf3d4..34e78752e1 100644 --- a/devito/ir/clusters/algorithms.py +++ b/devito/ir/clusters/algorithms.py @@ -442,8 +442,8 @@ def callback(self, clusters, prefix, seen=None): d = prefix[-1].dim # Construct a representation of the halo accesses - processed = [] - for i, c in enumerate(clusters): + processed = list(clusters) + for n, c in enumerate(clusters): if c.properties.is_sequential(d) or \ c in seen: continue @@ -453,10 +453,6 @@ def callback(self, clusters, prefix, seen=None): not d._defines & hs.distributed_aindices: continue - if any(_halo_write(ci, hs) for ci in clusters[:i]): - # If there's a halo write before `c`, then we cannot inject the HaloTouch - continue - points = set() for f in hs.fmapper: for a in c.scope.getreads(f): @@ -484,10 +480,16 @@ def callback(self, clusters, prefix, seen=None): halo_touch = c.rebuild(exprs=expr, ispace=ispace, properties=properties) - processed.append(halo_touch) - seen.update({halo_touch, c}) + # Insert `halo_touch` at the top of the IterationSpace within which + # `c` is scheduled + index = 0 + for i in reversed(range(n)): + if not processed[i].ispace.is_subset(c.ispace): + index = i + 1 + break + processed.insert(index, halo_touch) - processed.extend(clusters) + seen.update({halo_touch, c}) return processed @@ -785,21 +787,3 @@ def normalize_reductions_sparse(cluster, sregistry): processed.append(e) return cluster.rebuild(processed) - - -def _halo_write(c, hs): - """ - Check if the cluster `c` writes into any of the local values read by `hs`. - """ - for f in hs.fmapper: - if not any(f.grid.distributor.topology.get(d, 1) > 1 - for d in hs.dimensions): - # Not distributed halo dimension, write does not impact the halo exchange - continue - - if any(set(a.access.indices) & hs.loc_values for a in c.scope.getwrites(f)): - # Writing into a local value, which is read by the halo exchange, - # creates a write dependency - return True - - return False diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 32f16b123b..7746e06155 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -2752,9 +2752,10 @@ def test_haloupdate_same_timestep_v2(self, mode): titer = op.body.body[-1].body[0] assert titer.dim is grid.time_dim - assert titer.nodes[0].body[0].body[0].is_List - assert len(titer.nodes[0].body[0].body[0].body[0].body) == 1 - assert titer.nodes[0].body[0].body[0].body[0].body[0].is_Call + block = titer.nodes[0].body[0].body[1] + assert block.is_List + assert len(block.body) == 3 + assert block.body[0].body[0].is_Call op.apply(time=0) From 076af3730afb61ca5233183757b39865397c3d2b Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 23 Apr 2026 13:31:40 -0400 Subject: [PATCH 4/4] compiler: fix terminal detection from dynamic dim bounds --- devito/mpi/halo_scheme.py | 27 +-------------------------- devito/passes/iet/mpi.py | 30 ++++++++++++++++++++++++++++-- devito/symbolics/search.py | 2 ++ tests/test_dle.py | 2 +- tests/test_operator.py | 6 +++--- 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/devito/mpi/halo_scheme.py b/devito/mpi/halo_scheme.py index 0feb8e74cc..fc41ba7ce2 100644 --- a/devito/mpi/halo_scheme.py +++ b/devito/mpi/halo_scheme.py @@ -5,7 +5,7 @@ from operator import attrgetter import sympy -from sympy import Max, Min, S +from sympy import Max, Min from devito import configuration from devito.data import CENTER, CORE, LEFT, OWNED, RIGHT @@ -514,31 +514,6 @@ def merge(self, hs): fmapper[f] = fmapper.get(f, hse).merge(hse) return HaloScheme.build(fmapper, self.honored) - def _is_iter_carried(self, scope): - """ - True if the HaloScheme is iteration-carried, i.e., it induces - a halo exchange that requires values from the previous iteration(s); False - otherwise. - """ - - def rule0(dep): - # E.g., `dep=W -> R`, `d=t` => OK - return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause) - - def rule1(dep, loc_indices): - # E.g., `dep=W -> R`, `loc_indices={t: t0}` => OK - return any(dep.distance_mapper[d] == 0 and - dep.source[d] is not v and - dep.sink[d] is not v - for d, v in loc_indices.items()) - - for f, v in self.fmapper.items(): - for dep in scope.d_flow.project(f): - if not rule0(dep) and not rule1(dep, v.loc_indices): - return False - - return True - def classify(exprs, ispace): """ diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 98a2d427b7..3a3d354905 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -206,7 +206,7 @@ def _hoist_invariant(iet): # Ensure there's another HaloScheme that could cover for # us should we get hoisted while still satisfying the # data dependences - if hsf1.issubset(hsf0) and hsf1._is_iter_carried(scope): + if hsf1.issubset(hsf0) and _is_iter_carried(hsf1, scope): hs, hsf = hs1, hsf1 elif hsf0.issubset(hsf1) and hs0 is halo_spots[0]: # Special case @@ -474,6 +474,32 @@ def _check_control_flow(hs0, hs1, cond_mapper): return cond0 != cond1 +def _is_iter_carried(hsf, scope): + """ + True if the provided HaloScheme `hsf` is iteration-carried, i.e., it induces + a halo exchange that requires values from the previous iteration(s); False + otherwise. + """ + + def rule0(dep): + # E.g., `dep=W -> R`, `d=t` => OK + return not any(dep.distance_mapper[d] is S.Infinity for d in dep.cause) + + def rule1(dep, loc_indices): + # E.g., `dep=W -> R`, `loc_indices={t: t0}` => OK + return any(dep.distance_mapper[d] == 0 and + dep.source[d] is not v and + dep.sink[d] is not v + for d, v in loc_indices.items()) + + for f, v in hsf.fmapper.items(): + for dep in scope.d_flow.project(f): + if not rule0(dep) and not rule1(dep, v.loc_indices): + return False + + return True + + def _is_mergeable(hsf0, hsf1, scope): """ True if `hsf1` can be merged into `hsf0`, i.e., if they are compatible @@ -489,7 +515,7 @@ def _is_mergeable(hsf0, hsf1, scope): return False # Finally, check the data dependences would be satisfied - return hsf1._is_iter_carried(scope) + return _is_iter_carried(hsf1, scope) def _semantical_eq_loc_indices(hsf0, hsf1): diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index f0acdeca72..9c30948064 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -60,6 +60,8 @@ def __init__(self, query: Callable[[Expression], bool], deep: bool = False) -> N def _next(self, expr: Expression) -> Iterable[Expression]: if self.deep and expr.is_Indexed: return expr.indices + elif self.deep and q_dimension(expr): + return expr.bound_symbols elif q_leaf(expr): return () return expr.args diff --git a/tests/test_dle.py b/tests/test_dle.py index 5fa1c305c2..0c0532193b 100644 --- a/tests/test_dle.py +++ b/tests/test_dle.py @@ -755,7 +755,7 @@ def test_dynamic_nthreads(self): ('[Eq(f, 2*f)]', [2, 0, 0], False), ('[Eq(u, 2*u)]', [0, 2, 0, 0], False), ('[Eq(u, 2*u + f)]', [0, 3, 0, 0, 0, 0, 0], True), - ('[Eq(u, 2*u), Eq(f, u.dzr)]', [0, 2, 0, 0, 0], False) + ('[Eq(u, 2*u), Eq(f, u.dzr)]', [0, 2, 0, 0, 2, 0, 0], False) ]) def test_collapsing(self, eqns, expected, blocking): grid = Grid(shape=(3, 3, 3)) diff --git a/tests/test_operator.py b/tests/test_operator.py index e8d8b55bed..d5c9827836 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -1582,12 +1582,12 @@ def test_no_fission_as_illegal(self, exprs): (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z+2])', 'Eq(tu[t,x,y,0], tu[t,x,y,0] + 1.)'), - '+++++', ['txyz', 'txyz', 'txy'], 'txyzz'), + '+++++++', ['txyz', 'txyz', 'txy'], 'txyzxyz'), # 7) WAR 1->2, 2->3 (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z+2])', 'Eq(tw[t,x,y,z], tv[t,x,y,z-1] + 1.)'), - '++++++', ['txyz', 'txyz', 'txyz'], 'txyzzz'), + '++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzz'), # 8) WAR 1->2; WAW 1->3 (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x+2,y,z])', @@ -1597,7 +1597,7 @@ def test_no_fission_as_illegal(self, exprs): (('Eq(tu[t,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z-2])', 'Eq(tw[t,x,y,z], tv[t,x,y+1,z] + 1.)'), - '+++++++', ['txyz', 'txyz', 'txyz'], 'txyzzyz'), + '+++++++++', ['txyz', 'txyz', 'txyz'], 'txyzxyzyz'), # 10) WAR 1->2; WAW 1->3 (('Eq(tu[t-1,x,y,z], tu[t,x,y,z] + tv[t,x,y,z])', 'Eq(tv[t,x,y,z], tu[t,x,y,z+2])',