Skip to content

Commit 1ac8d5b

Browse files
authored
Update avl_tree.py
1 parent 0196e75 commit 1ac8d5b

1 file changed

Lines changed: 97 additions & 134 deletions

File tree

Lines changed: 97 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
1-
"""
2-
Auto-balanced binary tree implementation
3-
For doctests: python3 -m doctest -v avl_tree.py
4-
For testing: python avl_tree.py
5-
"""
6-
71
from __future__ import annotations
82

9-
import math
103
import random
11-
import doctest
124
from typing import Any
135

14-
156
class MyQueue:
167
def __init__(self) -> None:
178
self.data: list[Any] = []
@@ -25,204 +16,176 @@ def push(self, data: Any) -> None:
2516
self.tail += 1
2617

2718
def pop(self) -> Any:
28-
ret = self.data[self.head]
29-
self.head += 1
30-
return ret
31-
19+
return self.data[self.head]
3220

3321
class MyNode:
22+
__slots__ = ("data", "left", "right", "height")
23+
3424
def __init__(self, data: Any) -> None:
3525
self.data = data
36-
self.left = self.right = None
26+
self.left: MyNode | None = None
27+
self.right: MyNode | None = None
3728
self.height = 1
3829

39-
def get_data(self) -> Any:
40-
return self.data
41-
42-
def get_left(self) -> MyNode | None:
43-
return self.left
44-
45-
def get_right(self) -> MyNode | None:
46-
return self.right
47-
48-
def get_height(self) -> int:
49-
return self.height
50-
51-
def set_data(self, data: Any) -> None:
52-
self.data = data
53-
54-
def set_left(self, node: MyNode | None) -> None:
55-
self.left = node
56-
57-
def set_right(self, node: MyNode | None) -> None:
58-
self.right = node
59-
60-
def set_height(self, height: int) -> None:
61-
self.height = height
62-
63-
64-
def get_height(node: MyNode | None) -> int:
30+
def get_height(node: MyNode | None) -> int:
6531
return node.height if node else 0
6632

67-
68-
def my_max(a: int, b: int) -> int:
33+
def my_max(a: int, b: int) -> int:
6934
return a if a > b else b
7035

71-
7236
def right_rotation(node: MyNode) -> MyNode:
73-
print("left rotation node:", node.data)
74-
ret = node.left
75-
node.left = ret.right
76-
ret.right = node
37+
left_child = node.left
38+
if left_child is None:
39+
return node
40+
41+
node.left = left_child.right
42+
left_child.right = node
7743
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
78-
ret.height = my_max(get_height(ret.right), get_height(ret.left)) + 1
79-
return ret
80-
44+
left_child.height = my_max(get_height(left_child.right), get_height(left_child.left)) + 1
45+
return left_child
8146

8247
def left_rotation(node: MyNode) -> MyNode:
83-
print("right rotation node:", node.data)
84-
ret = node.right
85-
node.right = ret.left
86-
ret.left = node
48+
right_child = node.right
49+
if right_child is None:
50+
return node
51+
52+
node.right = right_child.left
53+
right_child.left = node
8754
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
88-
ret.height = my_max(get_height(ret.right), get_height(ret.left)) + 1
89-
return ret
90-
55+
right_child.height = my_max(get_height(right_child.right), get_height(right_child.left)) + 1
56+
return right_child
9157

9258
def lr_rotation(node: MyNode) -> MyNode:
93-
node.left = left_rotation(node.left)
59+
if node.left:
60+
node.left = left_rotation(node.left)
9461
return right_rotation(node)
9562

96-
9763
def rl_rotation(node: MyNode) -> MyNode:
98-
node.right = right_rotation(node.right)
64+
if node.right:
65+
node.right = right_rotation(node.right)
9966
return left_rotation(node)
10067

101-
10268
def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
103-
if not node:
69+
if node is None:
10470
return MyNode(data)
105-
71+
10672
if data < node.data:
10773
node.left = insert_node(node.left, data)
10874
if get_height(node.left) - get_height(node.right) == 2:
109-
if data < node.left.data:
75+
if node.left and data < node.left.data:
11076
node = right_rotation(node)
11177
else:
11278
node = lr_rotation(node)
11379
else:
11480
node.right = insert_node(node.right, data)
11581
if get_height(node.right) - get_height(node.left) == 2:
116-
if data < node.right.data:
82+
if node.right and data < node.right.data:
11783
node = rl_rotation(node)
11884
else:
11985
node = left_rotation(node)
120-
86+
12187
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
12288
return node
12389

124-
125-
def get_extreme(root: MyNode, is_right: bool) -> Any:
126-
while child := root.right if is_right else root.left:
127-
root = child
90+
def get_left_most(root: MyNode) -> Any:
91+
while root.left:
92+
root = root.left
12893
return root.data
12994

