Skip to content

Commit b002fc9

Browse files
authored
Update avl_tree.py
1 parent 5ea1b4a commit b002fc9

1 file changed

Lines changed: 50 additions & 67 deletions

File tree

data_structures/binary_tree/avl_tree.py

Lines changed: 50 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import random
44
from typing import Any
55

6-
76
class MyQueue:
7+
__slots__ = ("data", "head", "tail")
8+
89
def __init__(self) -> None:
910
self.data: list[Any] = []
1011
self.head = self.tail = 0
@@ -17,71 +18,61 @@ def push(self, data: Any) -> None:
1718
self.tail += 1
1819

1920
def pop(self) -> Any:
20-
return self.data[self.head]
21-
21+
ret = self.data[self.head]
22+
self.head += 1
23+
return ret
2224

2325
class MyNode:
24-
__slots__ = ("data", "left", "right", "height")
25-
26+
__slots__ = ("data", "height", "left", "right") # 按字母顺序排序
27+
2628
def __init__(self, data: Any) -> None:
2729
self.data = data
30+
self.height = 1
2831
self.left: MyNode | None = None
2932
self.right: MyNode | None = None
30-
self.height = 1
31-
3233

33-
def get_height(node: MyNode | None) -> int:
34+
def get_height(node: MyNode | None) -> int:
3435
return node.height if node else 0
3536

36-
37-
def my_max(a: int, b: int) -> int:
37+
def my_max(a: int, b: int) -> int:
3838
return a if a > b else b
3939

40-
4140
def right_rotation(node: MyNode) -> MyNode:
4241
left_child = node.left
4342
if left_child is None:
4443
return node
45-
44+
4645
node.left = left_child.right
4746
left_child.right = node
4847
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
49-
left_child.height = (
50-
my_max(get_height(left_child.right), get_height(left_child.left)) + 1
51-
)
48+
left_child.height = my_max(get_height(left_child.right), get_height(left_child.left)) + 1
5249
return left_child
5350

54-
5551
def left_rotation(node: MyNode) -> MyNode:
5652
right_child = node.right
5753
if right_child is None:
5854
return node
59-
55+
6056
node.right = right_child.left
6157
right_child.left = node
6258
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
63-
right_child.height = (
64-
my_max(get_height(right_child.right), get_height(right_child.left)) + 1
65-
)
59+
right_child.height = my_max(get_height(right_child.right), get_height(right_child.left)) + 1
6660
return right_child
6761

68-
6962
def lr_rotation(node: MyNode) -> MyNode:
7063
if node.left:
7164
node.left = left_rotation(node.left)
7265
return right_rotation(node)
7366

74-
7567
def rl_rotation(node: MyNode) -> MyNode:
7668
if node.right:
7769
node.right = right_rotation(node.right)
7870
return left_rotation(node)
7971

80-
8172
def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
8273
if node is None:
8374
return MyNode(data)
84-
75+
8576
if data < node.data:
8677
node.left = insert_node(node.left, data)
8778
if get_height(node.left) - get_height(node.right) == 2:
@@ -96,22 +87,19 @@ def insert_node(node: MyNode | None, data: Any) -> MyNode | None:
9687
node = rl_rotation(node)
9788
else:
9889
node = left_rotation(node)
99-
90+
10091
node.height = my_max(get_height(node.right), get_height(node.left)) + 1
10192
return node
10293

103-
10494
def get_left_most(root: MyNode) -> Any:
10595
while root.left:
10696
root = root.left
10797
return root.data
10898

109-
11099
def del_node(root: MyNode | None, data: Any) -> MyNode | None:
111100
if root is None:
112101
return None
113-
114-
if data == root.data:
102+
if data == root.data:
115103
if root.left and root.right:
116104
root.data = get_left_most(root.right)
117105
root.right = del_node(root.right, root.data)
@@ -121,90 +109,85 @@ def del_node(root: MyNode | None, data: Any) -> MyNode | None:
121109
root.left = del_node(root.left, data)
122110
else:
123111
root.right = del_node(root.right, data)
124-
112+
125113
if root.left is None and root.right is None:
126114
root.height = 1
127115
return root
128-
116+
129117
left_height = get_height(root.left)
130118
right_height = get_height(root.right)
131-
119+
132120
if right_height - left_height == 2:
133121
right_right = get_height(root.right.right) if root.right else 0
134122
right_left = get_height(root.right.left) if root.right else 0
135-
if right_right > right_left:
136-
root = left_rotation(root)
137-
else:
138-
root = rl_rotation(root)
123+
# 使用三元表达式
124+
root = left_rotation(root) if right_right > right_left else rl_rotation(root)
139125
elif left_height - right_height == 2:
140126
left_left = get_height(root.left.left) if root.left else 0
141127
left_right = get_height(root.left.right) if root.left else 0
142-
if left_left > left_right:
143-
root = right_rotation(root)
144-
else:
145-
root = lr_rotation(root)
146-
128+
# 使用三元表达式
129+
root = right_rotation(root) if left_left > left_right else lr_rotation(root)
130+
147131
root.height = my_max(get_height(root.right), get_height(root.left)) + 1
148132
return root
149133

150-
151-
class AVLtree:
134+
class AVLTree:
152135
__slots__ = ("root",)
153-
154-
def __init__(self) -> None:
136+
137+
def __init__(self) -> None:
155138
self.root: MyNode | None = None
156-
157-
def get_height(self) -> int:
139+
140+
def get_height(self) -> int:
158141
return get_height(self.root)
159-
142+
160143
def insert(self, data: Any) -> None:
161144
self.root = insert_node(self.root, data)
162-
163-
def del_node(self, data: Any) -> None:
145+
146+
def delete(self, data: Any) -> None:
164147
self.root = del_node(self.root, data)
165-
148+
166149
def __str__(self) -> str:
167150
if self.root is None:
168151
return ""
169-
152+
170153
levels = []
171-
queue = [self.root]
172-
154+
# 明确指定类型为 MyNode | None
155+
queue: list[MyNode | None] = [self.root]
156+
173157
while queue:
174158
current = []
175-
next_level = []
176-
159+
next_level: list[MyNode | None] = []
160+
177161
for node in queue:
178162
if node:
179163
current.append(str(node.data))
180164
next_level.append(node.left)
181165
next_level.append(node.right)
182166
else:
183167
current.append("*")
184-
next_level.append(None)
185-
next_level.append(None)
186-
168+
next_level.extend([None, None])
169+
187170
if any(node is not None for node in next_level):
188171
levels.append(" ".join(current))
189172
queue = next_level
190173
else:
174+
if current: # 添加最后一行
175+
levels.append(" ".join(current))
191176
break
192-
193-
return "\n".join(levels) + "\n" + "*" * 36
194-
177+
178+
return "\n".join(levels) + "\n" + "*"*36
195179

196180
def test_avl_tree() -> None:
197-
t = AVLtree()
181+
t = AVLTree()
198182
lst = list(range(10))
199183
random.shuffle(lst)
200-
184+
201185
for i in lst:
202186
t.insert(i)
203-
187+
204188
random.shuffle(lst)
205189
for i in lst:
206-
t.del_node(i)
207-
190+
t.delete(i)
208191

209192
if __name__ == "__main__":
210193
test_avl_tree()

0 commit comments

Comments
 (0)