Skip to content

Commit 0f1a8b4

Browse files
Added Random Forest Regressor as an additional prediction model.
1 parent a8ad2db commit 0f1a8b4

1 file changed

Lines changed: 23 additions & 1 deletion

File tree

  • machine_learning/forecasting

machine_learning/forecasting/run.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
for the next 3 months sales or something,
1111
u can just adjust it for ur own purpose
1212
"""
13-
1413
from warnings import simplefilter
1514

1615
import numpy as np
1716
import pandas as pd
17+
from sklearn.ensemble import RandomForestRegressor
1818
from sklearn.preprocessing import Normalizer
1919
from sklearn.svm import SVR
2020
from statsmodels.tsa.statespace.sarimax import SARIMAX
@@ -77,6 +77,28 @@ def support_vector_regressor(x_train: list, x_test: list, train_user: list) -> f
7777
y_pred = regressor.predict(x_test)
7878
return float(y_pred[0])
7979

80+
def random_forest_regressor(x_train: list, x_test: list, train_user: list) -> float:
81+
"""
82+
Fourth method: Random Forest Regressor
83+
Random Forest is an ensemble learning method for regression that operates
84+
by constructing a multitude of decision trees at training time and outputting
85+
the mean prediction of the individual trees.
86+
87+
It is more robust than a single decision tree and less prone to overfitting.
88+
Good for capturing nonlinear relationships in data.
89+
90+
input : training data (date, total_event) in list of float
91+
where x = list of set (date and total event)
92+
output : list of total user prediction in float
93+
94+
>>> random_forest_regressor([[5,2],[1,5],[6,2]], [[3,2]], [2,1,4])
95+
2.3333333333333335
96+
"""
97+
model = RandomForestRegressor(n_estimators=100, random_state=42)
98+
model.fit(x_train, train_user)
99+
prediction = model.predict(x_test)
100+
return float(prediction[0])
101+
80102

81103
def interquartile_range_checker(train_user: list) -> float:
82104
"""

0 commit comments

Comments
 (0)