Skip to content

Commit d868aba

Browse files
Refactor function signatures for improved readability in linear regression implementation
1 parent 11fa072 commit d868aba

1 file changed

Lines changed: 17 additions & 4 deletions

File tree

machine_learning/linear_regression_naive.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def collect_dataset() -> np.ndarray:
4141

4242

4343
def run_steep_gradient_descent(
44-
data_x: np.ndarray, data_y: np.ndarray, len_data: int, alpha: float, theta: np.ndarray
44+
data_x: np.ndarray,
45+
data_y: np.ndarray,
46+
len_data: int,
47+
alpha: float,
48+
theta: np.ndarray
4549
) -> np.ndarray:
4650
"""Run one step of steep gradient descent.
4751
@@ -70,7 +74,10 @@ def run_steep_gradient_descent(
7074

7175

7276
def sum_of_square_error(
73-
data_x: np.ndarray, data_y: np.ndarray, len_data: int, theta: np.ndarray
77+
data_x: np.ndarray,
78+
data_y: np.ndarray,
79+
len_data: int,
80+
theta: np.ndarray
7481
) -> float:
7582
"""Return sum of square error for error calculation.
7683
@@ -85,7 +92,10 @@ def sum_of_square_error(
8592
return float(error)
8693

8794

88-
def run_linear_regression(data_x: np.ndarray, data_y: np.ndarray) -> np.ndarray:
95+
def run_linear_regression(
96+
data_x: np.ndarray,
97+
data_y: np.ndarray
98+
) -> np.ndarray:
8999
"""Run linear regression using gradient descent.
90100
91101
:param data_x: dataset features
@@ -108,7 +118,10 @@ def run_linear_regression(data_x: np.ndarray, data_y: np.ndarray) -> np.ndarray:
108118
return theta
109119

110120

111-
def mean_absolute_error(predicted_y: np.ndarray, original_y: np.ndarray) -> float:
121+
def mean_absolute_error(
122+
predicted_y: np.ndarray,
123+
original_y: np.ndarray
124+
) -> float:
112125
"""Return mean absolute error.
113126
114127
>>> predicted_y = np.array([3, -0.5, 2, 7])

0 commit comments

Comments
 (0)