Skip to content

Commit 64139dd

Browse files
committed
compiler: Fix DDA with ComponentAccesses and whole-Bundle accesses
1 parent 4a959f6 commit 64139dd

5 files changed

Lines changed: 55 additions & 16 deletions

File tree

devito/ir/support/basic.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from devito.ir.support.utils import AccessMode, extrema
99
from devito.ir.support.vector import LabeledVector, Vector
1010
from devito.symbolics import (compare_ops, retrieve_indexed, retrieve_terminals,
11-
q_constant, q_affine, q_routine, search, uxreplace)
11+
q_constant, q_comp_acc, q_affine, q_routine, search,
12+
uxreplace)
1213
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
1314
flatten, memoized_meth, memoized_generator)
1415
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
@@ -529,9 +530,16 @@ def __hash__(self):
529530
(self.source, self.sink, self.source.timestamp == self.sink.timestamp)
530531
)
531532

532-
@property
533+
@cached_property
533534
def function(self):
534-
return self.source.function
535+
if q_comp_acc(self.source.access) and not q_comp_acc(self.sink.access):
536+
# E.g., `source=ab[x].x` and `sink=ab[x]` -> `a(x)`
537+
return self.source.access.function_access
538+
elif q_comp_acc(self.sink.access) and not q_comp_acc(self.source.access):
539+
# E.g., `source=ab[x]` and `sink=ab[x].y` -> `b(x)`
540+
return self.sink.access.function_access
541+
else:
542+
return self.source.function
535543

536544
@property
537545
def findices(self):
@@ -955,7 +963,7 @@ def reads_gen(self):
955963
@memoized_generator
956964
def reads_smart_gen(self, f):
957965
"""
958-
Generate all read access to a given function.
966+
Generate all read accesses to a given function.
959967
960968
StencilDimensions, if any, are replaced with their extrema.
961969

devito/passes/iet/mpi.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _hoist_invariant(iet):
8383

8484
for it, halo_spots in iter_mapper.items():
8585
for hs0, hs1 in combinations(halo_spots, r=2):
86-
8786
if _check_control_flow(hs0, hs1, cond_mapper):
8887
continue
8988

@@ -129,10 +128,9 @@ def _hoist_invariant(iet):
129128

130129
def _merge_halospots(iet):
131130
"""
132-
Merge HaloSpots on the same Iteration tree level where all data dependencies
133-
would be honored. Avoids redundant halo exchanges when the same data is
134-
redundantly exchanged within the same Iteration tree level as well as to initiate
135-
multiple halo exchanges at once.
131+
Using data dependence analysis, merge HaloSpots on the same IET level. This
132+
has two effects: anticipating communication over computation, and (potentially)
133+
avoiding redundant halo exchanges.
136134
137135
Example:
138136
@@ -141,7 +139,6 @@ def _merge_halospots(iet):
141139
W v[t1]- R v[t0] W v[t1]- R v[t0]
142140
haloupd v[t0], h
143141
W g[t1]- R v[t0], h W g[t1]- R v[t0], h
144-
145142
"""
146143

147144
# Analysis
@@ -155,16 +152,15 @@ def _merge_halospots(iet):
155152
hs0 = halo_spots[0]
156153

157154
for hs1 in halo_spots[1:]:
158-
159155
if _check_control_flow(hs0, hs1, cond_mapper):
160156
continue
161157

162158
for f, v in hs1.fmapper.items():
163159
for dep in scope.d_flow.project(f):
164-
if not any(r(dep, hs1, v.loc_indices) for r in rules):
160+
if not any(rule(dep, hs1, v.loc_indices) for rule in rules):
165161
break
166162
else:
167-
# hs1 is merged with hs0
163+
# All good -- `hs1` can be merged with `hs0`
168164
hs = hs1.halo_scheme.project(f)
169165
mapper[hs0] = HaloScheme.union([mapper.get(hs0, hs0.halo_scheme), hs])
170166
mapper[hs1] = mapper.get(hs1, hs1.halo_scheme).drop(f)

devito/symbolics/queries.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from devito.types.basic import AbstractFunction
77
from devito.types.constant import Constant
88
from devito.types.dimension import Dimension
9+
from devito.types.array import ComponentAccess
910
from devito.types.object import AbstractObject
1011

1112

1213
__all__ = ['q_leaf', 'q_indexed', 'q_terminal', 'q_function', 'q_routine',
1314
'q_terminalop', 'q_indirect', 'q_constant', 'q_affine', 'q_linear',
14-
'q_identity', 'q_symbol', 'q_multivar', 'q_monoaffine', 'q_dimension',
15-
'q_positive', 'q_negative']
15+
'q_identity', 'q_symbol', 'q_comp_acc', 'q_multivar', 'q_monoaffine',
16+
'q_dimension', 'q_positive', 'q_negative']
1617

1718

1819
# The following SymPy objects are considered tree leaves:
@@ -31,6 +32,10 @@ def q_symbol(expr):
3132
return False
3233

3334

35+
def q_comp_acc(expr):
36+
return isinstance(expr, ComponentAccess)
37+
38+
3439
def q_leaf(expr):
3540
return (expr.is_Atom or
3641
expr.is_Indexed or

devito/types/array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,10 @@ def sindex(self):
611611
def function(self):
612612
return self.base.function
613613

614+
@property
615+
def function_access(self):
616+
return self.function.components[self.index]
617+
614618
@property
615619
def indices(self):
616620
return self.base.indices

tests/test_ir.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from devito.ir.support.guards import GuardOverflow
1818
from devito.symbolics import DefFunction, FieldFromPointer
1919
from devito.tools import prod
20-
from devito.types import Array, CriticalRegion, Jump, Scalar, Symbol
20+
from devito.types import Array, Bundle, CriticalRegion, Jump, Scalar, Symbol
2121

2222

2323
class TestVectorHierarchy:
@@ -887,6 +887,32 @@ def test_critical_region_v1(self):
887887
assert len(scope.reads[mocksym1]) == 2
888888
assert len(scope.d_all) == 9
889889

890+
def test_bundle_components(self):
891+
grid = Grid(shape=(4, 4))
892+
x, y = grid.dimensions
893+
894+
f = Function(name='f', grid=grid)
895+
g = Function(name='g', grid=grid)
896+
v = Function(name='v', grid=grid)
897+
w = Function(name='w', grid=grid)
898+
u0 = Function(name='u0', grid=grid)
899+
u1 = Function(name='u1', grid=grid)
900+
901+
fg = Bundle(name='fg', components=(f, g))
902+
vw = Bundle(name='vw', components=(v, w))
903+
904+
exprs = [Eq(fg.indexify(), 1),
905+
Eq(u0.indexify(), fg[0, x, y] + 2),
906+
Eq(vw[0, x, y], 3),
907+
Eq(u1.indexify(), vw[1, x, y] + 4)]
908+
exprs = [LoweredEq(i) for i in exprs]
909+
910+
scope = Scope(exprs)
911+
assert len(scope.d_all) == 1
912+
assert len(scope.d_flow) == 1
913+
dep, = scope.d_flow
914+
assert dep.function is f
915+
890916

891917
class TestParallelismAnalysis:
892918

0 commit comments

Comments
 (0)