|
25 | 25 | logging.basicConfig(level=logging.Info) |
26 | 26 | logger = logging.getLogger(__name__) |
27 | 27 |
|
| 28 | + |
28 | 29 | def linear_regression_prediction( |
29 | 30 | train_dt: list, train_usr: list, train_mtch: list, test_dt: list, test_mtch: list |
30 | 31 | ) -> float: |
@@ -147,13 +148,14 @@ def data_safety_checker(list_vote: list, actual_result: float) -> bool: |
147 | 148 | not_safe += 1 |
148 | 149 | return safe > not_safe |
149 | 150 |
|
| 151 | + |
150 | 152 | def plot_forecast(actual, predictions): |
151 | 153 | plt.figure(figsize=(10, 5)) |
152 | 154 | plt.plot(range(len(actual)), actual, label="Actual") |
153 | | - plt.plot(len(actual), predictions[0], 'ro', label="Linear Reg") |
154 | | - plt.plot(len(actual), predictions[1], 'go', label="SARIMAX") |
155 | | - plt.plot(len(actual), predictions[2], 'bo', label="SVR") |
156 | | - plt.plot(len(actual), predictions[3], 'yo', label="RF") |
| 155 | + plt.plot(len(actual), predictions[0], "ro", label="Linear Reg") |
| 156 | + plt.plot(len(actual), predictions[1], "go", label="SARIMAX") |
| 157 | + plt.plot(len(actual), predictions[2], "bo", label="SVR") |
| 158 | + plt.plot(len(actual), predictions[3], "yo", label="RF") |
157 | 159 | plt.legend() |
158 | 160 | plt.title("Data Safety Forecast") |
159 | 161 | plt.xlabel("Days") |
@@ -204,5 +206,5 @@ def plot_forecast(actual, predictions): |
204 | 206 | # check the safety of today's data |
205 | 207 | not_str = "" if data_safety_checker(res_vote, test_user[0]) else "not " |
206 | 208 | logger.info(f"Today's data is {not_str}safe.") |
207 | | - |
| 209 | + |
208 | 210 | plot_forecast(train_user, res_vote) |
0 commit comments