Skip to content

Commit 7d734c9

Browse files
committed
Add Splay Tree Implementation in Binary Tree
Signed-off-by: Arya Pratap Singh <notaryasingh@gmail.com>
1 parent a71618f commit 7d734c9

1 file changed

Lines changed: 326 additions & 0 deletions

File tree

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
"""
2+
Splay Tree implementation - A self-adjusting binary search tree.
3+
4+
A Splay tree is a self-adjusting binary search tree where recently accessed
5+
elements are moved to the root through rotations. This provides amortized
6+
O(log n) time complexity for search, insert, and delete operations.
7+
8+
The splaying operation moves a node to the root by performing a series of
9+
rotations, making frequently accessed elements faster to access in the future.
10+
11+
Time Complexity (amortized):
12+
- Search: O(log n)
13+
- Insert: O(log n)
14+
- Delete: O(log n)
15+
16+
Space Complexity: O(n)
17+
18+
Operations:
19+
- Zig: Single rotation (when parent is root)
20+
- Zig-zig: Double rotation in same direction
21+
- Zig-zag: Double rotation in opposite directions
22+
23+
Example:
24+
>>> tree = SplayTree()
25+
>>> _ = tree.insert(10)
26+
>>> _ = tree.insert(20)
27+
>>> _ = tree.insert(30)
28+
>>> _ = tree.insert(5)
29+
>>> _ = tree.insert(15)
30+
>>> list(tree.inorder())
31+
[5, 10, 15, 20, 30]
32+
>>> tree.search(15)
33+
True
34+
>>> tree.search(25)
35+
False
36+
>>> _ = tree.delete(20)
37+
>>> list(tree.inorder())
38+
[5, 10, 15, 30]
39+
"""
40+
41+
from __future__ import annotations
42+
43+
from collections.abc import Iterator
44+
from dataclasses import dataclass
45+
from typing import Any, Self
46+
47+
48+
@dataclass
49+
class SplayNode:
50+
"""A node in the Splay Tree."""
51+
value: Any
52+
left: SplayNode | None = None
53+
right: SplayNode | None = None
54+
parent: SplayNode | None = None
55+
56+
def __iter__(self) -> Iterator[Any]:
57+
"""Inorder traversal iterator."""
58+
yield from self.left or []
59+
yield self.value
60+
yield from self.right or []
61+
62+
def __repr__(self) -> str:
63+
from pprint import pformat
64+
65+
if self.left is None and self.right is None:
66+
return str(self.value)
67+
return pformat({f"{self.value}": (self.left, self.right)}, indent=1)
68+
69+
@property
70+
def is_right(self) -> bool:
71+
"""Check if this node is the right child of its parent."""
72+
return bool(self.parent and self is self.parent.right)
73+
74+
75+
@dataclass
76+
class SplayTree:
77+
"""
78+
Splay Tree implementation - A self-adjusting BST.
79+
80+
This tree automatically moves recently accessed elements to the root
81+
through rotations, providing amortized O(log n) performance for all operations.
82+
"""
83+
84+
root: SplayNode | None = None
85+
86+
def __bool__(self) -> bool:
87+
"""Return True if the tree is not empty."""
88+
return self.root is not None
89+
90+
def __iter__(self) -> Iterator[Any]:
91+
"""Iterate over the tree in inorder traversal."""
92+
yield from self.root or []
93+
94+
def __len__(self) -> int:
95+
"""Return the number of nodes in the tree."""
96+
return sum(1 for _ in self)
97+
98+
def __str__(self) -> str:
99+
"""Return a string representation of the tree."""
100+
return str(self.root)
101+
102+
def _rotate_right(self, node: SplayNode) -> None:
103+
"""Perform a right rotation on the given node."""
104+
if not node.left:
105+
return
106+
107+
left_child = node.left
108+
node.left = left_child.right
109+
110+
if left_child.right:
111+
left_child.right.parent = node
112+
113+
left_child.parent = node.parent
114+
115+
if not node.parent:
116+
self.root = left_child
117+
elif node is node.parent.right:
118+
node.parent.right = left_child
119+
else:
120+
node.parent.left = left_child
121+
122+
left_child.right = node
123+
node.parent = left_child
124+
125+
def _rotate_left(self, node: SplayNode) -> None:
126+
"""Perform a left rotation on the given node."""
127+
if not node.right:
128+
return
129+
130+
right_child = node.right
131+
node.right = right_child.left
132+
133+
if right_child.left:
134+
right_child.left.parent = node
135+
136+
right_child.parent = node.parent
137+
138+
if not node.parent:
139+
self.root = right_child
140+
elif node is node.parent.left:
141+
node.parent.left = right_child
142+
else:
143+
node.parent.right = right_child
144+
145+
right_child.left = node
146+
node.parent = right_child
147+
148+
def _splay(self, node: SplayNode) -> None:
149+
"""
150+
Splay the given node to the root through a series of rotations.
151+
152+
The splaying operation uses three types of rotations:
153+
- Zig: Single rotation when parent is root
154+
- Zig-zig: Double rotation in same direction
155+
- Zig-zag: Double rotation in opposite directions
156+
"""
157+
while node.parent:
158+
parent = node.parent
159+
grandparent = parent.parent
160+
161+
if not grandparent:
162+
# Zig case: parent is root
163+
if node is parent.left:
164+
self._rotate_right(parent)
165+
else:
166+
self._rotate_left(parent)
167+
elif (node is parent.left and parent is grandparent.left) or \
168+
(node is parent.right and parent is grandparent.right):
169+
# Zig-zig case: same direction
170+
if parent is grandparent.left:
171+
self._rotate_right(grandparent)
172+
self._rotate_right(parent)
173+
else:
174+
self._rotate_left(grandparent)
175+
self._rotate_left(parent)
176+
else:
177+
# Zig-zag case: opposite directions
178+
if node is parent.left:
179+
self._rotate_right(parent)
180+
self._rotate_left(grandparent)
181+
else:
182+
self._rotate_left(parent)
183+
self._rotate_right(grandparent)
184+
185+
def _find_node(self, value: Any) -> SplayNode | None:
186+
"""Find a node with the given value, splaying it to root if found."""
187+
current = self.root
188+
while current:
189+
if value == current.value:
190+
self._splay(current)
191+
return current
192+
elif value < current.value:
193+
if not current.left:
194+
break
195+
current = current.left
196+
else:
197+
if not current.right:
198+
break
199+
current = current.right
200+
201+
# If we found a node (even if not exact match), splay it
202+
if current:
203+
self._splay(current)
204+
return None
205+
206+
def search(self, value: Any) -> bool:
207+
"""Search for a value in the tree. Returns True if found."""
208+
return self._find_node(value) is not None
209+
210+
def insert(self, value: Any) -> Self:
211+
"""Insert a value into the splay tree."""
212+
if not self.root:
213+
self.root = SplayNode(value)
214+
return self
215+
216+
# Find the insertion point
217+
current = self.root
218+
while True:
219+
if value < current.value:
220+
if current.left:
221+
current = current.left
222+
else:
223+
current.left = SplayNode(value, parent=current)
224+
self._splay(current.left)
225+
break
226+
elif value > current.value:
227+
if current.right:
228+
current = current.right
229+
else:
230+
current.right = SplayNode(value, parent=current)
231+
self._splay(current.right)
232+
break
233+
else:
234+
# Value already exists, splay the existing node
235+
self._splay(current)
236+
break
237+
238+
return self
239+
240+
def delete(self, value: Any) -> Self:
241+
"""Delete a value from the splay tree."""
242+
node = self._find_node(value)
243+
if not node:
244+
return self
245+
246+
# Node to delete is now at root
247+
left_tree = node.left
248+
right_tree = node.right
249+
250+
# Remove the root
251+
if left_tree:
252+
left_tree.parent = None
253+
if right_tree:
254+
right_tree.parent = None
255+
256+
# If no left subtree, right becomes root
257+
if not left_tree:
258+
self.root = right_tree
259+
return self
260+
261+
# Find the maximum in left subtree
262+
max_left = left_tree
263+
while max_left.right:
264+
max_left = max_left.right
265+
266+
# Splay the maximum to root of left subtree
267+
self._splay(max_left)
268+
269+
# Attach right subtree to the new root
270+
max_left.right = right_tree
271+
if right_tree:
272+
right_tree.parent = max_left
273+
274+
self.root = max_left
275+
return self
276+
277+
def inorder(self) -> Iterator[Any]:
278+
"""Return an inorder iterator."""
279+
yield from self
280+
281+
def preorder(self) -> Iterator[Any]:
282+
"""Return a preorder iterator."""
283+
def _preorder(node: SplayNode | None) -> Iterator[Any]:
284+
if node:
285+
yield node.value
286+
yield from _preorder(node.left)
287+
yield from _preorder(node.right)
288+
yield from _preorder(self.root)
289+
290+
def postorder(self) -> Iterator[Any]:
291+
"""Return a postorder iterator."""
292+
def _postorder(node: SplayNode | None) -> Iterator[Any]:
293+
if node:
294+
yield from _postorder(node.left)
295+
yield from _postorder(node.right)
296+
yield node.value
297+
yield from _postorder(self.root)
298+
299+
def get_min(self) -> Any:
300+
"""Get the minimum value in the tree."""
301+
if not self.root:
302+
raise ValueError("Tree is empty")
303+
current = self.root
304+
while current.left:
305+
current = current.left
306+
self._splay(current)
307+
return current.value
308+
309+
def get_max(self) -> Any:
310+
"""Get the maximum value in the tree."""
311+
if not self.root:
312+
raise ValueError("Tree is empty")
313+
current = self.root
314+
while current.right:
315+
current = current.right
316+
self._splay(current)
317+
return current.value
318+
319+
def is_empty(self) -> bool:
320+
"""Check if the tree is empty."""
321+
return self.root is None
322+
323+
def clear(self) -> Self:
324+
"""Clear the tree."""
325+
self.root = None
326+
return self

0 commit comments

Comments
 (0)