Skip to content

Commit e3cca4a

Browse files
authored
Created mlp_activation_comparison.py
Added a new script (mlp_activation_comparison.py) that demonstrates the effect of different activation functions ('relu', 'tanh', 'logistic') on a simple dataset using scikit-learn's MLPClassifier. This helps visualize and understand how activation choices influence model performance.
1 parent beb3cfd commit e3cca4a

1 file changed

Lines changed: 51 additions & 0 deletions

File tree

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from sklearn.datasets import make_moons
4+
from sklearn.model_selection import train_test_split
5+
from sklearn.neural_network import MLPClassifier
6+
7+
8+
# Compare different activation functions in MLPClassifier
9+
def compare_activations():
10+
X, y = make_moons(n_samples=200, noise=0.25, random_state=3)
11+
X_train, X_test, y_train, y_test = train_test_split(
12+
X, y, stratify=y, random_state=42
13+
)
14+
15+
activations = ["identity", "logistic", "tanh", "relu"]
16+
17+
for activation in activations:
18+
mlp = MLPClassifier(
19+
hidden_layer_sizes=[50],
20+
max_iter=1000,
21+
activation=activation,
22+
random_state=0,
23+
)
24+
mlp.fit(X_train, y_train)
25+
26+
print(
27+
f"Activation: {activation}, "
28+
f"Train Accuracy: {mlp.score(X_train, y_train):.2f}, "
29+
f"Test Accuracy: {mlp.score(X_test, y_test):.2f}"
30+
)
31+
32+
# Decision boundary
33+
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
34+
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
35+
xx, yy = np.meshgrid(
36+
np.linspace(x_min, x_max, 200),
37+
np.linspace(y_min, y_max, 200),
38+
)
39+
Z = mlp.predict(np.c_[xx.ravel(), yy.ravel()])
40+
Z = Z.reshape(xx.shape)
41+
42+
plt.contourf(xx, yy, Z, alpha=0.3)
43+
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, marker="o", label="Train")
44+
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, marker="s", label="Test")
45+
plt.title(f"Activation: {activation}")
46+
plt.legend()
47+
plt.show()
48+
49+
50+
if __name__ == "__main__":
51+
compare_activations()

0 commit comments

Comments
 (0)