33from itertools import product
44
55from devito .finite_differences import IndexDerivative
6- from devito .symbolics import CallFromPointer , retrieve_indexed , retrieve_terminals , search
6+ from devito .symbolics import (
7+ CallFromPointer , retrieve_indexed , retrieve_terminals , search
8+ )
79from devito .tools import DefaultOrderedDict , as_tuple , filter_sorted , flatten , split
810from devito .types import (
9- Dimension , DimensionTuple , Indirection , ModuloDimension , StencilDimension
11+ Dimension , DimensionTuple , Indirection , ModuloDimension , StencilDimension ,
12+ TensorMove
1013)
1114
1215__all__ = [
@@ -137,7 +140,14 @@ def detect_accesses(exprs):
137140 """
138141 # Compute M : F -> S
139142 mapper = defaultdict (Stencil )
140- for e in retrieve_indexed (exprs , deep = True ):
143+
144+ # Search among the Indexeds (Most accesses typically stem from Indexeds)
145+ plain_indexeds = retrieve_indexed (exprs , deep = True )
146+
147+ # Search among higher order objects, which still represent meaningful accesses
148+ high_order_indexeds = [i .indexed for i in search (exprs , TensorMove )]
149+
150+ for e in (* plain_indexeds , * high_order_indexeds ):
141151 f = e .function
142152
143153 for a , d0 in zip (e .indices , f .dimensions , strict = False ):
@@ -164,13 +174,16 @@ def detect_accesses(exprs):
164174 d , others = split (dims , lambda i : d0 in i ._defines ) # noqa: B023
165175
166176 if any (i .is_Indexed for i in a .args ) or len (d ) != 1 :
167- # Case 1) -- with indirect accesses there's not much we can infer
177+ # Case 1) -- with indirect accesses there's not much we
178+ # can infer
168179 continue
169180 else :
170181 # Case 2)
171182 d , = d
172183 _ , o = split (others , lambda i : i .is_Custom )
173- off = sum (i for i in a .args if i .is_integer or i .free_symbols & o )
184+ off = sum (
185+ i for i in a .args if i .is_integer or i .free_symbols & o
186+ )
174187 else :
175188 d , = dims
176189
0 commit comments