Skip to content

Commit ca98c95

Browse files
authored
Update lru_cache.py
1 parent 027f692 commit ca98c95

1 file changed

Lines changed: 32 additions & 37 deletions

File tree

other/lru_cache.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
from collections.abc import Callable, Hashable
44
from functools import wraps
5-
from typing import Any, Generic, ParamSpec, TypeVar
5+
from typing import Any, Generic, ParamSpec, TypeVar, cast
66

7-
T = TypeVar("T", bound=Hashable)
8-
U = TypeVar("U")
97
P = ParamSpec("P")
108
R = TypeVar("R")
119

@@ -43,26 +41,26 @@ def __repr__(self) -> str:
4341
def add(self, node: DoubleLinkedListNode) -> None:
4442
"""Add node to list end"""
4543
prev = self.rear.prev
46-
if not prev:
44+
if prev is None:
4745
raise ValueError("Invalid list state")
48-
46+
4947
prev.next = node
5048
node.prev = prev
5149
self.rear.prev = node
5250
node.next = self.rear
5351

5452
def remove(self, node: DoubleLinkedListNode) -> DoubleLinkedListNode | None:
5553
"""Remove node from list"""
56-
if not node.prev or not node.next:
54+
if node.prev is None or node.next is None:
5755
return None
58-
56+
5957
node.prev.next = node.next
6058
node.next.prev = node.prev
6159
node.prev = node.next = None
6260
return node
6361

6462

65-
class LRUCache(Generic[T, U]):
63+
class LRUCache:
6664
"""LRU Cache implementation"""
6765

6866
def __init__(self, capacity: int) -> None:
@@ -71,29 +69,26 @@ def __init__(self, capacity: int) -> None:
7169
self.size = 0
7270
self.hits = 0
7371
self.misses = 0
74-
self.cache: dict[T, DoubleLinkedListNode] = {}
72+
self.cache: dict[Any, DoubleLinkedListNode] = {}
7573

7674
def __repr__(self) -> str:
7775
return (
7876
f"Cache(hits={self.hits}, misses={self.misses}, "
7977
f"cap={self.capacity}, size={self.size})"
8078
)
8179

82-
def __contains__(self, key: T) -> bool:
83-
return key in self.cache
84-
85-
def get(self, key: T) -> U | None:
80+
def get(self, key: Any) -> Any | None:
8681
"""Get value for key"""
8782
if key in self.cache:
8883
self.hits += 1
8984
node = self.cache[key]
9085
if self.list.remove(node):
9186
self.list.add(node)
92-
return node.val # type: ignore[return-value]
87+
return node.val
9388
self.misses += 1
9489
return None
9590

96-
def put(self, key: T, value: U) -> None:
91+
def put(self, key: Any, value: Any) -> None:
9792
"""Set value for key"""
9893
if key in self.cache:
9994
node = self.cache[key]
@@ -105,38 +100,38 @@ def put(self, key: T, value: U) -> None:
105100
if self.size >= self.capacity:
106101
first = self.list.head.next
107102
if first and first.key and self.list.remove(first):
108-
del self.cache[first.key] # type: ignore[index]
103+
del self.cache[first.key]
109104
self.size -= 1
110105

111106
new_node = DoubleLinkedListNode(key, value)
112107
self.cache[key] = new_node
113108
self.list.add(new_node)
114109
self.size += 1
115110

116-
@classmethod
117-
def decorator(cls, size: int = 128) -> Callable[[Callable[P, R]], Callable[P, R]]:
118-
"""LRU Cache decorator"""
119-
120-
def decorator_func(func: Callable[P, R]) -> Callable[P, R]:
121-
# Create non-generic cache instance
122-
cache = cls(size) # type: ignore[assignment]
123111

124-
@wraps(func)
125-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
126-
key = (args, tuple(sorted(kwargs.items())))
127-
if (result := cache.get(key)) is None:
128-
result = func(*args, **kwargs)
129-
cache.put(key, result)
130-
return result
131-
132-
# Add cache_info attribute
133-
wrapper.cache_info = lambda: cache # type: ignore[attr-defined]
134-
return wrapper
135-
136-
return decorator_func
112+
def lru_cache(size: int = 128) -> Callable[[Callable[P, R]], Callable[P, R]]:
113+
"""LRU Cache decorator"""
114+
def decorator_func(func: Callable[P, R]) -> Callable[P, R]:
115+
cache = LRUCache(size)
116+
117+
@wraps(func)
118+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
119+
key = (args, tuple(sorted(kwargs.items())))
120+
cached = cache.get(key)
121+
if cached is not None:
122+
return cached
123+
124+
result = func(*args, **kwargs)
125+
cache.put(key, result)
126+
return result
127+
128+
# Add cache_info attribute
129+
wrapper.cache_info = lambda: cache # type: ignore[attr-defined]
130+
return wrapper
131+
132+
return decorator_func
137133

138134

139135
if __name__ == "__main__":
140136
import doctest
141-
142137
doctest.testmod()

0 commit comments

Comments
 (0)