|
11 | 11 | u can just adjust it for ur own purpose |
12 | 12 | """ |
13 | 13 |
|
| 14 | +import logging |
14 | 15 | from warnings import simplefilter |
15 | 16 |
|
16 | 17 | import numpy as np |
|
19 | 20 | from sklearn.preprocessing import Normalizer |
20 | 21 | from sklearn.svm import SVR |
21 | 22 | from statsmodels.tsa.statespace.sarimax import SARIMAX |
| 23 | +import matplotlib.pyplot as plt |
22 | 24 |
|
| 25 | +logging.basicConfig(level=logging.Info) |
| 26 | +logger = logging.getLogger(__name__) |
23 | 27 |
|
24 | 28 | def linear_regression_prediction( |
25 | 29 | train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list |
@@ -143,6 +147,21 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool: |
143 | 147 | not_safe += 1 |
144 | 148 | return safe > not_safe |
145 | 149 |
|
| 150 | +def plot_forecast(actual, predictions): |
| 151 | + plt.figure(figsize=(10, 5)) |
| 152 | + plt.plot(range(len(actual)), actual, label="Actual") |
| 153 | + plt.plot(len(actual), predictions[0], 'ro', label="Linear Reg") |
| 154 | + plt.plot(len(actual), predictions[1], 'go', label="SARIMAX") |
| 155 | + plt.plot(len(actual), predictions[2], 'bo', label="SVR") |
| 156 | + plt.plot(len(actual), predictions[3], 'yo', label="RF") |
| 157 | + plt.legend() |
| 158 | + plt.title("Data Safety Forecast") |
| 159 | + plt.xlabel("Days") |
| 160 | + plt.ylabel("Normalized User Count") |
| 161 | + plt.grid(True) |
| 162 | + plt.tight_layout() |
| 163 | + plt.show() |
| 164 | + |
146 | 165 |
|
147 | 166 | if __name__ == "__main__": |
148 | 167 | """ |
@@ -179,11 +198,11 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool: |
179 | 198 | ), |
180 | 199 | sarimax_predictor(train_user, train_match, test_match), |
181 | 200 | support_vector_regressor(x_train, x_test, train_user), |
182 | | - random_forest_regressor( |
183 | | - x_train, x_test, train_user |
184 | | - ), # Added Random Forest Regressor |
| 201 | + random_forest_regressor(x_train, x_test, train_user), |
185 | 202 | ] |
186 | 203 |
|
187 | 204 | # check the safety of today's data |
188 | 205 | not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not " |
189 | | - print(f"Today's data is {not_str}safe.") |
| 206 | + logger.info(f"Today's data is {not_str}safe.") |
| 207 | + |
| 208 | + plot_forecast(train_user, res_vote) |
0 commit comments