Skip to content

Commit f4462f0

Browse files
Fix softmax: improve type hints, exception handling, lint compliance
1 parent 3967fc0 commit f4462f0

1 file changed

Lines changed: 17 additions & 17 deletions

File tree

maths/softmax.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
2-
from typing import Optional, Union
2+
from typing import Optional
33

44

55
def softmax(
6-
vector: Union[np.ndarray, list, tuple],
7-
axis: Optional[int] = -1
6+
vector: np.ndarray | list | tuple,
7+
axis: int | None = -1
88
) -> np.ndarray:
99
"""
1010
Compute the softmax of `vector` along `axis` in a numerically-stable way.
@@ -20,8 +20,8 @@ def softmax(
2020
Returns
2121
-------
2222
np.ndarray
23-
Same shape as `vector`, with softmax applied along `axis`. Probabilities sum to 1
24-
along `axis` (or to 1 overall if axis is None).
23+
Same shape as `vector`, with softmax applied along `axis`. Probabilities
24+
sum to 1 along `axis` (or to 1 overall if axis is None).
2525
2626
Raises
2727
------
@@ -30,8 +30,9 @@ def softmax(
3030
"""
3131
try:
3232
vector = np.asarray(vector, dtype=float)
33-
except Exception as e:
34-
raise ValueError(f"Could not convert input to float ndarray: {e}")
33+
except TypeError as e:
34+
msg = f"Could not convert input to float ndarray: {e}"
35+
raise ValueError(msg)
3536

3637
if vector.size == 0:
3738
raise ValueError("softmax input must be non-empty")
@@ -49,42 +50,41 @@ def softmax(
4950
return e_vector / denom
5051

5152

52-
# Example unit tests
5353
def _test_softmax():
5454
import numpy.testing as npt
55-
55+
5656
# Typical 1D input
5757
result = softmax([1, 2, 3])
5858
npt.assert_almost_equal(result.sum(), 1)
59-
59+
6060
# Typical 2D, axis=-1
6161
result = softmax([[1, 2, 3], [4, 5, 6]])
6262
npt.assert_almost_equal(result.sum(axis=-1).tolist(), [1, 1])
63-
63+
6464
# Scalar input
6565
result = softmax([0])
6666
npt.assert_almost_equal(result, [1.0])
67-
67+
6868
# Identical values
6969
result = softmax([5, 5])
7070
npt.assert_almost_equal(result, [0.5, 0.5])
71-
71+
7272
# Large values for numeric stability
7373
result = softmax([1000, 1001])
7474
npt.assert_almost_equal(result.sum(), 1)
75-
75+
7676
# axis=None flatten
7777
data = np.array([[1, 2], [3, 4]])
7878
flat_result = softmax(data, axis=None)
7979
npt.assert_almost_equal(flat_result.sum(), 1)
80-
80+
8181
# Empty input error
8282
try:
8383
softmax([])
84-
assert False, "Expected ValueError for empty input"
84+
raise AssertionError("Expected ValueError for empty input")
8585
except ValueError:
8686
pass
87-
87+
8888
print("All tests passed.")
8989

9090

0 commit comments

Comments
 (0)