Skip to content

Commit 4c6bd9d

Browse files
committed
Refactor power_sort.py to improve code clarity and organization. Updated import statements, removed unnecessary whitespace, and enhanced comments for better readability. Adjusted key function handling for reverse sorting and ensured proper handling of numeric and non-numeric types.
1 parent 6c46eaf commit 4c6bd9d

1 file changed

Lines changed: 86 additions & 72 deletions

File tree

sorts/power_sort.py

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,28 @@
2727

2828
from __future__ import annotations
2929

30-
from typing import Any, Callable
30+
from collections.abc import Callable
31+
from typing import Any
3132

3233

3334
def _find_run(
3435
arr: list, start: int, end: int, key: Callable[[Any], Any] | None = None
3536
) -> int:
3637
"""
3738
Detect a run (ascending or descending sequence) starting at 'start'.
38-
39+
3940
If the run is descending, reverse it in-place to make it ascending.
4041
Returns the end index (exclusive) of the detected run.
41-
42+
4243
Args:
4344
arr: The list to search in
4445
start: Starting index of the run
4546
end: End index (exclusive) of the search range
4647
key: Optional key function for comparisons
47-
48+
4849
Returns:
4950
End index (exclusive) of the detected run
50-
51+
5152
>>> arr = [3, 2, 1, 4, 5, 6]
5253
>>> _find_run(arr, 0, 6)
5354
3
@@ -61,10 +62,10 @@ def _find_run(
6162
"""
6263
if start >= end - 1:
6364
return start + 1
64-
65+
6566
key_func = key if key else lambda x: x
6667
run_end = start + 1
67-
68+
6869
# Check if run is ascending or descending
6970
if key_func(arr[run_end]) < key_func(arr[start]):
7071
# Descending run
@@ -76,48 +77,51 @@ def _find_run(
7677
# Ascending run
7778
while run_end < end and key_func(arr[run_end]) >= key_func(arr[run_end - 1]):
7879
run_end += 1
79-
80+
8081
return run_end
8182

8283

8384
def _node_power(n: int, b1: int, n1: int, b2: int, n2: int) -> int:
8485
"""
8586
Calculate the node power for two adjacent runs.
86-
87+
8788
This determines the merge priority in the stack. The power is the smallest
8889
integer p such that floor(a * 2^p) != floor(b * 2^p), where:
8990
- a = (b1 + n1/2) / n
9091
- b = (b2 + n2/2) / n
91-
92+
9293
Args:
9394
n: Total length of the array
9495
b1: Start index of first run
9596
n1: Length of first run
9697
b2: Start index of second run
9798
n2: Length of second run
98-
99+
99100
Returns:
100101
The calculated node power
101-
102+
102103
>>> _node_power(100, 0, 25, 25, 25)
103104
2
104105
>>> _node_power(100, 0, 50, 50, 50)
105106
1
106107
"""
107108
# Calculate midpoints: a = (b1 + n1/2) / n, b = (b2 + n2/2) / n
108-
# To avoid floating point, we work with a = (2*b1 + n1) / (2*n) and b = (2*b2 + n2) / (2*n)
109+
# To avoid floating point, we work with a = (2*b1 + n1) / (2*n) and
110+
# b = (2*b2 + n2) / (2*n)
109111
# We want smallest p where floor(a * 2^p) != floor(b * 2^p)
110-
# This is floor((2*b1 + n1) * 2^p / (2*n)) != floor((2*b2 + n2) * 2^p / (2*n))
111-
112+
# This is floor((2*b1 + n1) * 2^p / (2*n)) !=
113+
# floor((2*b2 + n2) * 2^p / (2*n))
114+
112115
a = 2 * b1 + n1
113116
b = 2 * b2 + n2
114117
two_n = 2 * n
115-
116-
# Find smallest power p where floor(a * 2^p / two_n) != floor(b * 2^p / two_n)
118+
119+
# Find smallest power p where floor(a * 2^p / two_n) !=
120+
# floor(b * 2^p / two_n)
117121
power = 0
118122
while (a * (1 << power)) // two_n == (b * (1 << power)) // two_n:
119123
power += 1
120-
124+
121125
return power
122126

123127

@@ -130,16 +134,16 @@ def _merge(
130134
) -> None:
131135
"""
132136
Merge two adjacent sorted runs in-place using auxiliary space.
133-
137+
134138
Merges arr[start1:end1] with arr[end1:end2].
135-
139+
136140
Args:
137141
arr: The list containing the runs
138142
start1: Start index of first run
139143
end1: End index of first run (start of second run)
140144
end2: End index of second run
141145
key: Optional key function for comparisons
142-
146+
143147
>>> arr = [1, 3, 5, 2, 4, 6]
144148
>>> _merge(arr, 0, 3, 6)
145149
>>> arr
@@ -150,14 +154,14 @@ def _merge(
150154
[1, 2, 3, 5, 6, 7]
151155
"""
152156
key_func = key if key else lambda x: x
153-
157+
154158
# Copy the runs to temporary storage
155159
left = arr[start1:end1]
156160
right = arr[end1:end2]
157-
161+
158162
i = j = 0
159163
k = start1
160-
164+
161165
# Merge the two runs
162166
while i < len(left) and j < len(right):
163167
if key_func(left[i]) <= key_func(right[j]):
@@ -167,13 +171,13 @@ def _merge(
167171
arr[k] = right[j]
168172
j += 1
169173
k += 1
170-
174+
171175
# Copy remaining elements
172176
while i < len(left):
173177
arr[k] = left[i]
174178
i += 1
175179
k += 1
176-
180+
177181
while j < len(right):
178182
arr[k] = right[j]
179183
j += 1
@@ -188,21 +192,21 @@ def power_sort(
188192
) -> list:
189193
"""
190194
Sort a list using the PowerSort algorithm.
191-
195+
192196
PowerSort is an adaptive merge sort that detects existing runs in the data
193197
and uses a power-based merging strategy for optimal performance.
194-
198+
195199
Args:
196200
collection: A mutable ordered collection with comparable items
197201
key: Optional function to extract comparison key from each element
198202
reverse: If True, sort in descending order
199-
203+
200204
Returns:
201205
The same collection ordered according to the parameters
202-
206+
203207
Time Complexity: O(n log n) worst case, O(n) for nearly sorted data
204208
Space Complexity: O(n)
205-
209+
206210
Examples:
207211
>>> power_sort([0, 5, 3, 2, 2])
208212
[0, 2, 2, 3, 5]
@@ -230,66 +234,70 @@ def power_sort(
230234
[(1, 'b'), (1, 'a'), (2, 'a')]
231235
>>> power_sort([1, 2, 3, 2, 1, 2, 3, 4])
232236
[1, 1, 2, 2, 2, 3, 3, 4]
233-
>>> power_sort(list(range(100)))
234-
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
235-
>>> power_sort(list(reversed(range(50))))
236-
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]
237+
>>> result = power_sort(list(range(100)))
238+
>>> result == list(range(100))
239+
True
240+
>>> result = power_sort(list(reversed(range(50))))
241+
>>> result == list(range(50))
242+
True
237243
"""
238244
if len(collection) <= 1:
239245
return collection
240-
246+
241247
# Make a copy to avoid modifying the original if it's immutable
242248
arr = list(collection)
243249
n = len(arr)
244-
250+
245251
# Adjust key function for reverse sorting
252+
needs_final_reverse = False
246253
if reverse:
247254
if key:
248255
original_key = key
249-
key = lambda x: -original_key(x) if isinstance(original_key(x), (int, float)) else original_key(x)
250-
# For non-numeric types, we'll need a different approach
251-
# Store original key and use negation wrapper
256+
252257
def reverse_key(x):
253258
val = original_key(x)
254-
# For comparable types, we can't negate, so we'll reverse at the end
259+
if isinstance(val, int | float):
260+
return -val
255261
return val
262+
256263
key = reverse_key
257264
needs_final_reverse = True
258265
else:
259-
key = lambda x: -x if isinstance(x, (int, float)) else x
266+
267+
def reverse_cmp(x):
268+
if isinstance(x, int | float):
269+
return -x
270+
return x
271+
272+
key = reverse_cmp
260273
needs_final_reverse = True
261-
else:
262-
needs_final_reverse = False
263-
274+
264275
# Stack to hold runs: each entry is (start_index, length, power)
265-
# Capacity is ceil(log2(n)) + 1
266-
import math
267-
stack_capacity = math.ceil(math.log2(n)) + 1 if n > 1 else 2
268276
stack: list[tuple[int, int, int]] = []
269-
277+
270278
start = 0
271279
while start < n:
272280
# Find the next run
273281
run_end = _find_run(arr, start, n, key)
274282
run_length = run_end - start
275-
283+
276284
# Calculate power for this run
277285
if len(stack) == 0:
278286
power = 0
279287
else:
280288
prev_start, prev_length, _ = stack[-1]
281289
power = _node_power(n, prev_start, prev_length, start, run_length)
282-
290+
283291
# Merge runs from stack based on power comparison
284292
while len(stack) > 0 and stack[-1][2] >= power:
285293
# Merge the top run with the current run
286-
prev_start, prev_length, prev_power = stack.pop()
294+
prev_start, prev_length, _ = stack.pop()
287295
_merge(arr, prev_start, prev_start + prev_length, run_end, key)
288-
296+
289297
# Update current run to include the merged run
290298
start = prev_start
291299
run_length = run_end - start
292-
300+
293301
# Recalculate power
294302
if len(stack) == 0:
295303
power = 0
@@ -298,60 +306,66 @@ def reverse_key(x):
298306
power = _node_power(
299307
n, prev_prev_start, prev_prev_length, start, run_length
300308
)
301-
309+
302310
# Push current run onto stack
303311
stack.append((start, run_length, power))
304312
start = run_end
305-
313+
306314
# Merge all remaining runs on the stack
307315
while len(stack) > 1:
308316
start2, length2, _ = stack.pop()
309-
start1, length1, power1 = stack.pop()
317+
start1, length1, _ = stack.pop()
310318
_merge(arr, start1, start1 + length1, start2 + length2, key)
311-
319+
312320
# Recalculate power for merged run
313321
if len(stack) == 0:
314322
power = 0
315323
else:
316324
prev_start, prev_length, _ = stack[-1]
317-
power = _node_power(n, prev_start, prev_length, start1, start2 + length2 - start1)
318-
325+
merged_length = start2 + length2 - start1
326+
power = _node_power(n, prev_start, prev_length, start1, merged_length)
327+
319328
stack.append((start1, start2 + length2 - start1, power))
320-
329+
321330
# Handle reverse sorting for non-numeric types
322-
if reverse and needs_final_reverse:
331+
if (
332+
reverse
333+
and needs_final_reverse
334+
and key
335+
and len(arr) > 0
336+
and not isinstance(arr[0], int | float)
337+
):
323338
# For non-numeric types, we need to reverse the final result
324339
# Check if we used numeric negation or not
325-
if key and not isinstance(arr[0], (int, float)):
326-
arr.reverse()
327-
340+
arr.reverse()
341+
328342
return arr
329343

330344

331345
if __name__ == "__main__":
332346
import doctest
333-
347+
334348
doctest.testmod()
335-
349+
336350
print("\nPowerSort Interactive Testing")
337351
print("=" * 40)
338-
352+
339353
try:
340354
user_input = input("Enter numbers separated by a comma:\n").strip()
341355
if user_input == "":
342356
unsorted = []
343357
else:
344358
unsorted = [int(item.strip()) for item in user_input.split(",")]
345-
359+
346360
print(f"\nOriginal: {unsorted}")
347361
sorted_list = power_sort(unsorted)
348362
print(f"Sorted: {sorted_list}")
349-
363+
350364
# Test reverse
351365
sorted_reverse = power_sort(unsorted, reverse=True)
352366
print(f"Reverse: {sorted_reverse}")
353-
367+
354368
except ValueError:
355369
print("Invalid input. Please enter valid integers separated by commas.")
356370
except KeyboardInterrupt:
357-
print("\n\nGoodbye!")
371+
print("\n\nGoodbye!")

0 commit comments

Comments
 (0)