Skip to content

Commit feb82b7

Browse files
committed
compiler: Tweaks to CTemp purging in CSE
1 parent d9ca141 commit feb82b7

6 files changed

Lines changed: 42 additions & 6 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class Schedule(QueueStateful):
122122

123123
@timed_pass(name='schedule')
124124
def process(self, clusters):
125+
# from IPython import embed; embed()
125126
return self._process_fatd(clusters, 1)
126127

127128
def callback(self, clusters, prefix, backlog=None, known_break=None):
@@ -156,6 +157,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
156157
# Schedule Clusters over different IterationSpaces if this increases
157158
# parallelism
158159
for i in range(1, len(clusters)):
160+
# FIXME: This eats a lot of time (four seconds each time)
159161
if self._break_for_parallelism(scope, candidates, i):
160162
return self.callback(clusters[:i], prefix, clusters[i:] + backlog,
161163
candidates | known_break)
@@ -191,6 +193,8 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
191193
def _break_for_parallelism(self, scope, candidates, i):
192194
# `test` will be True if there's at least one data-dependence that would
193195
# break parallelism
196+
197+
# TODO: Can this loop be made to short-circuit?
194198
test = False
195199
for d in scope.d_from_access_gen(scope.a_query(i)):
196200
if d.is_local or d.is_storage_related(candidates):

devito/ir/support/basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def distance(self, other):
363363
# Case 1: `sit` is an IterationInterval with statically known
364364
# trip count. E.g. it ranges from 0 to 3; `other` performs a
365365
# constant access at 4
366+
# TODO: This case represents the majority of time constructing a DAG
366367
for v in (self[n], other[n]):
367368
try:
368369
if bool(v < sit.symbolic_min or v > sit.symbolic_max):
@@ -830,6 +831,7 @@ def __init__(self, exprs, rules=None):
830831
self.rules = as_tuple(rules)
831832
assert all(callable(i) for i in self.rules)
832833

834+
# FIXME: Should be put somewhere sensible
833835
@cached_property
834836
def thingy(self):
835837
return any(i.cause for i in self.d_anti_gen())

devito/operator/operator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -967,8 +967,10 @@ def _emit_build_profiling(self):
967967
tot = timings.pop('op-compile')
968968
perf(f"Operator `{self.name}` generated in {fround(tot):.2f} s")
969969

970-
max_hotspots = 3
971-
threshold = 20.
970+
# max_hotspots = 3
971+
# threshold = 20.
972+
max_hotspots = 300
973+
threshold = 0.5
972974

973975
def _emit_timings(timings, indent=''):
974976
timings.pop('total', None)

devito/passes/clusters/cse.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class CTemp(Temp):
2525
"""
2626
A cluster-level Temp, similar to Temp, ensured to have different priority
2727
"""
28+
is_CTemp = True
29+
2830
ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp')
2931

3032

@@ -220,13 +222,39 @@ def _compact(exprs, exclude):
220222
`for (i = ...) { a = b; for (j = a ...) ... }`. Hence, this routine
221223
only targets CTemps.
222224
"""
225+
# FIXME: Can use is_CTemp rather than isinstance
223226
candidates = [e for e in exprs
224227
if isinstance(e.lhs, CTemp) and e.lhs not in exclude]
225228

226229
mapper = {e.lhs: e.rhs for e in candidates if q_leaf(e.rhs)}
227230

228-
mapper.update({e.lhs: e.rhs for e in candidates
229-
if sum([i.rhs.count(e.lhs) for i in exprs]) == 1})
231+
# FIXME: Move this to searches as retrieve_ctemps
232+
from devito.symbolics.search import search
233+
234+
def q_ctemp(expr):
235+
try:
236+
return expr.is_CTemp
237+
except AttributeError:
238+
return False
239+
240+
# Find all the CTemps in expressions without removing duplicates
241+
# ctemps = search(exprs, q_ctemp, 'all', 'dfs')
242+
# I think it was more like
243+
ctemps = search([e.rhs for e in exprs], q_ctemp, 'all', 'dfs')
244+
245+
# print(ctemps, len(ctemps), len(set(ctemps)), len(candidates))
246+
247+
# FIXME: This line is kinda slow. I should find some way to replace it.
248+
# FIXME: Specifically sum([i.rhs.count(e.lhs) for i in exprs]) == 1 is slow as hell
249+
# mapper.update({e.lhs: e.rhs for e in candidates
250+
# if sum([i.rhs.count(e.lhs) for i in exprs]) == 1})
251+
252+
# If there are ctemps in the expressions, then add any to the mapper which only
253+
# appear once
254+
# TODO: Double check this is exactly the prior behaviour?
255+
if ctemps:
256+
mapper.update({e.lhs: e.rhs for e in candidates
257+
if ctemps.count(e.lhs) == 1})
230258

231259
processed = []
232260
for e in exprs:

devito/passes/clusters/misc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,7 @@ def is_cross(source, sink):
362362
# (intuitively, "the loop nests are to be kept separated")
363363
# * All ClusterGroups between `cg0` and `cg1` must precede `cg1`
364364
# * All ClusterGroups after `cg1` cannot precede `cg1`
365-
# FIXME: Attach to the scope
366-
# if any(i.cause & prefix for i in scope.d_anti_gen()):
365+
# FIXME: This is a terrible variable name
367366
if prefix and scope.thingy:
368367
for cg2 in cgroups[n:cgroups.index(cg1)]:
369368
dag.add_edge(cg2, cg1)

devito/types/basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class Basic(CodeSymbol):
298298
is_Object = False
299299
is_LocalObject = False
300300
is_LocalType = False
301+
is_CTemp = False
301302

302303
# Created by the user
303304
is_Input = False

0 commit comments

Comments
 (0)