Skip to content

Commit 821502f

Browse files
committed
misc: Move IR caches to tool/memoization
1 parent 1c12003 commit 821502f

7 files changed

Lines changed: 133 additions & 139 deletions

File tree

devito/ir/support/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .vector import * # noqa
22
from .utils import * # noqa
33
from .basic import * # noqa
4-
from .caching import * # noqa
54
from .space import * # noqa
65
from .guards import * # noqa
76
from .syncs import * # noqa

devito/ir/support/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from sympy import S, Expr
77
import sympy
88

9-
from devito.ir.support.caching import CacheInstances
109
from devito.ir.support.space import Backward, null_ispace
1110
from devito.ir.support.utils import AccessMode, extrema
1211
from devito.ir.support.vector import LabeledVector, Vector
@@ -15,7 +14,7 @@
1514
uxreplace)
1615
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
1716
flatten, memoized_meth, memoized_generator, smart_gt,
18-
smart_lt)
17+
smart_lt, CacheInstances)
1918
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
2019
CriticalRegion, Function, Symbol, Temp, TempArray,
2120
TBArray)

devito/ir/support/caching.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

devito/operator/operator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from devito.data import default_allocator
1414
from devito.exceptions import (CompilationError, ExecutionError, InvalidArgument,
1515
InvalidOperator)
16-
from devito.ir.support.caching import CacheInstances
1716
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
1817
switch_log_level)
1918
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
@@ -32,7 +31,8 @@
3231
from devito.symbolics import estimate_cost, subs_op_args
3332
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3433
flatten, filter_sorted, frozendict, is_integer,
35-
split, timed_pass, timed_region, contains_val)
34+
split, timed_pass, timed_region, contains_val,
35+
CacheInstances)
3636
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3737
disk_layer)
3838
from devito.types.dimension import Thickness
@@ -246,7 +246,7 @@ def _build(cls, expressions, **kwargs):
246246
op._dtype, op._dspace = irs.clusters.meta
247247
op._profiler = profiler
248248

249-
# Clear Scope + Dependence caches
249+
# Clear build-scoped instance caches
250250
CacheInstances.clear_caches()
251251

252252
return op

devito/tools/memoization.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from collections.abc import Hashable
2-
from functools import partial
1+
from collections.abc import Callable, Hashable
2+
from functools import lru_cache, partial
33
from itertools import tee
4+
from typing import TypeVar
45

5-
__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator']
6+
__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator', 'CacheInstances']
67

78

89
class memoized_func:
@@ -125,3 +126,63 @@ def __call__(self, *args, **kwargs):
125126
it = cache[key] if key in cache else self.func(*args, **kwargs)
126127
cache[key], result = tee(it)
127128
return result
129+
130+
131+
# Describes the type of a subclass of CacheInstances
132+
InstanceType = TypeVar('InstanceType', bound='CacheInstances', covariant=True)
133+
134+
135+
class CacheInstancesMeta(type):
136+
"""
137+
Metaclass to wrap construction in an LRU cache.
138+
"""
139+
140+
_cached_types: set[type['CacheInstances']] = set()
141+
142+
def __init__(cls: type[InstanceType], *args) -> None: # type: ignore
143+
super().__init__(*args)
144+
145+
# Register the cached type
146+
CacheInstancesMeta._cached_types.add(cls)
147+
148+
def __call__(cls: type[InstanceType], # type: ignore
149+
*args, **kwargs) -> InstanceType:
150+
if cls._instance_cache is None:
151+
maxsize = cls._instance_cache_size
152+
cls._instance_cache = lru_cache(maxsize=maxsize)(super().__call__)
153+
154+
args, kwargs = cls._preprocess_args(*args, **kwargs)
155+
return cls._instance_cache(*args, **kwargs)
156+
157+
@classmethod
158+
def clear_caches(cls: type['CacheInstancesMeta']) -> None:
159+
"""
160+
Clear all caches for classes using this metaclass.
161+
"""
162+
for cached_type in cls._cached_types:
163+
if cached_type._instance_cache is not None:
164+
cached_type._instance_cache.cache_clear()
165+
166+
167+
class CacheInstances(metaclass=CacheInstancesMeta):
168+
"""
169+
Parent class that wraps construction in an LRU cache.
170+
"""
171+
172+
_instance_cache: Callable | None = None
173+
_instance_cache_size: int = 128
174+
175+
@classmethod
176+
def _preprocess_args(cls, *args, **kwargs):
177+
"""
178+
Preprocess the arguments before caching. This can be overridden in subclasses
179+
to customize argument handling (e.g. to convert to hashable types).
180+
"""
181+
return args, kwargs
182+
183+
@staticmethod
184+
def clear_caches() -> None:
185+
"""
186+
Clears all IR instance caches.
187+
"""
188+
CacheInstancesMeta.clear_caches()

