Skip to content

Commit 7d618d6

Browse files
committed
compiler: Tidy up new searches and functionality tweaks
1 parent ba43ed8 commit 7d618d6

6 files changed

Lines changed: 24 additions & 32 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ class Schedule(QueueStateful):
122122

123123
@timed_pass(name='schedule')
124124
def process(self, clusters):
125-
# from IPython import embed; embed()
126125
return self._process_fatd(clusters, 1)
127126

128127
def callback(self, clusters, prefix, backlog=None, known_break=None):

devito/ir/support/basic.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -831,11 +831,6 @@ def __init__(self, exprs, rules=None):
831831
self.rules = as_tuple(rules)
832832
assert all(callable(i) for i in self.rules)
833833

834-
# FIXME: Should be put somewhere sensible
835-
@cached_property
836-
def thingy(self):
837-
return any(i.cause for i in self.d_anti_gen())
838-
839834
@memoized_generator
840835
def writes_gen(self):
841836
"""
@@ -1128,6 +1123,10 @@ def d_anti(self):
11281123
"""Anti (or "write-after-read") dependences."""
11291124
return DependenceGroup(self.d_anti_gen())
11301125

1126+
@cached_property
1127+
def has_antidependencies(self):
1128+
return any(i.cause for i in self.d_anti_gen())
1129+
11311130
@memoized_generator
11321131
def d_output_gen(self):
11331132
"""Generate the output (or "write-after-write") dependences."""

devito/passes/clusters/cse.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from devito.finite_differences.differentiable import IndexDerivative
1414
from devito.ir import Cluster, Scope, cluster_pass
1515
from devito.symbolics import estimate_cost, q_leaf, q_terminal
16+
from devito.symbolics.search import retrieve_ctemps
1617
from devito.symbolics.manipulation import _uxreplace
1718
from devito.tools import DAG, as_list, as_tuple, frozendict, extract_dtype
1819
from devito.types import Eq, Symbol, Temp
@@ -222,36 +223,17 @@ def _compact(exprs, exclude):
222223
`for (i = ...) { a = b; for (j = a ...) ... }`. Hence, this routine
223224
only targets CTemps.
224225
"""
225-
# FIXME: Can use is_CTemp rather than isinstance
226226
candidates = [e for e in exprs
227227
if isinstance(e.lhs, CTemp) and e.lhs not in exclude]
228228

229229
mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)}
230230

231-
# FIXME: Move this to searches as retrieve_ctemps
232-
from devito.symbolics.search import search
233-
234-
def q_ctemp(expr):
235-
try:
236-
return expr.is_CTemp
237-
except AttributeError:
238-
return False
239-
240-
# Find all the CTemps in expressions without removing duplicates
241-
# ctemps = search(exprs, q_ctemp, 'all', 'dfs')
242-
# I think it was more like
243-
ctemps = search([e.rhs for e in exprs], q_ctemp, 'all', 'dfs')
244-
245-
# print(ctemps, len(ctemps), len(set(ctemps)), len(candidates))
246-
247-
# FIXME: This line is kinda slow. I should find some way to replace it.
248-
# FIXME: Specifically sum([i.rhs.count(e.lhs) for i in exprs]) == 1 is slow as hell
249-
# mapper.update({e.lhs: e.rhs for e in candidates
250-
# if sum([i.rhs.count(e.lhs) for i in exprs]) == 1})
231+
# Find all the CTemps in expression right-hand-sides without removing duplicates
232+
ctemps = retrieve_ctemps([e.rhs for e in exprs])
251233

252234
# If there are ctemps in the expressions, then add any to the mapper which only
253235
# appear once
254-
# TODO: Double check this is exactly the prior behaviour?
236+
# TODO: Double check this is exactly the prior behaviour
255237
if ctemps:
256238
mapper.update({e.lhs: e.rhs for e in candidates
257239
if ctemps.count(e.lhs) == 1})

devito/passes/clusters/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,8 @@ def is_cross(source, sink):
362362
# (intuitively, "the loop nests are to be kept separated")
363363
# * All ClusterGroups between `cg0` and `cg1` must precede `cg1`
364364
# * All ClusterGroups after `cg1` cannot precede `cg1`
365-
# FIXME: This is a terrible variable name
366-
if prefix and scope.thingy:
365+
# TODO: Check that this is indeed what the attribute does
366+
if prefix and scope.has_antidependencies:
367367
for cg2 in cgroups[n:cgroups.index(cg1)]:
368368
dag.add_edge(cg2, cg1)
369369
for cg2 in cgroups[cgroups.index(cg1)+1:]:

devito/symbolics/queries.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ def q_symbol(expr):
3232
return False
3333

3434

35+
def q_ctemp(expr):
36+
try:
37+
return expr.is_CTemp
38+
except AttributeError:
39+
return False
40+
41+
3542
def q_comp_acc(expr):
3643
return isinstance(expr, ComponentAccess)
3744

devito/symbolics/search.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sympy
22

33
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf,
4-
q_symbol, q_dimension, q_derivative)
4+
q_symbol, q_ctemp, q_dimension, q_derivative)
55
from devito.tools import as_tuple
66

77
__all__ = ['retrieve_indexed', 'retrieve_functions', 'retrieve_function_carriers',
@@ -155,10 +155,15 @@ def retrieve_functions(exprs, mode='all', deep=False):
155155

156156

157157
def retrieve_symbols(exprs, mode='all'):
158-
"""Shorthand to retrieve the Scalar in ``exprs``."""
158+
"""Shorthand to retrieve the Scalar in `exprs`."""
159159
return search(exprs, q_symbol, mode, 'dfs')
160160

161161

162+
def retrieve_ctemps(exprs, mode='all'):
163+
"""Shorthand to retrieve the CTemps in `exprs`"""
164+
return search(exprs, q_ctemp, mode, 'dfs')
165+
166+
162167
def retrieve_function_carriers(exprs, mode='all'):
163168
"""
164169
Shorthand to retrieve the DiscreteFunction carriers in ``exprs``. An

0 commit comments

Comments
 (0)