Skip to content

Commit 3967fc0

Browse files
Improve softmax: type checks, stability, axis support, tests, docs
1 parent a71618f commit 3967fc0

1 file changed

Lines changed: 89 additions & 45 deletions

File tree

maths/softmax.py

Lines changed: 89 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,100 @@
1-
"""
2-
This script demonstrates the implementation of the Softmax function.
3-
4-
Its a function that takes as input a vector of K real numbers, and normalizes
5-
it into a probability distribution consisting of K probabilities proportional
6-
to the exponentials of the input numbers. After softmax, the elements of the
7-
vector always sum up to 1.
8-
9-
Script inspired from its corresponding Wikipedia article
10-
https://en.wikipedia.org/wiki/Softmax_function
11-
"""
12-
131
import numpy as np
2+
from typing import Optional, Union
143

154

16-
def softmax(vector):
5+
def softmax(
6+
vector: Union[np.ndarray, list, tuple],
7+
axis: Optional[int] = -1
8+
) -> np.ndarray:
179
"""
18-
Implements the softmax function
19-
20-
Parameters:
21-
vector (np.array,list,tuple): A numpy array of shape (1,n)
22-
consisting of real values or a similar list,tuple
23-
24-
25-
Returns:
26-
softmax_vec (np.array): The input numpy array after applying
27-
softmax.
28-
29-
The softmax vector adds up to one. We need to ceil to mitigate for
30-
precision
31-
>>> float(np.ceil(np.sum(softmax([1,2,3,4]))))
32-
1.0
33-
34-
>>> vec = np.array([5,5])
35-
>>> softmax(vec)
36-
array([0.5, 0.5])
37-
38-
>>> softmax([0])
39-
array([1.])
10+
Compute the softmax of `vector` along `axis` in a numerically-stable way.
11+
12+
Parameters
13+
----------
14+
vector : array_like (np.ndarray, list, or tuple)
15+
Input data (vector, matrix, tensor). Will be converted to float ndarray.
16+
axis : int or None, optional
17+
Axis along which to compute softmax. If None, compute softmax over
18+
the flattened array (single distribution). Default is -1 (last axis).
19+
20+
Returns
21+
-------
22+
np.ndarray
23+
Same shape as `vector`, with softmax applied along `axis`. Probabilities sum to 1
24+
along `axis` (or to 1 overall if axis is None).
25+
26+
Raises
27+
------
28+
ValueError
29+
If input is empty or cannot be converted to a float ndarray.
4030
"""
31+
try:
32+
vector = np.asarray(vector, dtype=float)
33+
except Exception as e:
34+
raise ValueError(f"Could not convert input to float ndarray: {e}")
35+
36+
if vector.size == 0:
37+
raise ValueError("softmax input must be non-empty")
38+
39+
if axis is None:
40+
# flatten to single distribution
41+
vector_max = np.max(vector)
42+
e_vector = np.exp(vector - vector_max)
43+
return e_vector / e_vector.sum()
44+
45+
# subtract max along axis with keepdims for numerical stability/broadcasting
46+
vector_max = np.max(vector, axis=axis, keepdims=True)
47+
e_vector = np.exp(vector - vector_max)
48+
denom = e_vector.sum(axis=axis, keepdims=True)
49+
return e_vector / denom
50+
51+
52+
# Example unit tests
53+
def _test_softmax():
54+
import numpy.testing as npt
55+
56+
# Typical 1D input
57+
result = softmax([1, 2, 3])
58+
npt.assert_almost_equal(result.sum(), 1)
59+
60+
# Typical 2D, axis=-1
61+
result = softmax([[1, 2, 3], [4, 5, 6]])
62+
npt.assert_almost_equal(result.sum(axis=-1).tolist(), [1, 1])
63+
64+
# Scalar input
65+
result = softmax([0])
66+
npt.assert_almost_equal(result, [1.0])
67+
68+
# Identical values
69+
result = softmax([5, 5])
70+
npt.assert_almost_equal(result, [0.5, 0.5])
71+
72+
# Large values for numeric stability
73+
result = softmax([1000, 1001])
74+
npt.assert_almost_equal(result.sum(), 1)
75+
76+
# axis=None flatten
77+
data = np.array([[1, 2], [3, 4]])
78+
flat_result = softmax(data, axis=None)
79+
npt.assert_almost_equal(flat_result.sum(), 1)
80+
81+
# Empty input error
82+
try:
83+
softmax([])
84+
assert False, "Expected ValueError for empty input"
85+
except ValueError:
86+
pass
87+
88+
print("All tests passed.")
4189

42-
# Calculate e^x for each x in your vector where e is Euler's
43-
# number (approximately 2.718)
44-
exponent_vector = np.exp(vector)
4590

46-
# Add up the all the exponentials
47-
sum_of_exponents = np.sum(exponent_vector)
91+
if __name__ == "__main__":
92+
print("Softmax demonstration:")
4893

49-
# Divide every exponent by the sum of all exponents
50-
softmax_vector = exponent_vector / sum_of_exponents
94+
print("softmax((0,)) =", softmax((0,)))
5195

52-
return softmax_vector
96+
print("softmax([1, 2, 3]) =", softmax([1, 2, 3]))
5397

98+
print("softmax([[1, 2, 3], [4, 5, 6]]) =", softmax([[1, 2, 3], [4, 5, 6]]))
5499

55-
if __name__ == "__main__":
56-
print(softmax((0,)))
100+
_test_softmax()

0 commit comments

Comments
 (0)