Skip to content

Commit 010ab52

Browse files
committed
compiler: Enhance detect_accesses
1 parent bd39274 commit 010ab52

2 files changed

Lines changed: 57 additions & 6 deletions

File tree

devito/ir/support/utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
from itertools import product
44

55
from devito.finite_differences import IndexDerivative
6-
from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals, search
6+
from devito.symbolics import (
7+
CallFromPointer, retrieve_indexed, retrieve_terminals, search
8+
)
79
from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, flatten, split
810
from devito.types import (
9-
Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension
11+
Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension,
12+
TensorMove
1013
)
1114

1215
__all__ = [
@@ -137,7 +140,14 @@ def detect_accesses(exprs):
137140
"""
138141
# Compute M : F -> S
139142
mapper = defaultdict(Stencil)
140-
for e in retrieve_indexed(exprs, deep=True):
143+
144+
# Search among the Indexeds (Most accesses typically stem from Indexeds)
145+
plain_indexeds = retrieve_indexed(exprs, deep=True)
146+
147+
# Search among higher order objects, which still represent meaningful accesses
148+
high_order_indexeds = [i.indexed for i in search(exprs, TensorMove)]
149+
150+
for e in (*plain_indexeds, *high_order_indexeds):
141151
f = e.function
142152

143153
for a, d0 in zip(e.indices, f.dimensions, strict=False):
@@ -164,13 +174,16 @@ def detect_accesses(exprs):
164174
d, others = split(dims, lambda i: d0 in i._defines) # noqa: B023
165175

166176
if any(i.is_Indexed for i in a.args) or len(d) != 1:
167-
# Case 1) -- with indirect accesses there's not much we can infer
177+
# Case 1) -- with indirect accesses there's not much we
178+
# can infer
168179
continue
169180
else:
170181
# Case 2)
171182
d, = d
172183
_, o = split(others, lambda i: i.is_Custom)
173-
off = sum(i for i in a.args if i.is_integer or i.free_symbols & o)
184+
off = sum(
185+
i for i in a.args if i.is_integer or i.free_symbols & o
186+
)
174187
else:
175188
d, = dims
176189

devito/types/parallel.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,47 @@ class TensorMove(Expr, Reserved, Terminal):
407407

408408
"""
409409
Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher
410-
level of the memory hierarchy
410+
level of the memory hierarchy.
411+
412+
Parameters
413+
----------
414+
base : IndexedBase
415+
The base of the AbstractFunction subject of the TensorMove.
416+
tid0 : Dimension
417+
A representation of thread(s) issuing the TensorMove.
418+
coords : tuple
419+
The base address of the TensorMove (one point per Dimension).
411420
"""
412421

422+
__rargs__ = ('base', 'tid0', 'coords')
423+
424+
def __new__(cls, base, tid0, coords, **kwargs):
425+
return super().__new__(cls, base, tid0, coords)
426+
427+
@property
428+
def base(self):
429+
return self.args[0]
430+
431+
@property
432+
def tid0(self):
433+
return self.args[1]
434+
435+
@property
436+
def coords(self):
437+
return self.args[2]
438+
439+
@property
440+
def function(self):
441+
return self.base.function
442+
443+
@cached_property
444+
def indexed(self):
445+
return self.function[self.coords]
446+
447+
@property
448+
def ndim(self):
449+
return self.function.ndim
450+
413451
func = Reserved._rebuild
414452

415453
def _ccode(self, printer):

0 commit comments

Comments
 (0)