Skip to content

Commit 325d566

Browse files
committed
compiler: Rework Scope + Dependence caching
1 parent 9c2cf0f commit 325d566

11 files changed

Lines changed: 42 additions & 65 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
136136
# `clusters` are supposed to share it
137137
candidates = prefix[-1].dim._defines
138138

139-
scope = Scope.fetch(flatten(c.exprs for c in clusters))
139+
scope = Scope(flatten(c.exprs for c in clusters))
140140

141141
# Handle the nastiest case -- ambiguity due to the presence of both a
142142
# flow- and an anti-dependence.

devito/ir/clusters/analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _callback(self, clusters, dim, prefix):
9999
is_parallel_indep = True
100100
is_parallel_atomic = False
101101

102-
scope = Scope.fetch(flatten(c.exprs for c in clusters))
102+
scope = Scope(flatten(c.exprs for c in clusters))
103103
for dep in scope.d_all_gen():
104104
test00 = dep.is_indep(dim) and not dep.is_storage_related(dim)
105105
test01 = all(dep.is_reduce_atmost(i) for i in prev)
@@ -136,7 +136,7 @@ class Affiness(Detector):
136136
"""
137137

138138
def _callback(self, clusters, dim, prefix):
139-
scope = Scope.fetch(flatten(c.exprs for c in clusters))
139+
scope = Scope(flatten(c.exprs for c in clusters))
140140
accesses = [a for a in scope.accesses if not a.is_scalar]
141141

142142
if all(a.is_regular and a.affine_if_present(dim._defines) for a in accesses):

devito/ir/clusters/cluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def dist_dimensions(self):
183183

184184
@cached_property
185185
def scope(self):
186-
return Scope.fetch(self.exprs)
186+
return Scope(self.exprs)
187187

188188
@cached_property
189189
def functions(self):
@@ -473,7 +473,7 @@ def exprs(self):
473473

474474
@cached_property
475475
def scope(self):
476-
return Scope.fetch(exprs=self.exprs)
476+
return Scope(exprs=self.exprs)
477477

478478
@cached_property
479479
def ispace(self):

devito/ir/support/basic.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from collections.abc import Iterable
22
from itertools import chain, product
3-
from functools import cached_property, lru_cache
3+
from functools import cached_property
44
from typing import Callable
55

66
from sympy import S, Expr
77
import sympy
88

9+
from devito.ir.support.caching import CacheInstances
910
from devito.ir.support.space import Backward, null_ispace
1011
from devito.ir.support.utils import AccessMode, extrema
1112
from devito.ir.support.vector import LabeledVector, Vector
@@ -626,21 +627,12 @@ def is_imaginary(self):
626627
return S.ImaginaryUnit in self.distance
627628

628629

629-
class Dependence(Relation):
630+
class Dependence(Relation, CacheInstances):
630631

631632
"""
632633
A data dependence between two TimedAccess objects.
633634
"""
634635

635-
@classmethod
636-
@lru_cache(maxsize=128)
637-
def fetch(cls: type['Dependence'],
638-
source: TimedAccess, sink: TimedAccess) -> 'Dependence':
639-
"""
640-
Obtain a (potentially cached) Dependence for analysis.
641-
"""
642-
return cls(source, sink)
643-
644636
def __repr__(self):
645637
return "%s -> %s" % (self.source, self.sink)
646638

@@ -834,14 +826,18 @@ def project(self, function):
834826
return DependenceGroup(i for i in self if i.function is function)
835827

836828

837-
class Scope:
829+
class Scope(CacheInstances):
838830

839831
# Describes a rule for dependencies
840832
Rule = Callable[[TimedAccess, TimedAccess], bool]
841833

834+
@classmethod
835+
def _preprocess_args(cls, exprs: Expr | Iterable[Expr],
836+
**kwargs) -> tuple[tuple, dict]:
837+
return (as_tuple(exprs),), kwargs
838+
842839
def __init__(self, exprs: tuple[Expr],
843-
rules: Rule | tuple[Rule] | None = None) \
844-
-> None:
840+
rules: Rule | tuple[Rule] | None = None) -> None:
845841
"""
846842
A Scope enables data dependence analysis on a totally ordered sequence
847843
of expressions.
@@ -852,24 +848,6 @@ def __init__(self, exprs: tuple[Expr],
852848
self.rules: tuple[Scope.Rule] = as_tuple(rules) # type: ignore[assignment]
853849
assert all(callable(i) for i in self.rules)
854850

855-
@classmethod
856-
@lru_cache(maxsize=128)
857-
def _fetch(cls: type['Scope'], exprs: tuple[Expr],
858-
rules: Rule | tuple[Rule] | None = None) -> 'Scope':
859-
"""
860-
Obtains a (potentially cached) Scope from a sequence of expressions.
861-
Helper function called with hashable arguments.
862-
"""
863-
return cls(exprs, rules=rules)
864-
865-
@classmethod
866-
def fetch(cls: type['Scope'], exprs: Expr | Iterable[Expr],
867-
rules: Rule | tuple[Rule] | None = None) -> 'Scope':
868-
"""
869-
Obtains a (potentially cached) Scope from a sequence of expressions.
870-
"""
871-
return cls._fetch(as_tuple(exprs), rules=rules)
872-
873851
@memoized_generator
874852
def writes_gen(self):
875853
"""
@@ -1111,7 +1089,7 @@ def d_flow_gen(self):
11111089
if any(not rule(w, r) for rule in self.rules):
11121090
continue
11131091

1114-
dependence = Dependence.fetch(w, r)
1092+
dependence = Dependence(w, r)
11151093

11161094
if dependence.is_imaginary:
11171095
continue
@@ -1141,7 +1119,7 @@ def d_anti_gen(self):
11411119
if any(not rule(r, w) for rule in self.rules):
11421120
continue
11431121

1144-
dependence = Dependence.fetch(r, w)
1122+
dependence = Dependence(r, w)
11451123

11461124
if dependence.is_imaginary:
11471125
continue
@@ -1171,7 +1149,7 @@ def d_output_gen(self):
11711149
if any(not rule(w2, w1) for rule in self.rules):
11721150
continue
11731151

1174-
dependence = Dependence.fetch(w2, w1)
1152+
dependence = Dependence(w2, w1)
11751153

11761154
if dependence.is_imaginary:
11771155
continue

devito/mpi/halo_scheme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def classify(exprs, ispace):
519519
# performed before (reads) or after (writes) the OWNED region is computed
520520
loc_indices_from_reads = configuration['mpi'] not in ('dual',)
521521

522-
scope = Scope.fetch(exprs)
522+
scope = Scope(exprs)
523523

524524
mapper = {}
525525
for f, r in scope.reads.items():

devito/operator/operator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from devito.data import default_allocator
1414
from devito.exceptions import (CompilationError, ExecutionError, InvalidArgument,
1515
InvalidOperator)
16-
from devito.ir.support.basic import Dependence, Scope
16+
from devito.ir.support.caching import CacheInstances
1717
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
1818
switch_log_level)
1919
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
@@ -247,8 +247,7 @@ def _build(cls, expressions, **kwargs):
247247
op._profiler = profiler
248248

249249
# Clear Scope + Dependence caches
250-
Scope._fetch.cache_clear()
251-
Dependence.fetch.cache_clear()
250+
CacheInstances.clear_caches()
252251

253252
return op
254253

devito/passes/clusters/blocking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def callback(self, clusters, prefix):
279279
if len(clusters) > 1:
280280
# Heuristic: same as above if it induces dynamic bounds
281281
exprs = flatten(c.exprs for c in as_tuple(clusters))
282-
scope = Scope.fetch(exprs)
282+
scope = Scope(exprs)
283283
if any(i.is_lex_non_stmt for i in scope.d_all_gen()):
284284
return clusters
285285
else:

devito/passes/clusters/cse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ def _cse(maybe_exprs, make, min_cost=1, mode='basic'):
124124
maybe_exprs = as_list(maybe_exprs)
125125
if all(e.is_Equality for e in maybe_exprs):
126126
exprs = maybe_exprs
127-
scope = Scope.fetch(maybe_exprs)
127+
scope = Scope(maybe_exprs)
128128
else:
129129
exprs = [Eq(make(e), e) for e in maybe_exprs]
130-
scope = Scope.fetch([])
130+
scope = Scope([])
131131

132132
# Some sub-expressions aren't really "common" -- that's the case of Dimension-
133133
# independent data dependences. For example:

devito/passes/clusters/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def is_cross(source, sink):
355355
for n1, cg1 in enumerate(cgroups[n+1:], start=n+1):
356356

357357
# A Scope to compute all cross-ClusterGroup anti-dependences
358-
scope = Scope.fetch(exprs=cg0.exprs + cg1.exprs, rules=is_cross)
358+
scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross)
359359

360360
# Anti-dependences along `prefix` break the execution flow
361361
# (intuitively, "the loop nests are to be kept separated")
@@ -444,7 +444,7 @@ def callback(self, clusters, prefix):
444444
return clusters
445445

446446
# Analyze and abort if fissioning would break a dependence
447-
scope = Scope.fetch(flatten(c.exprs for c in clusters))
447+
scope = Scope(flatten(c.exprs for c in clusters))
448448
if any(d._defines & dep.cause or dep.is_reduce(d) for dep in scope.d_all_gen()):
449449
return clusters
450450

devito/passes/iet/mpi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _drop_reduction_halospots(iet):
4040

4141
# If all HaloSpot reads pertain to reductions, then the HaloSpot is useless
4242
for hs, expressions in MapNodes(HaloSpot, Expression).visit(iet).items():
43-
scope = Scope.fetch(i.expr for i in expressions)
43+
scope = Scope(i.expr for i in expressions)
4444
for k, v in hs.fmapper.items():
4545
f = v.bundle or k
4646
if f not in scope.reads:
@@ -82,7 +82,7 @@ def _hoist_redundant_from_conditionals(iet):
8282

8383
mapper = HaloSpotMapper()
8484
for it, halo_spots in iter_mapper.items():
85-
scope = Scope.fetch(e.expr for e in FindNodes(Expression).visit(it))
85+
scope = Scope(e.expr for e in FindNodes(Expression).visit(it))
8686

8787
for hs0 in halo_spots:
8888
conditions = cond_mapper[hs0]
@@ -282,7 +282,7 @@ def _mark_overlappable(iet):
282282
if not expressions:
283283
continue
284284

285-
scope = Scope.fetch(i.expr for i in expressions)
285+
scope = Scope(i.expr for i in expressions)
286286

287287
# Comp/comm overlaps is legal only if the OWNED regions can grow
288288
# arbitrarly, which means all of the dependences must be carried
@@ -462,7 +462,7 @@ def _derive_scope(it, hs0, hs1):
462462
and ends at the HaloSpot `hs1`.
463463
"""
464464
expressions = FindWithin(Expression, hs0, stop=hs1).visit(it)
465-
return Scope.fetch(e.expr for e in expressions)
465+
return Scope(e.expr for e in expressions)
466466

467467

468468
def _check_control_flow(hs0, hs1, cond_mapper):

0 commit comments

Comments
 (0)