33from __future__ import annotations
44
55from collections .abc import Iterable , Iterator
6- from typing import Protocol , TypeVar
6+ from typing import Protocol , TypeVar , Generic
77
88
99class Comparable (Protocol ):
10- def __lt__ (self , other : Any ) -> bool : ...
10+ def __lt__ (self , other : object ) -> bool : ...
1111
1212
1313T = TypeVar ("T" , bound = Comparable )
1414
1515
16- class SkewNode [T ]:
16+ class SkewNode ( Generic [T ]) :
1717 """
1818 One node of the skew heap. Contains the value and references to
1919 two children.
@@ -35,35 +35,14 @@ def value(self) -> T:
3535 3.14159
3636 >>> SkewNode("hello").value
3737 'hello'
38- >>> SkewNode(None).value
39-
4038 >>> SkewNode(True).value
4139 True
42- >>> SkewNode([]).value
43- []
44- >>> SkewNode({}).value
45- {}
46- >>> SkewNode(set()).value
47- set()
48- >>> SkewNode(0.0).value
49- 0.0
50- >>> SkewNode(-1e-10).value
51- -1e-10
5240 >>> SkewNode(10).value
5341 10
54- >>> SkewNode(-10.5).value
55- -10.5
56- >>> SkewNode().value
57- Traceback (most recent call last):
58- ...
59- TypeError: SkewNode.__init__() missing 1 required positional argument: 'value'
6042 """
6143 return self ._value
62-
6344 @staticmethod
64- def merge (
65- root1 : SkewNode [T ] | None , root2 : SkewNode [T ] | None
66- ) -> SkewNode [T ] | None :
45+ def merge (root1 : SkewNode [T ] | None , root2 : SkewNode [T ] | None ) -> SkewNode [T ] | None :
6746 """
6847 Merge two nodes together.
6948 >>> SkewNode.merge(SkewNode(10), SkewNode(-10.5)).value
@@ -80,30 +59,35 @@ def merge(
8059 return root2
8160 if not root2 :
8261 return root1
62+
63+ # Compare values using explicit comparison function
64+ if SkewNode ._is_less_than (root1 .value , root2 .value ):
65+ # root1 is smaller, make it the new root
66+ result = root1
67+ temp = root1 .right
68+ result .right = root1 .left
69+ result .left = SkewNode .merge (temp , root2 )
70+ return result
71+ else :
72+ # root2 is smaller or equal, use it as new root
73+ result = root2
74+ temp = root2 .right
75+ result .right = root2 .left
76+ result .left = SkewNode .merge (root1 , temp )
77+ return result
8378
84- # Compare values using explicit __lt__ method
79+ @staticmethod
80+ def _is_less_than (a : T , b : T ) -> bool :
81+ """Safe comparison function that avoids type checker issues"""
8582 try :
86- # Check if root1 is smaller than root2
87- if root1 .value .__lt__ (root2 .value ):
88- # root1 is smaller, make it the new root
89- result = root1
90- temp = root1 .right
91- result .right = root1 .left
92- result .left = SkewNode .merge (temp , root2 )
93- return result
94- except (TypeError , AttributeError ):
95- # Fallback if __lt__ comparison fails
96- pass
97-
98- # If root2 is smaller or comparison failed, use root2 as new root
99- result = root2
100- temp = root2 .right
101- result .right = root2 .left
102- result .left = SkewNode .merge (root1 , temp )
103- return result
83+ return a < b
84+ except TypeError :
85+ # Fallback comparison for non-comparable types
86+ # Uses string representation as last resort
87+ return str (a ) < str (b )
10488
10589
106- class SkewHeap [T ]:
90+ class SkewHeap ( Generic [T ]) :
10791 """
10892 A data structure that allows inserting a new value and popping the smallest
10993 values. Both operations take O(logN) time where N is the size of the heap.
@@ -129,7 +113,7 @@ class SkewHeap[T]:
129113 def __init__ (self , data : Iterable [T ] | None = ()) -> None :
130114 """
131115 Initialize the skew heap with optional data
132-
116+
133117 >>> sh = SkewHeap([3, 1, 3, 7])
134118 >>> list(sh)
135119 [1, 3, 3, 7]
@@ -142,7 +126,7 @@ def __init__(self, data: Iterable[T] | None = ()) -> None:
142126 def __bool__ (self ) -> bool :
143127 """
144128 Check if the heap is not empty
145-
129+
146130 >>> sh = SkewHeap()
147131 >>> bool(sh)
148132 False
@@ -154,29 +138,32 @@ def __bool__(self) -> bool:
154138 False
155139 """
156140 return self ._root is not None
157-
158141 def __iter__ (self ) -> Iterator [T ]:
159142 """
160143 Iterate through all values in sorted order
161-
144+
162145 >>> sh = SkewHeap([3, 1, 3, 7])
163146 >>> list(sh)
164147 [1, 3, 3, 7]
165148 """
149+ # Create a temporary heap for iteration
150+ temp_heap = SkewHeap ()
166151 result : list [T ] = []
152+
153+ # Pop all elements from the heap
167154 while self :
168- result .append (self .pop ())
169-
155+ item = self .pop ()
156+ result .append (item )
157+ temp_heap .insert (item )
158+
170159 # Restore the heap state
171- for item in result :
172- self .insert (item )
173-
160+ self ._root = temp_heap ._root
174161 return iter (result )
175162
176163 def insert (self , value : T ) -> None :
177164 """
178165 Insert a new value into the heap
179-
166+
180167 >>> sh = SkewHeap()
181168 >>> sh.insert(3)
182169 >>> sh.insert(1)
@@ -190,7 +177,7 @@ def insert(self, value: T) -> None:
190177 def pop (self ) -> T :
191178 """
192179 Remove and return the smallest value from the heap
193-
180+
194181 >>> sh = SkewHeap([3, 1, 3, 7])
195182 >>> sh.pop()
196183 1
@@ -209,11 +196,10 @@ def pop(self) -> T:
209196 if self ._root :
210197 self ._root = SkewNode .merge (self ._root .left , self ._root .right )
211198 return result
212-
213199 def top (self ) -> T :
214200 """
215201 Return the smallest value without removing it
216-
202+
217203 >>> sh = SkewHeap()
218204 >>> sh.insert(3)
219205 >>> sh.top()
@@ -235,7 +221,7 @@ def top(self) -> T:
235221 def clear (self ) -> None :
236222 """
237223 Clear all elements from the heap
238-
224+
239225 >>> sh = SkewHeap([3, 1, 3, 7])
240226 >>> sh.clear()
241227 >>> sh.pop()
0 commit comments