Skip to content

Commit 94c7f32

Browse files
committed
added mini batch gradient descent algo in ml dir
1 parent a71618f commit 94c7f32

1 file changed

Lines changed: 65 additions & 0 deletions

File tree

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
Mini-Batch Gradient Descent : https://en.wikipedia.org/wiki/Stochastic_gradient_descent
3+
Mini-batch gradient descent is an optimization method for training models
4+
by splitting the data into small batches.
5+
"""
6+
7+
import numpy as np
8+
9+
10+
def mini_batch_gradient_descent(
11+
X: np.ndarray, y: np.ndarray, lr: float = 0.01, batch_size: int = 16, n_epochs: int = 50
12+
):
13+
"""
14+
Mini-Batch Gradient Descent for linear regression.
15+
16+
Parameters
17+
----------
18+
X : np.ndarray
19+
Feature matrix.
20+
y : np.ndarray
21+
Target values.
22+
lr : float
23+
Learning rate.
24+
batch_size : int
25+
Size of mini-batches.
26+
n_epochs : int
27+
Number of training epochs.
28+
29+
Returns
30+
-------
31+
weights : np.ndarray
32+
Learned weights.
33+
bias : float
34+
Learned bias.
35+
36+
Example
37+
-------
38+
>>> import numpy as np
39+
>>> X = np.array([[1],[2],[3],[4]])
40+
>>> y = np.array([2,4,6,8])
41+
>>> w, b = mini_batch_gradient_descent(X, y, lr=0.1, batch_size=2, n_epochs=100)
42+
>>> round(w[0], 1) # slope close to 2
43+
2.0
44+
"""
45+
n_samples, n_features = X.shape
46+
weights = np.zeros(n_features)
47+
bias = 0
48+
49+
for _ in range(n_epochs):
50+
indices = np.random.permutation(n_samples)
51+
X_shuffled, y_shuffled = X[indices], y[indices]
52+
for start in range(0, n_samples, batch_size):
53+
end = start + batch_size
54+
X_batch, y_batch = X_shuffled[start:end], y_shuffled[start:end]
55+
y_pred = np.dot(X_batch, weights) + bias
56+
error = y_pred - y_batch
57+
weights -= lr * (X_batch.T @ error) / len(y_batch)
58+
bias -= lr * np.mean(error)
59+
return weights, bias
60+
61+
62+
if __name__ == "__main__":
63+
import doctest
64+
65+
doctest.testmod()

0 commit comments

Comments
 (0)