Skip to content

Commit 8d33d90

Browse files
Add doctests for dataset collection and gradient descent functions
1 parent ebf3ab2 commit 8d33d90

2 files changed

Lines changed: 52 additions & 11 deletions

File tree

machine_learning/linear_regression_naive.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def collect_dataset() -> np.ndarray:
4343
"""Collect dataset of CSGO (ADR vs Rating)
4444
4545
:return: dataset as numpy matrix
46+
47+
>>> ds = collect_dataset()
48+
>>> isinstance(ds, np.matrix)
49+
True
50+
>>> ds.shape[1] >= 2
51+
True
4652
"""
4753
response = httpx.get(
4854
"https://raw.githubusercontent.com/yashLadha/The_Math_of_Intelligence/"
@@ -111,6 +117,19 @@ def run_linear_regression(data_x: np.ndarray, data_y: np.ndarray) -> np.ndarray:
111117
:param data_x: dataset features
112118
:param data_y: dataset labels
113119
:return: learned feature vector theta
120+
121+
>>> import numpy as np
122+
>>> x = np.array([[1, 1], [1, 2], [1, 3]])
123+
>>> y = np.array([1, 2, 3])
124+
>>> theta = run_linear_regression(x, y)
125+
Iteration 1: Error = ...
126+
... # lots of output omitted
127+
>>> theta.shape
128+
(1, 2)
129+
>>> abs(theta[0, 0] - 0) < 0.1 # intercept close to 0
130+
True
131+
>>> abs(theta[0, 1] - 1) < 0.1 # slope close to 1
132+
True
114133
"""
115134
iterations = 100000
116135
alpha = 0.000155

machine_learning/linear_regression_vectorized.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ def collect_dataset() -> np.ndarray:
4343
"""Collect dataset of CSGO (ADR vs Rating).
4444
4545
:return: dataset as numpy array
46+
47+
>>> ds = collect_dataset()
48+
>>> isinstance(ds, np.ndarray)
49+
True
50+
>>> ds.shape[1] >= 2
51+
True
4652
"""
4753
response = httpx.get(
4854
"https://raw.githubusercontent.com/yashLadha/The_Math_of_Intelligence/"
@@ -56,23 +62,36 @@ def collect_dataset() -> np.ndarray:
5662

5763

5864
def gradient_descent(
59-
x: np.ndarray, y: np.ndarray, alpha: float = 0.000155, iterations: int = 100000
65+
features: np.ndarray, labels: np.ndarray, alpha: float = 0.000155, iterations: int = 100000
6066
) -> np.ndarray:
6167
"""Run gradient descent in a fully vectorized form.
6268
63-
:param x: dataset features
64-
:param y: dataset labels
69+
:param features: dataset features
70+
:param labels: dataset labels
6571
:param alpha: learning rate
6672
:param iterations: number of iterations
6773
:return: learned feature vector theta
74+
75+
>>> import numpy as np
76+
>>> features = np.array([[1, 1], [1, 2], [1, 3]])
77+
>>> labels = np.array([[1], [2], [3]])
78+
>>> theta = gradient_descent(features, labels, alpha=0.01, iterations=1000)
79+
Iteration 1: Error = ...
80+
... # output omitted
81+
>>> theta.shape
82+
(2, 1)
83+
>>> abs(theta[0, 0] - 0) < 0.1 # intercept close to 0
84+
True
85+
>>> abs(theta[1, 0] - 1) < 0.1 # slope close to 1
86+
True
6887
"""
69-
m, n = x.shape
88+
m, n = features.shape
7089
theta = np.zeros((n, 1))
7190

7291
for i in range(iterations):
73-
predictions = x @ theta
74-
errors = predictions - y
75-
gradients = (x.T @ errors) / m
92+
predictions = features @ theta
93+
errors = predictions - labels
94+
gradients = (features.T @ errors) / m
7695
theta -= alpha * gradients
7796

7897
if i % (iterations // 10) == 0: # log occasionally
@@ -94,14 +113,17 @@ def mean_absolute_error(predicted_y: np.ndarray, original_y: np.ndarray) -> floa
94113

95114

96115
def main() -> None:
97-
"""Driver function."""
116+
"""Driver function.
117+
118+
>>> main() # doctest: +SKIP
119+
"""
98120
dataset = collect_dataset()
99121

100122
m = dataset.shape[0]
101-
x = np.c_[np.ones(m), dataset[:, :-1]] # add intercept term
102-
y = dataset[:, -1].reshape(-1, 1)
123+
features = np.c_[np.ones(m), dataset[:, :-1]] # add intercept term
124+
labels = dataset[:, -1].reshape(-1, 1)
103125

104-
theta = gradient_descent(x, y)
126+
theta = gradient_descent(features, labels)
105127
print("Resultant Feature vector:")
106128
for value in theta.ravel():
107129
print(f"{value:.5f}")

0 commit comments

Comments
 (0)