|
3 | 3 |
|
4 | 4 |
|
5 | 5 | def softmax( |
6 | | - vector: Union[np.ndarray, list, tuple], |
7 | | - axis: Optional[int] = -1 |
| 6 | + vector: Union[np.ndarray, list, tuple], axis: Optional[int] = -1 |
8 | 7 | ) -> np.ndarray: |
9 | 8 | """ |
10 | 9 | Compute the softmax of `vector` along `axis` in a numerically-stable way. |
@@ -52,39 +51,39 @@ def softmax( |
52 | 51 | # Example unit tests |
53 | 52 | def _test_softmax(): |
54 | 53 | import numpy.testing as npt |
55 | | - |
| 54 | + |
56 | 55 | # Typical 1D input |
57 | 56 | result = softmax([1, 2, 3]) |
58 | 57 | npt.assert_almost_equal(result.sum(), 1) |
59 | | - |
| 58 | + |
60 | 59 | # Typical 2D, axis=-1 |
61 | 60 | result = softmax([[1, 2, 3], [4, 5, 6]]) |
62 | 61 | npt.assert_almost_equal(result.sum(axis=-1).tolist(), [1, 1]) |
63 | | - |
| 62 | + |
64 | 63 | # Scalar input |
65 | 64 | result = softmax([0]) |
66 | 65 | npt.assert_almost_equal(result, [1.0]) |
67 | | - |
| 66 | + |
68 | 67 | # Identical values |
69 | 68 | result = softmax([5, 5]) |
70 | 69 | npt.assert_almost_equal(result, [0.5, 0.5]) |
71 | | - |
| 70 | + |
72 | 71 | # Large values for numeric stability |
73 | 72 | result = softmax([1000, 1001]) |
74 | 73 | npt.assert_almost_equal(result.sum(), 1) |
75 | | - |
| 74 | + |
76 | 75 | # axis=None flatten |
77 | 76 | data = np.array([[1, 2], [3, 4]]) |
78 | 77 | flat_result = softmax(data, axis=None) |
79 | 78 | npt.assert_almost_equal(flat_result.sum(), 1) |
80 | | - |
| 79 | + |
81 | 80 | # Empty input error |
82 | 81 | try: |
83 | 82 | softmax([]) |
84 | 83 | assert False, "Expected ValueError for empty input" |
85 | 84 | except ValueError: |
86 | 85 | pass |
87 | | - |
| 86 | + |
88 | 87 | print("All tests passed.") |
89 | 88 |
|
90 | 89 |
|
|
0 commit comments