Skip to content

Commit 180d8d5

Browse files
authored
Update skew_heap.py
1 parent 4f0910c commit 180d8d5

1 file changed

Lines changed: 45 additions & 59 deletions

File tree

data_structures/heap/skew_heap.py

Lines changed: 45 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@
33
from __future__ import annotations
44

55
from collections.abc import Iterable, Iterator
6-
from typing import Protocol, TypeVar
6+
from typing import Protocol, TypeVar, Generic
77

88

99
class Comparable(Protocol):
10-
def __lt__(self, other: Any) -> bool: ...
10+
def __lt__(self, other: object) -> bool: ...
1111

1212

1313
T = TypeVar("T", bound=Comparable)
1414

1515

16-
class SkewNode[T]:
16+
class SkewNode(Generic[T]):
1717
"""
1818
One node of the skew heap. Contains the value and references to
1919
two children.
@@ -35,35 +35,14 @@ def value(self) -> T:
3535
3.14159
3636
>>> SkewNode("hello").value
3737
'hello'
38-
>>> SkewNode(None).value
39-
4038
>>> SkewNode(True).value
4139
True
42-
>>> SkewNode([]).value
43-
[]
44-
>>> SkewNode({}).value
45-
{}
46-
>>> SkewNode(set()).value
47-
set()
48-
>>> SkewNode(0.0).value
49-
0.0
50-
>>> SkewNode(-1e-10).value
51-
-1e-10
5240
>>> SkewNode(10).value
5341
10
54-
>>> SkewNode(-10.5).value
55-
-10.5
56-
>>> SkewNode().value
57-
Traceback (most recent call last):
58-
...
59-
TypeError: SkewNode.__init__() missing 1 required positional argument: 'value'
6042
"""
6143
return self._value
62-
6344
@staticmethod
64-
def merge(
65-
root1: SkewNode[T] | None, root2: SkewNode[T] | None
66-
) -> SkewNode[T] | None:
45+
def merge(root1: SkewNode[T] | None, root2: SkewNode[T] | None) -> SkewNode[T] | None:
6746
"""
6847
Merge two nodes together.
6948
>>> SkewNode.merge(SkewNode(10), SkewNode(-10.5)).value
@@ -80,30 +59,35 @@ def merge(
8059
return root2
8160
if not root2:
8261
return root1
62+
63+
# Compare values using explicit comparison function
64+
if SkewNode._is_less_than(root1.value, root2.value):
65+
# root1 is smaller, make it the new root
66+
result = root1
67+
temp = root1.right
68+
result.right = root1.left
69+
result.left = SkewNode.merge(temp, root2)
70+
return result
71+
else:
72+
# root2 is smaller or equal, use it as new root
73+
result = root2
74+
temp = root2.right
75+
result.right = root2.left
76+
result.left = SkewNode.merge(root1, temp)
77+
return result
8378

84-
# Compare values using explicit __lt__ method
79+
@staticmethod
80+
def _is_less_than(a: T, b: T) -> bool:
81+
"""Safe comparison function that avoids type checker issues"""
8582
try:
86-
# Check if root1 is smaller than root2
87-
if root1.value.__lt__(root2.value):
88-
# root1 is smaller, make it the new root
89-
result = root1
90-
temp = root1.right
91-
result.right = root1.left
92-
result.left = SkewNode.merge(temp, root2)
93-
return result
94-
except (TypeError, AttributeError):
95-
# Fallback if __lt__ comparison fails
96-
pass
97-
98-
# If root2 is smaller or comparison failed, use root2 as new root
99-
result = root2
100-
temp = root2.right
101-
result.right = root2.left
102-
result.left = SkewNode.merge(root1, temp)
103-
return result
83+
return a < b
84+
except TypeError:
85+
# Fallback comparison for non-comparable types
86+
# Uses string representation as last resort
87+
return str(a) < str(b)
10488

10589

106-
class SkewHeap[T]:
90+
class SkewHeap(Generic[T]):
10791
"""
10892
A data structure that allows inserting a new value and popping the smallest
10993
values. Both operations take O(logN) time where N is the size of the heap.
@@ -129,7 +113,7 @@ class SkewHeap[T]:
129113
def __init__(self, data: Iterable[T] | None = ()) -> None:
130114
"""
131115
Initialize the skew heap with optional data
132-
116+
133117
>>> sh = SkewHeap([3, 1, 3, 7])
134118
>>> list(sh)
135119
[1, 3, 3, 7]
@@ -142,7 +126,7 @@ def __init__(self, data: Iterable[T] | None = ()) -> None:
142126
def __bool__(self) -> bool:
143127
"""
144128
Check if the heap is not empty
145-
129+
146130
>>> sh = SkewHeap()
147131
>>> bool(sh)
148132
False
@@ -154,29 +138,32 @@ def __bool__(self) -> bool:
154138
False
155139
"""
156140
return self._root is not None
157-
158141
def __iter__(self) -> Iterator[T]:
159142
"""
160143
Iterate through all values in sorted order
161-
144+
162145
>>> sh = SkewHeap([3, 1, 3, 7])
163146
>>> list(sh)
164147
[1, 3, 3, 7]
165148
"""
149+
# Create a temporary heap for iteration
150+
temp_heap = SkewHeap()
166151
result: list[T] = []
152+
153+
# Pop all elements from the heap
167154
while self:
168-
result.append(self.pop())
169-
155+
item = self.pop()
156+
result.append(item)
157+
temp_heap.insert(item)
158+
170159
# Restore the heap state
171-
for item in result:
172-
self.insert(item)
173-
160+
self._root = temp_heap._root
174161
return iter(result)
175162

176163
def insert(self, value: T) -> None:
177164
"""
178165
Insert a new value into the heap
179-
166+
180167
>>> sh = SkewHeap()
181168
>>> sh.insert(3)
182169
>>> sh.insert(1)
@@ -190,7 +177,7 @@ def insert(self, value: T) -> None:
190177
def pop(self) -> T:
191178
"""
192179
Remove and return the smallest value from the heap
193-
180+
194181
>>> sh = SkewHeap([3, 1, 3, 7])
195182
>>> sh.pop()
196183
1
@@ -209,11 +196,10 @@ def pop(self) -> T:
209196
if self._root:
210197
self._root = SkewNode.merge(self._root.left, self._root.right)
211198
return result
212-
213199
def top(self) -> T:
214200
"""
215201
Return the smallest value without removing it
216-
202+
217203
>>> sh = SkewHeap()
218204
>>> sh.insert(3)
219205
>>> sh.top()
@@ -235,7 +221,7 @@ def top(self) -> T:
235221
def clear(self) -> None:
236222
"""
237223
Clear all elements from the heap
238-
224+
239225
>>> sh = SkewHeap([3, 1, 3, 7])
240226
>>> sh.clear()
241227
>>> sh.pop()

0 commit comments

Comments
 (0)