tests/test_ir.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from devito.ir.support.basic import (IterationInstance, TimedAccess, Scope,
1313
Vector, AFFINE, REGULAR, IRREGULAR, mocksym0,
1414
mocksym1)
15-
from devito.ir.support.caching import CacheInstances
1615
from devito.ir.support.space import (NullInterval, Interval, Forward, Backward,
1716
IntervalGroup, IterationSpace)
1817
from devito.ir.support.guards import GuardOverflow
@@ -1102,65 +1101,3 @@ def test_guard_overflow(self):
11021101
guard = GuardOverflow(freespace, size)
11031102

11041103
assert ccode(guard) == 'freespace >= f_vec->size[0]*f_vec->size[1]'
1105-
1106-
1107-
class TestCacheInstances:
1108-
1109-
def test_caching(self):
1110-
"""
1111-
Tests basic functionality of cached instances.
1112-
"""
1113-
class Object(CacheInstances):
1114-
def __init__(self, value: int):
1115-
self.value = value
1116-
1117-
obj1 = Object(1)
1118-
obj2 = Object(1)
1119-
obj3 = Object(2)
1120-
1121-
assert obj1 is obj2
1122-
assert obj1 is not obj3
1123-
1124-
def test_cache_size(self):
1125-
"""
1126-
Tests specifying the size of the instance cache.
1127-
"""
1128-
class Object(CacheInstances):
1129-
_instance_cache_size = 2
1130-
1131-
def __init__(self, value: int):
1132-
self.value = value
1133-
1134-
obj1 = Object(1)
1135-
obj2 = Object(2)
1136-
obj3 = Object(3)
1137-
obj4 = Object(1)
1138-
obj5 = Object(3)
1139-
1140-
# obj1 should have been evicted before obj4 was created
1141-
assert obj1 is not obj4
1142-
assert obj1 is not obj2
1143-
assert obj3 is obj5
1144-
1145-
hits, _, _, cursize = Object._instance_cache.cache_info()
1146-
assert hits == 1 # obj5 hit the cache
1147-
assert cursize == 2
1148-
1149-
def test_cleared_after_build(self):
1150-
"""
1151-
Tests that instance caches are cleared after building an Operator.
1152-
"""
1153-
class Object(CacheInstances):
1154-
def __init__(self, value: int):
1155-
self.value = value
1156-
1157-
obj1 = Object(1)
1158-
cache_size = Object._instance_cache.cache_info()[-1]
1159-
assert cache_size == 1
1160-
1161-
x = Symbol('x')
1162-
Operator(Eq(x, obj1.value))
1163-
1164-
# Cache should be cleared after Operator construction
1165-
cache_size = Object._instance_cache.cache_info()[-1]
1166-
assert cache_size == 0

tests/test_tools.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
import time
66

7+
from devito import Operator, Eq
78
from devito.tools import (UnboundedMultiTuple, ctypes_to_cstr, toposort,
8-
filter_ordered, transitive_closure, UnboundTuple)
9+
filter_ordered, transitive_closure, UnboundTuple,
10+
CacheInstances)
911
from devito.types.basic import Symbol
1012

1113

@@ -145,3 +147,65 @@ def test_unbound_tuple():
145147
assert ub.next() == 2
146148
ub.iter()
147149
assert ub.next() == 1
150+
151+
152+
class TestCacheInstances:
153+
154+
def test_caching(self):
155+
"""
156+
Tests basic functionality of cached instances.
157+
"""
158+
class Object(CacheInstances):
159+
def __init__(self, value: int):
160+
self.value = value
161+
162+
obj1 = Object(1)
163+
obj2 = Object(1)
164+
obj3 = Object(2)
165+
166+
assert obj1 is obj2
167+
assert obj1 is not obj3
168+
169+
def test_cache_size(self):
170+
"""
171+
Tests specifying the size of the instance cache.
172+
"""
173+
class Object(CacheInstances):
174+
_instance_cache_size = 2
175+
176+
def __init__(self, value: int):
177+
self.value = value
178+
179+
obj1 = Object(1)
180+
obj2 = Object(2)
181+
obj3 = Object(3)
182+
obj4 = Object(1)
183+
obj5 = Object(3)
184+
185+
# obj1 should have been evicted before obj4 was created
186+
assert obj1 is not obj4
187+
assert obj1 is not obj2
188+
assert obj3 is obj5
189+
190+
hits, _, _, cursize = Object._instance_cache.cache_info()
191+
assert hits == 1 # obj5 hit the cache
192+
assert cursize == 2
193+
194+
def test_cleared_after_build(self):
195+
"""
196+
Tests that instance caches are cleared after building an Operator.
197+
"""
198+
class Object(CacheInstances):
199+
def __init__(self, value: int):
200+
self.value = value
201+
202+
obj1 = Object(1)
203+
cache_size = Object._instance_cache.cache_info()[-1]
204+
assert cache_size == 1
205+
206+
x = Symbol('x')
207+
Operator(Eq(x, obj1.value))
208+
209+
# Cache should be cleared after Operator construction
210+
cache_size = Object._instance_cache.cache_info()[-1]
211+
assert cache_size == 0

0 commit comments

Comments
 (0)