130-
131-
def del_node(root: MyNode, data: Any) -> MyNode | None:
132-
if root.data == data:
95+
def del_node(root: MyNode | None, data: Any) -> MyNode | None:
96+
if root is None:
97+
return None
98+
99+
if data == root.data:
133100
if root.left and root.right:
134-
root.data = get_extreme(root.right, False)
101+
root.data = get_left_most(root.right)
135102
root.right = del_node(root.right, root.data)
136103
else:
137104
return root.left or root.right
138-
elif root.data > data:
139-
if not root.left:
140-
return root
105+
elif data < root.data:
141106
root.left = del_node(root.left, data)
142107
else:
143108
root.right = del_node(root.right, data)
144-
145-
# Handle balancing
146-
right_height = get_height(root.right)
109+
110+
if root.left is None and root.right is None:
111+
root.height = 1
112+
return root
113+
147114
left_height = get_height(root.left)
148-
115+
right_height = get_height(root.right)
116+
149117
if right_height - left_height == 2:
150-
if get_height(root.right.right) > get_height(root.right.left):
118+
right_right = get_height(root.right.right) if root.right else 0
119+
right_left = get_height(root.right.left) if root.right else 0
120+
if right_right > right_left:
151121
root = left_rotation(root)
152122
else:
153123
root = rl_rotation(root)
154-
elif right_height - left_height == -2:
155-
if get_height(root.left.left) > get_height(root.left.right):
124+
elif left_height - right_height == 2:
125+
left_left = get_height(root.left.left) if root.left else 0
126+
left_right = get_height(root.left.right) if root.left else 0
127+
if left_left > left_right:
156128
root = right_rotation(root)
157129
else:
158130
root = lr_rotation(root)
159-
131+
160132
root.height = my_max(get_height(root.right), get_height(root.left)) + 1
161133
return root
162-
163-
164134
class AVLtree:
165-
def __init__(self) -> None:
166-
self.root = None
167-
168-
def get_height(self) -> int:
135+
__slots__ = ("root",)
136+
137+
def __init__(self) -> None:
138+
self.root: MyNode | None = None
139+
140+
def get_height(self) -> int:
169141
return get_height(self.root)
170-
142+
171143
def insert(self, data: Any) -> None:
172-
print(f"insert:{data}")
173144
self.root = insert_node(self.root, data)
174-
145+
175146
def del_node(self, data: Any) -> None:
176-
print(f"delete:{data}")
177-
if not self.root:
178-
return
179147
self.root = del_node(self.root, data)
180-
148+
181149
def __str__(self) -> str:
182-
if not self.root:
150+
if self.root is None:
183151
return ""
184-
q, output, layer, cnt = MyQueue(), "", self.get_height(), 0
185-
q.push(self.root)
186-
187-
while not q.is_empty():
188-
node = q.pop()
189-
space = " " * int(2 ** (layer - 1))
190-
output += space + (str(node.data) if node else "*") + space
191-
cnt += 1
192-
193-
if node:
194-
q.push(node.left)
195-
q.push(node.right)
152+
153+
levels = []
154+
queue = [self.root]
155+
156+
while queue:
157+
current = []
158+
next_level = []
159+
160+
for node in queue:
161+
if node:
162+
current.append(str(node.data))
163+
next_level.append(node.left)
164+
next_level.append(node.right)
165+
else:
166+
current.append("*")
167+
next_level.append(None)
168+
next_level.append(None)
169+
170+
if any(node is not None for node in next_level):
171+
levels.append(" ".join(current))
172+
queue = next_level
196173
else:
197-
q.push(None)
198-
q.push(None)
199-
200-
for i in range(10):
201-
if cnt == 2**i - 1:
202-
layer -= 1
203-
output += "\n"
204-
if layer == 0:
205-
break
206-
break
207-
208-
return output + "\n" + "*" * 36
209-
210-
211-
def _test() -> None:
212-
doctest.testmod()
213-
174+
break
175+
176+
return "\n".join(levels) + "\n" + "*"*36
214177

215-
if __name__ == "__main__":
216-
_test()
178+
def test_avl_tree() -> None:
217179
t = AVLtree()
218180
lst = list(range(10))
219181
random.shuffle(lst)
220-
182+
221183
for i in lst:
222184
t.insert(i)
223-
print(t)
224-
185+
225186
random.shuffle(lst)
226187
for i in lst:
227188
t.del_node(i)
228-
print(t)
189+
190+
if __name__ == "__main__":
191+
test_avl_tree()

0 commit comments

Comments
 (0)