1- """
2- Auto-balanced binary tree implementation
3- For doctests: python3 -m doctest -v avl_tree.py
4- For testing: python avl_tree.py
5- """
6-
71from __future__ import annotations
82
9- import math
103import random
11- import doctest
124from typing import Any
135
14-
156class MyQueue :
167 def __init__ (self ) -> None :
178 self .data : list [Any ] = []
@@ -25,204 +16,176 @@ def push(self, data: Any) -> None:
2516 self .tail += 1
2617
2718 def pop (self ) -> Any :
28- ret = self .data [self .head ]
29- self .head += 1
30- return ret
31-
19+ return self .data [self .head ]
3220
3321class MyNode :
22+ __slots__ = ("data" , "left" , "right" , "height" )
23+
3424 def __init__ (self , data : Any ) -> None :
3525 self .data = data
36- self .left = self .right = None
26+ self .left : MyNode | None = None
27+ self .right : MyNode | None = None
3728 self .height = 1
3829
39- def get_data (self ) -> Any :
40- return self .data
41-
42- def get_left (self ) -> MyNode | None :
43- return self .left
44-
45- def get_right (self ) -> MyNode | None :
46- return self .right
47-
48- def get_height (self ) -> int :
49- return self .height
50-
51- def set_data (self , data : Any ) -> None :
52- self .data = data
53-
54- def set_left (self , node : MyNode | None ) -> None :
55- self .left = node
56-
57- def set_right (self , node : MyNode | None ) -> None :
58- self .right = node
59-
60- def set_height (self , height : int ) -> None :
61- self .height = height
62-
63-
64- def get_height (node : MyNode | None ) -> int :
30+ def get_height (node : MyNode | None ) -> int :
6531 return node .height if node else 0
6632
67-
68- def my_max (a : int , b : int ) -> int :
33+ def my_max (a : int , b : int ) -> int :
6934 return a if a > b else b
7035
71-
7236def right_rotation (node : MyNode ) -> MyNode :
73- print ("left rotation node:" , node .data )
74- ret = node .left
75- node .left = ret .right
76- ret .right = node
37+ left_child = node .left
38+ if left_child is None :
39+ return node
40+
41+ node .left = left_child .right
42+ left_child .right = node
7743 node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
78- ret .height = my_max (get_height (ret .right ), get_height (ret .left )) + 1
79- return ret
80-
44+ left_child .height = my_max (get_height (left_child .right ), get_height (left_child .left )) + 1
45+ return left_child
8146
8247def left_rotation (node : MyNode ) -> MyNode :
83- print ("right rotation node:" , node .data )
84- ret = node .right
85- node .right = ret .left
86- ret .left = node
48+ right_child = node .right
49+ if right_child is None :
50+ return node
51+
52+ node .right = right_child .left
53+ right_child .left = node
8754 node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
88- ret .height = my_max (get_height (ret .right ), get_height (ret .left )) + 1
89- return ret
90-
55+ right_child .height = my_max (get_height (right_child .right ), get_height (right_child .left )) + 1
56+ return right_child
9157
9258def lr_rotation (node : MyNode ) -> MyNode :
93- node .left = left_rotation (node .left )
59+ if node .left :
60+ node .left = left_rotation (node .left )
9461 return right_rotation (node )
9562
96-
9763def rl_rotation (node : MyNode ) -> MyNode :
98- node .right = right_rotation (node .right )
64+ if node .right :
65+ node .right = right_rotation (node .right )
9966 return left_rotation (node )
10067
101-
10268def insert_node (node : MyNode | None , data : Any ) -> MyNode | None :
103- if not node :
69+ if node is None :
10470 return MyNode (data )
105-
71+
10672 if data < node .data :
10773 node .left = insert_node (node .left , data )
10874 if get_height (node .left ) - get_height (node .right ) == 2 :
109- if data < node .left .data :
75+ if node . left and data < node .left .data :
11076 node = right_rotation (node )
11177 else :
11278 node = lr_rotation (node )
11379 else :
11480 node .right = insert_node (node .right , data )
11581 if get_height (node .right ) - get_height (node .left ) == 2 :
116- if data < node .right .data :
82+ if node . right and data < node .right .data :
11783 node = rl_rotation (node )
11884 else :
11985 node = left_rotation (node )
120-
86+
12187 node .height = my_max (get_height (node .right ), get_height (node .left )) + 1
12288 return node
12389
124-
125- def get_extreme (root : MyNode , is_right : bool ) -> Any :
126- while child := root .right if is_right else root .left :
127- root = child
90+ def get_left_most (root : MyNode ) -> Any :
91+ while root .left :
92+ root = root .left
12893 return root .data
12994
130-
131- def del_node (root : MyNode , data : Any ) -> MyNode | None :
132- if root .data == data :
95+ def del_node (root : MyNode | None , data : Any ) -> MyNode | None :
96+ if root is None :
97+ return None
98+
99+ if data == root .data :
133100 if root .left and root .right :
134- root .data = get_extreme (root .right , False )
101+ root .data = get_left_most (root .right )
135102 root .right = del_node (root .right , root .data )
136103 else :
137104 return root .left or root .right
138- elif root .data > data :
139- if not root .left :
140- return root
105+ elif data < root .data :
141106 root .left = del_node (root .left , data )
142107 else :
143108 root .right = del_node (root .right , data )
144-
145- # Handle balancing
146- right_height = get_height (root .right )
109+
110+ if root .left is None and root .right is None :
111+ root .height = 1
112+ return root
113+
147114 left_height = get_height (root .left )
148-
115+ right_height = get_height (root .right )
116+
149117 if right_height - left_height == 2 :
150- if get_height (root .right .right ) > get_height (root .right .left ):
118+ right_right = get_height (root .right .right ) if root .right else 0
119+ right_left = get_height (root .right .left ) if root .right else 0
120+ if right_right > right_left :
151121 root = left_rotation (root )
152122 else :
153123 root = rl_rotation (root )
154- elif right_height - left_height == - 2 :
155- if get_height (root .left .left ) > get_height (root .left .right ):
124+ elif left_height - right_height == 2 :
125+ left_left = get_height (root .left .left ) if root .left else 0
126+ left_right = get_height (root .left .right ) if root .left else 0
127+ if left_left > left_right :
156128 root = right_rotation (root )
157129 else :
158130 root = lr_rotation (root )
159-
131+
160132 root .height = my_max (get_height (root .right ), get_height (root .left )) + 1
161133 return root
162-
163-
164134class AVLtree :
165- def __init__ (self ) -> None :
166- self .root = None
167-
168- def get_height (self ) -> int :
135+ __slots__ = ("root" ,)
136+
137+ def __init__ (self ) -> None :
138+ self .root : MyNode | None = None
139+
140+ def get_height (self ) -> int :
169141 return get_height (self .root )
170-
142+
171143 def insert (self , data : Any ) -> None :
172- print (f"insert:{ data } " )
173144 self .root = insert_node (self .root , data )
174-
145+
175146 def del_node (self , data : Any ) -> None :
176- print (f"delete:{ data } " )
177- if not self .root :
178- return
179147 self .root = del_node (self .root , data )
180-
148+
181149 def __str__ (self ) -> str :
182- if not self .root :
150+ if self .root is None :
183151 return ""
184- q , output , layer , cnt = MyQueue (), "" , self .get_height (), 0
185- q .push (self .root )
186-
187- while not q .is_empty ():
188- node = q .pop ()
189- space = " " * int (2 ** (layer - 1 ))
190- output += space + (str (node .data ) if node else "*" ) + space
191- cnt += 1
192-
193- if node :
194- q .push (node .left )
195- q .push (node .right )
152+
153+ levels = []
154+ queue = [self .root ]
155+
156+ while queue :
157+ current = []
158+ next_level = []
159+
160+ for node in queue :
161+ if node :
162+ current .append (str (node .data ))
163+ next_level .append (node .left )
164+ next_level .append (node .right )
165+ else :
166+ current .append ("*" )
167+ next_level .append (None )
168+ next_level .append (None )
169+
170+ if any (node is not None for node in next_level ):
171+ levels .append (" " .join (current ))
172+ queue = next_level
196173 else :
197- q .push (None )
198- q .push (None )
199-
200- for i in range (10 ):
201- if cnt == 2 ** i - 1 :
202- layer -= 1
203- output += "\n "
204- if layer == 0 :
205- break
206- break
207-
208- return output + "\n " + "*" * 36
209-
210-
211- def _test () -> None :
212- doctest .testmod ()
213-
174+ break
175+
176+ return "\n " .join (levels ) + "\n " + "*" * 36
214177
215- if __name__ == "__main__" :
216- _test ()
178+ def test_avl_tree () -> None :
217179 t = AVLtree ()
218180 lst = list (range (10 ))
219181 random .shuffle (lst )
220-
182+
221183 for i in lst :
222184 t .insert (i )
223- print (t )
224-
185+
225186 random .shuffle (lst )
226187 for i in lst :
227188 t .del_node (i )
228- print (t )
189+
190+ if __name__ == "__main__" :
191+ test_avl_tree ()
0 commit comments