Skip to content

Commit 6c2f7b4

Browse files
Update run.py
Used matplotlib to plot actual vs predicted user count, forecast confidence intervals, outlier thresholds from IQR. Added logging instead of print because in production, print() is not scalable.
1 parent 062e046 commit 6c2f7b4

1 file changed

Lines changed: 23 additions & 4 deletions

File tree

  • machine_learning/forecasting

machine_learning/forecasting/run.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
u can just adjust it for ur own purpose
1212
"""
1313

14+
import logging
1415
from warnings import simplefilter
1516

1617
import numpy as np
@@ -19,7 +20,10 @@
1920
from sklearn.preprocessing import Normalizer
2021
from sklearn.svm import SVR
2122
from statsmodels.tsa.statespace.sarimax import SARIMAX
23+
import matplotlib.pyplot as plt
2224

25+
logging.basicConfig(level=logging.Info)
26+
logger = logging.getLogger(__name__)
2327

2428
def linear_regression_prediction(
2529
train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list
@@ -143,6 +147,21 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool:
143147
not_safe += 1
144148
return safe > not_safe
145149

150+
def plot_forecast(actual, predictions):
151+
plt.figure(figsize=(10, 5))
152+
plt.plot(range(len(actual)), actual, label="Actual")
153+
plt.plot(len(actual), predictions[0], 'ro', label="Linear Reg")
154+
plt.plot(len(actual), predictions[1], 'go', label="SARIMAX")
155+
plt.plot(len(actual), predictions[2], 'bo', label="SVR")
156+
plt.plot(len(actual), predictions[3], 'yo', label="RF")
157+
plt.legend()
158+
plt.title("Data Safety Forecast")
159+
plt.xlabel("Days")
160+
plt.ylabel("Normalized User Count")
161+
plt.grid(True)
162+
plt.tight_layout()
163+
plt.show()
164+
146165

147166
if __name__ == "__main__":
148167
"""
@@ -179,11 +198,11 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool:
179198
),
180199
sarimax_predictor(train_user, train_match, test_match),
181200
support_vector_regressor(x_train, x_test, train_user),
182-
random_forest_regressor(
183-
x_train, x_test, train_user
184-
), # Added Random Forest Regressor
201+
random_forest_regressor(x_train, x_test, train_user),
185202
]
186203

187204
# check the safety of today's data
188205
not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not "
189-
print(f"Today's data is {not_str}safe.")
206+
logger.info(f"Today's data is {not_str}safe.")
207+
208+
plot_forecast(train_user, res_vote)

0 commit comments

Comments
 (0)