@@ -67,7 +67,7 @@ def _find_run(
6767 if start >= end - 1 :
6868 return start + 1
6969
70- key_func = key if key else lambda x : x
70+ key_func = key if key else lambda element : element
7171 run_end = start + 1
7272
7373 # Check if run is ascending or descending
@@ -85,7 +85,7 @@ def _find_run(
8585 return run_end
8686
8787
88- def _node_power (n : int , b1 : int , n1 : int , b2 : int , n2 : int ) -> int :
88+ def _node_power (total_length : int , b1 : int , n1 : int , b2 : int , n2 : int ) -> int :
8989 """
9090 Calculate the node power for two adjacent runs.
9191
@@ -97,7 +97,7 @@ def _node_power(n: int, b1: int, n1: int, b2: int, n2: int) -> int:
9797
9898
9999 Args:
100- n : Total length of the array
100+ total_length : Total length of the array
101101 b1: Start index of first run
102102 n1: Length of first run
103103 b2: Start index of second run
@@ -113,16 +113,16 @@ def _node_power(n: int, b1: int, n1: int, b2: int, n2: int) -> int:
113113 >>> _node_power(100, 0, 50, 50, 50)
114114 1
115115 """
116- # Calculate midpoints: a = (b1 + n1/2) / n , b = (b2 + n2/2) / n
117- # To avoid floating point, we work with a = (2*b1 + n1) / (2*n ) and
118- # b = (2*b2 + n2) / (2*n )
116+ # Calculate midpoints: a = (b1 + n1/2) / total_length , b = (b2 + n2/2) / total_length
117+ # To avoid floating point, we work with a = (2*b1 + n1) / (2*total_length ) and
118+ # b = (2*b2 + n2) / (2*total_length )
119119 # We want smallest p where floor(a * 2^p) != floor(b * 2^p)
120- # This is floor((2*b1 + n1) * 2^p / (2*n )) !=
121- # floor((2*b2 + n2) * 2^p / (2*n ))
120+ # This is floor((2*b1 + n1) * 2^p / (2*total_length )) !=
121+ # floor((2*b2 + n2) * 2^p / (2*total_length ))
122122
123123 a = 2 * b1 + n1
124124 b = 2 * b2 + n2
125- two_n = 2 * n
125+ two_n = 2 * total_length
126126
127127 # Find smallest power p where floor(a * 2^p / two_n) !=
128128 # floor(b * 2^p / two_n)
@@ -164,7 +164,7 @@ def _merge(
164164 >>> arr
165165 [1, 2, 3, 5, 6, 7]
166166 """
167- key_func = key if key else lambda x : x
167+ key_func = key if key else lambda element : element
168168
169169 # Copy the runs to temporary storage
170170 left = arr [start1 :end1 ]
@@ -262,16 +262,30 @@ def power_sort(
262262
263263 # Make a copy to avoid modifying the original if it's immutable
264264 arr = list (collection )
265- n = len (arr )
265+ total_length = len (arr )
266266
267267 # Adjust key function for reverse sorting
268268 needs_final_reverse = False
269269 if reverse :
270270 if key :
271271 original_key = key
272272
273- def reverse_key (x ):
274- val = original_key (x )
273+ def reverse_key (element : Any ) -> Any :
274+ """
275+ Reverse key function for numeric values.
276+
277+ Args:
278+ element: The element to process
279+
280+ Returns:
281+ Negated value for numeric types, original value otherwise
282+
283+ >>> reverse_key(5)
284+ -5
285+ >>> reverse_key('hello')
286+ 'hello'
287+ """
288+ val = original_key (element )
275289 if isinstance (val , int | float ):
276290 return - val
277291 return val
@@ -280,10 +294,24 @@ def reverse_key(x):
280294 needs_final_reverse = True
281295 else :
282296
283- def reverse_cmp (x ):
284- if isinstance (x , int | float ):
285- return - x
286- return x
297+ def reverse_cmp (element : Any ) -> Any :
298+ """
299+ Reverse comparison function for numeric values.
300+
301+ Args:
302+ element: The element to process
303+
304+ Returns:
305+ Negated value for numeric types, original value otherwise
306+
307+ >>> reverse_cmp(10)
308+ -10
309+ >>> reverse_cmp('test')
310+ 'test'
311+ """
312+ if isinstance (element , int | float ):
313+ return - element
314+ return element
287315
288316 key = reverse_cmp
289317 needs_final_reverse = True
@@ -292,17 +320,17 @@ def reverse_cmp(x):
292320 stack : list [tuple [int , int , int ]] = []
293321
294322 start = 0
295- while start < n :
323+ while start < total_length :
296324 # Find the next run
297- run_end = _find_run (arr , start , n , key )
325+ run_end = _find_run (arr , start , total_length , key )
298326 run_length = run_end - start
299327
300328 # Calculate power for this run
301329 if len (stack ) == 0 :
302330 power = 0
303331 else :
304332 prev_start , prev_length , _ = stack [- 1 ]
305- power = _node_power (n , prev_start , prev_length , start , run_length )
333+ power = _node_power (total_length , prev_start , prev_length , start , run_length )
306334
307335 # Merge runs from stack based on power comparison
308336 while len (stack ) > 0 and stack [- 1 ][2 ] >= power :
@@ -320,7 +348,7 @@ def reverse_cmp(x):
320348 else :
321349 prev_prev_start , prev_prev_length , _ = stack [- 1 ]
322350 power = _node_power (
323- n , prev_prev_start , prev_prev_length , start , run_length
351+ total_length , prev_prev_start , prev_prev_length , start , run_length
324352 )
325353
326354 # Push current run onto stack
@@ -339,7 +367,7 @@ def reverse_cmp(x):
339367 else :
340368 prev_start , prev_length , _ = stack [- 1 ]
341369 merged_length = start2 + length2 - start1
342- power = _node_power (n , prev_start , prev_length , start1 , merged_length )
370+ power = _node_power (total_length , prev_start , prev_length , start1 , merged_length )
343371
344372 stack .append ((start1 , start2 + length2 - start1 , power ))
345373
0 commit comments