Skip to content

Commit 4544a14

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3967fc0 commit 4544a14

1 file changed

Lines changed: 9 additions & 10 deletions

File tree

maths/softmax.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44

55
def softmax(
6-
vector: Union[np.ndarray, list, tuple],
7-
axis: Optional[int] = -1
6+
vector: Union[np.ndarray, list, tuple], axis: Optional[int] = -1
87
) -> np.ndarray:
98
"""
109
Compute the softmax of `vector` along `axis` in a numerically-stable way.
@@ -52,39 +51,39 @@ def softmax(
5251
# Example unit tests
5352
def _test_softmax():
5453
import numpy.testing as npt
55-
54+
5655
# Typical 1D input
5756
result = softmax([1, 2, 3])
5857
npt.assert_almost_equal(result.sum(), 1)
59-
58+
6059
# Typical 2D, axis=-1
6160
result = softmax([[1, 2, 3], [4, 5, 6]])
6261
npt.assert_almost_equal(result.sum(axis=-1).tolist(), [1, 1])
63-
62+
6463
# Scalar input
6564
result = softmax([0])
6665
npt.assert_almost_equal(result, [1.0])
67-
66+
6867
# Identical values
6968
result = softmax([5, 5])
7069
npt.assert_almost_equal(result, [0.5, 0.5])
71-
70+
7271
# Large values for numeric stability
7372
result = softmax([1000, 1001])
7473
npt.assert_almost_equal(result.sum(), 1)
75-
74+
7675
# axis=None flatten
7776
data = np.array([[1, 2], [3, 4]])
7877
flat_result = softmax(data, axis=None)
7978
npt.assert_almost_equal(flat_result.sum(), 1)
80-
79+
8180
# Empty input error
8281
try:
8382
softmax([])
8483
assert False, "Expected ValueError for empty input"
8584
except ValueError:
8685
pass
87-
86+
8887
print("All tests passed.")
8988

9089

0 commit comments

Comments
 (0)