Skip to content

Commit 0998923

Browse files
FabioLuporinimloubout
authored andcommitted
compiler: Fix DDA with degenerating ModuloDimensions
1 parent 2a0076f commit 0998923

3 files changed

Lines changed: 55 additions & 8 deletions

File tree

devito/ir/support/basic.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,32 +382,36 @@ def distance(self, other):
382382
if disjoint_test(self[n], other[n], sai, sit):
383383
return Vector(S.ImaginaryUnit)
384384

385+
# Compute the distance along the current IterationInterval
385386
if self.function._mem_shared:
386387
# Special case: the distance between two regular, thread-shared
387-
# objects fallbacks to zero, as any other value would be nonsensical
388+
# objects fallbacks to zero, as any other value would be
389+
# nonsensical
390+
ret.append(S.Zero)
391+
elif degenerating_dimensions(sai, oai):
392+
# Special case: `sai` and `oai` may be different symbolic objects
393+
# but they can be proved to systematically generate the same value
388394
ret.append(S.Zero)
389-
390395
elif sai and oai and sai._defines & sit.dim._defines:
391-
# E.g., `self=R<f,[t + 1, x]>`, `self.itintervals=(time, x)`, `ai=t`
396+
# E.g., `self=R<f,[t + 1, x]>`, `self.itintervals=(time, x)`,
397+
# and `ai=t`
392398
if sit.direction is Backward:
393399
ret.append(other[n] - self[n])
394400
else:
395401
ret.append(self[n] - other[n])
396-
397402
elif not sai and not oai:
398403
# E.g., `self=R<a,[3]>` and `other=W<a,[4]>`
399404
if self[n] - other[n] == 0:
400405
ret.append(S.Zero)
401406
else:
402407
break
403-
404408
elif sai in self.ispace and oai in other.ispace:
405409
# E.g., `self=R<f,[x, y]>`, `sai=time`,
406410
# `self.itintervals=(time, x, y)`, `n=0`
407411
continue
408-
409412
else:
410-
# E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`, `n=1`
413+
# E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`,
414+
# and `n=1`
411415
return vinf(ret)
412416

413417
n = len(ret)
@@ -1408,3 +1412,19 @@ def disjoint_test(e0, e1, d, it):
14081412
i1 = sympy.Interval(min(p10, p11), max(p10, p11))
14091413

14101414
return not bool(i0.intersect(i1))
1415+
1416+
1417+
def degenerating_dimensions(d0, d1):
1418+
"""
1419+
True if `d0` and `d1` are Dimensions that are possibly symbolically
1420+
different, but they can be proved to systematically degenerate to the
1421+
same value, False otherwise.
1422+
"""
1423+
# Case 1: ModuloDimensions of size 1
1424+
try:
1425+
if d0.is_Modulo and d1.is_Modulo and d0.modulo == d1.modulo == 1:
1426+
return True
1427+
except AttributeError:
1428+
pass
1429+
1430+
return False

devito/passes/iet/mpi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,9 @@ def _rule0(dep, hs, loc_indices):
386386

387387
def _rule1(dep, hs, loc_indices):
388388
# E.g., `dep=W<f,[t1, x+1]> -> R<f,[t1, xl+1]>` and `loc_indices={t: t0}` => True
389-
return any(dep.distance_mapper[d] == 0 and dep.source[d] is not v
389+
return any(dep.distance_mapper[d] == 0 and
390+
dep.source[d] is not v and
391+
dep.sink[d] is not v
390392
for d, v in loc_indices.items())
391393

392394

tests/test_mpi.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,6 +1832,31 @@ def test_haloupdate_buffer1(self, mode):
18321832
# ModuloDimensions
18331833
assert len([i for i in FindSymbols('dimensions').visit(op) if i.is_Modulo]) == 0
18341834

1835+
@pytest.mark.parallel(mode=1)
1836+
def test_haloupdate_buffer1_v2(self, mode):
1837+
grid = Grid((65, 65, 65), topology=('*', 1, '*'))
1838+
1839+
v1 = TimeFunction(name='v1', grid=grid, space_order=2, time_order=1,
1840+
save=Buffer(1))
1841+
v2 = TimeFunction(name='v2', grid=grid, space_order=2, time_order=1,
1842+
save=Buffer(1))
1843+
1844+
rec = SparseTimeFunction(name='rec', grid=grid, nt=500, npoint=65)
1845+
1846+
eqns = [Eq(v1.forward, v2.dx2 + v1),
1847+
Eq(v2.forward, v1.forward.dx2 + v2),
1848+
rec.interpolate(v2)]
1849+
1850+
op = Operator(eqns)
1851+
op.cfunction
1852+
1853+
# Ensure there's a halo exchange over v2 before the rec interpolation
1854+
section1 = op.body.body[-1].body[1].nodes[1]
1855+
assert section1.is_Section
1856+
calls = FindNodes(HaloUpdateCall).visit(section1)
1857+
assert len(calls) == 1
1858+
assert calls[0].arguments[0] is v2
1859+
18351860

18361861
class TestOperatorAdvanced:
18371862

0 commit comments

Comments
 (0)