Skip to content

Commit c8df2cc

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6c2f7b4 commit c8df2cc

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

  • machine_learning/forecasting

machine_learning/forecasting/run.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
logging.basicConfig(level=logging.Info)
2626
logger = logging.getLogger(__name__)
2727

28+
2829
def linear_regression_prediction(
2930
train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list
3031
) -> float:
@@ -147,13 +148,14 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool:
147148
not_safe += 1
148149
return safe > not_safe
149150

151+
150152
def plot_forecast(actual, predictions):
151153
plt.figure(figsize=(10, 5))
152154
plt.plot(range(len(actual)), actual, label="Actual")
153-
plt.plot(len(actual), predictions[0], 'ro', label="Linear Reg")
154-
plt.plot(len(actual), predictions[1], 'go', label="SARIMAX")
155-
plt.plot(len(actual), predictions[2], 'bo', label="SVR")
156-
plt.plot(len(actual), predictions[3], 'yo', label="RF")
155+
plt.plot(len(actual), predictions[0], "ro", label="Linear Reg")
156+
plt.plot(len(actual), predictions[1], "go", label="SARIMAX")
157+
plt.plot(len(actual), predictions[2], "bo", label="SVR")
158+
plt.plot(len(actual), predictions[3], "yo", label="RF")
157159
plt.legend()
158160
plt.title("Data Safety Forecast")
159161
plt.xlabel("Days")
@@ -204,5 +206,5 @@ def plot_forecast(actual, predictions):
204206
# check the safety of today's data
205207
not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not "
206208
logger.info(f"Today's data is {not_str}safe.")
207-
209+
208210
plot_forecast(train_user, res_vote)

0 commit comments

Comments
 (0)