Skip to content

Commit 7690063

Browse files
committed
Refactor power_sort.py to enhance clarity and consistency. Updated parameter names for better understanding, improved comments for key functions, and ensured consistent handling of total length in calculations. This improves readability and maintainability of the code.
1 parent b8a0c4e commit 7690063

1 file changed

Lines changed: 50 additions & 22 deletions

File tree

sorts/power_sort.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _find_run(
6767
if start >= end - 1:
6868
return start + 1
6969

70-
key_func = key if key else lambda x: x
70+
key_func = key if key else lambda element: element
7171
run_end = start + 1
7272

7373
# Check if run is ascending or descending
@@ -85,7 +85,7 @@ def _find_run(
8585
return run_end
8686

8787

88-
def _node_power(n: int, b1: int, n1: int, b2: int, n2: int) -> int:
88+
def _node_power(total_length: int, b1: int, n1: int, b2: int, n2: int) -> int:
8989
"""
9090
Calculate the node power for two adjacent runs.
9191
@@ -97,7 +97,7 @@ def _node_power(n: int, b1: int, n1: int, b2: int, n2: int) -> int:
9797
9898
9999
Args:
100-
n: Total length of the array
100+
total_length: Total length of the array
101101
b1: Start index of first run
102102
n1: Length of first run
103103
b2: Start index of second run
@@ -113,16 +113,16 @@ def _node_power(n: int, b1: int, n1: int, b2: int, n2: int) -> int:
113113
>>> _node_power(100, 0, 50, 50, 50)
114114
1
115115
"""
116-
# Calculate midpoints: a = (b1 + n1/2) / n, b = (b2 + n2/2) / n
117-
# To avoid floating point, we work with a = (2*b1 + n1) / (2*n) and
118-
# b = (2*b2 + n2) / (2*n)
116+
# Calculate midpoints: a = (b1 + n1/2) / total_length, b = (b2 + n2/2) / total_length
117+
# To avoid floating point, we work with a = (2*b1 + n1) / (2*total_length) and
118+
# b = (2*b2 + n2) / (2*total_length)
119119
# We want smallest p where floor(a * 2^p) != floor(b * 2^p)
120-
# This is floor((2*b1 + n1) * 2^p / (2*n)) !=
121-
# floor((2*b2 + n2) * 2^p / (2*n))
120+
# This is floor((2*b1 + n1) * 2^p / (2*total_length)) !=
121+
# floor((2*b2 + n2) * 2^p / (2*total_length))
122122

123123
a = 2 * b1 + n1
124124
b = 2 * b2 + n2
125-
two_n = 2 * n
125+
two_n = 2 * total_length
126126

127127
# Find smallest power p where floor(a * 2^p / two_n) !=
128128
# floor(b * 2^p / two_n)
@@ -164,7 +164,7 @@ def _merge(
164164
>>> arr
165165
[1, 2, 3, 5, 6, 7]
166166
"""
167-
key_func = key if key else lambda x: x
167+
key_func = key if key else lambda element: element
168168

169169
# Copy the runs to temporary storage
170170
left = arr[start1:end1]
@@ -262,16 +262,30 @@ def power_sort(
262262

263263
# Make a copy to avoid modifying the original if it's immutable
264264
arr = list(collection)
265-
n = len(arr)
265+
total_length = len(arr)
266266

267267
# Adjust key function for reverse sorting
268268
needs_final_reverse = False
269269
if reverse:
270270
if key:
271271
original_key = key
272272

273-
def reverse_key(x):
274-
val = original_key(x)
273+
def reverse_key(element: Any) -> Any:
274+
"""
275+
Reverse key function for numeric values.
276+
277+
Args:
278+
element: The element to process
279+
280+
Returns:
281+
Negated value for numeric types, original value otherwise
282+
283+
>>> reverse_key(5)
284+
-5
285+
>>> reverse_key('hello')
286+
'hello'
287+
"""
288+
val = original_key(element)
275289
if isinstance(val, int | float):
276290
return -val
277291
return val
@@ -280,10 +294,24 @@ def reverse_key(x):
280294
needs_final_reverse = True
281295
else:
282296

283-
def reverse_cmp(x):
284-
if isinstance(x, int | float):
285-
return -x
286-
return x
297+
def reverse_cmp(element: Any) -> Any:
298+
"""
299+
Reverse comparison function for numeric values.
300+
301+
Args:
302+
element: The element to process
303+
304+
Returns:
305+
Negated value for numeric types, original value otherwise
306+
307+
>>> reverse_cmp(10)
308+
-10
309+
>>> reverse_cmp('test')
310+
'test'
311+
"""
312+
if isinstance(element, int | float):
313+
return -element
314+
return element
287315

288316
key = reverse_cmp
289317
needs_final_reverse = True
@@ -292,17 +320,17 @@ def reverse_cmp(x):
292320
stack: list[tuple[int, int, int]] = []
293321

294322
start = 0
295-
while start < n:
323+
while start < total_length:
296324
# Find the next run
297-
run_end = _find_run(arr, start, n, key)
325+
run_end = _find_run(arr, start, total_length, key)
298326
run_length = run_end - start
299327

300328
# Calculate power for this run
301329
if len(stack) == 0:
302330
power = 0
303331
else:
304332
prev_start, prev_length, _ = stack[-1]
305-
power = _node_power(n, prev_start, prev_length, start, run_length)
333+
power = _node_power(total_length, prev_start, prev_length, start, run_length)
306334

307335
# Merge runs from stack based on power comparison
308336
while len(stack) > 0 and stack[-1][2] >= power:
@@ -320,7 +348,7 @@ def reverse_cmp(x):
320348
else:
321349
prev_prev_start, prev_prev_length, _ = stack[-1]
322350
power = _node_power(
323-
n, prev_prev_start, prev_prev_length, start, run_length
351+
total_length, prev_prev_start, prev_prev_length, start, run_length
324352
)
325353

326354
# Push current run onto stack
@@ -339,7 +367,7 @@ def reverse_cmp(x):
339367
else:
340368
prev_start, prev_length, _ = stack[-1]
341369
merged_length = start2 + length2 - start1
342-
power = _node_power(n, prev_start, prev_length, start1, merged_length)
370+
power = _node_power(total_length, prev_start, prev_length, start1, merged_length)
343371

344372
stack.append((start1, start2 + length2 - start1, power))
345373

0 commit comments

Comments
 (0)