Skip to content

Commit 15d9f9b

Browse files
committed
updated code quality
1 parent a96fdd1 commit 15d9f9b

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

machine_learning/mini_batch_gradient_descent.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"""
66

77
import numpy as np
8-
from typing import Tuple
98

109

1110
def mini_batch_gradient_descent(
@@ -14,7 +13,7 @@ def mini_batch_gradient_descent(
1413
learning_rate: float = 0.01,
1514
batch_size: int = 16,
1615
n_epochs: int = 50,
17-
) -> Tuple[np.ndarray, float]:
16+
) -> tuple[np.ndarray, float]:
1817
"""
1918
Mini-Batch Gradient Descent for linear regression.
2019
@@ -41,20 +40,25 @@ def mini_batch_gradient_descent(
4140
Example
4241
-------
4342
>>> import numpy as np
44-
>>> X = np.array([[1],[2],[3],[4]])
45-
>>> y = np.array([2,4,6,8])
46-
>>> w, b = mini_batch_gradient_descent(X, y, learning_rate=0.1, batch_size=2, n_epochs=100)
43+
>>> X = np.array([[1], [2], [3], [4]])
44+
>>> y = np.array([2, 4, 6, 8])
45+
>>> w, b = mini_batch_gradient_descent(
46+
... X, y, learning_rate=0.1, batch_size=2, n_epochs=100
47+
... )
4748
>>> round(w[0], 1) # slope close to 2
4849
2.0
4950
"""
5051
n_samples, n_features = feature_matrix.shape
5152
weights = np.zeros(n_features)
5253
bias = 0
5354

55+
rng = np.random.default_rng()
56+
5457
for _ in range(n_epochs):
55-
indices = np.random.permutation(n_samples)
58+
indices = rng.permutation(n_samples)
5659
shuffled_features = feature_matrix[indices]
5760
shuffled_targets = target_values[indices]
61+
5862
for start_idx in range(0, n_samples, batch_size):
5963
end_idx = start_idx + batch_size
6064
batch_features = shuffled_features[start_idx:end_idx]
@@ -63,10 +67,11 @@ def mini_batch_gradient_descent(
6367
errors = predictions - batch_targets
6468
weights -= learning_rate * (batch_features.T @ errors) / len(batch_targets)
6569
bias -= learning_rate * np.mean(errors)
70+
6671
return weights, bias
6772

6873

6974
if __name__ == "__main__":
7075
import doctest
7176

72-
doctest.testmod()
77+
doctest.testmod()

0 commit comments

Comments
 (0)