11import numpy as np
2- from typing import Optional , Union
2+ from typing import Optional
33
44
55def softmax (
6- vector : Union [ np .ndarray , list , tuple ] ,
7- axis : Optional [ int ] = - 1
6+ vector : np .ndarray | list | tuple ,
7+ axis : int | None = - 1
88) -> np .ndarray :
99 """
1010 Compute the softmax of `vector` along `axis` in a numerically-stable way.
@@ -20,8 +20,8 @@ def softmax(
2020 Returns
2121 -------
2222 np.ndarray
23- Same shape as `vector`, with softmax applied along `axis`. Probabilities sum to 1
24- along `axis` (or to 1 overall if axis is None).
23+ Same shape as `vector`, with softmax applied along `axis`. Probabilities
24+ sum to 1 along `axis` (or to 1 overall if axis is None).
2525
2626 Raises
2727 ------
@@ -30,8 +30,9 @@ def softmax(
3030 """
3131 try :
3232 vector = np .asarray (vector , dtype = float )
33- except Exception as e :
34- raise ValueError (f"Could not convert input to float ndarray: { e } " )
33+ except TypeError as e :
34+ msg = f"Could not convert input to float ndarray: { e } "
35+ raise ValueError (msg )
3536
3637 if vector .size == 0 :
3738 raise ValueError ("softmax input must be non-empty" )
@@ -49,42 +50,41 @@ def softmax(
4950 return e_vector / denom
5051
5152
52- # Example unit tests
5353def _test_softmax ():
5454 import numpy .testing as npt
55-
55+
5656 # Typical 1D input
5757 result = softmax ([1 , 2 , 3 ])
5858 npt .assert_almost_equal (result .sum (), 1 )
59-
59+
6060 # Typical 2D, axis=-1
6161 result = softmax ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
6262 npt .assert_almost_equal (result .sum (axis = - 1 ).tolist (), [1 , 1 ])
63-
63+
6464 # Scalar input
6565 result = softmax ([0 ])
6666 npt .assert_almost_equal (result , [1.0 ])
67-
67+
6868 # Identical values
6969 result = softmax ([5 , 5 ])
7070 npt .assert_almost_equal (result , [0.5 , 0.5 ])
71-
71+
7272 # Large values for numeric stability
7373 result = softmax ([1000 , 1001 ])
7474 npt .assert_almost_equal (result .sum (), 1 )
75-
75+
7676 # axis=None flatten
7777 data = np .array ([[1 , 2 ], [3 , 4 ]])
7878 flat_result = softmax (data , axis = None )
7979 npt .assert_almost_equal (flat_result .sum (), 1 )
80-
80+
8181 # Empty input error
8282 try :
8383 softmax ([])
84- assert False , "Expected ValueError for empty input"
84+ raise AssertionError ( "Expected ValueError for empty input" )
8585 except ValueError :
8686 pass
87-
87+
8888 print ("All tests passed." )
8989
9090
0 commit comments