Skip to content

Commit b1b80ce

Browse files
authored
Created mlp_activation_comparison.py
Added a new script (mlp_activation_comparison.py) that demonstrates the effect of different activation functions ('relu', 'tanh', 'logistic') on a simple dataset using scikit-learn's MLPClassifier. This helps visualize and understand how activation choices influence model performance.
1 parent beb3cfd commit b1b80ce

1 file changed

Lines changed: 57 additions & 0 deletions

File tree

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from sklearn.datasets import make_moons
4+
from sklearn.model_selection import train_test_split
5+
from sklearn.neural_network import MLPClassifier
6+
7+
8+
def compare_activations() -> None:
9+
"""
10+
Demonstrates the effect of different activation functions on a simple dataset
11+
using scikit-learn's MLPClassifier.
12+
13+
>>> compare_activations() # doctest: +SKIP
14+
This function trains models and plots decision boundaries for each activation.
15+
"""
16+
x, y = make_moons(n_samples=200, noise=0.25, random_state=3)
17+
x_train, x_test, y_train, y_test = train_test_split(
18+
x, y, stratify=y, random_state=42
19+
)
20+
21+
activations = ["identity", "logistic", "tanh", "relu"]
22+
23+
for activation in activations:
24+
mlp = MLPClassifier(
25+
hidden_layer_sizes=[50],
26+
max_iter=1000,
27+
activation=activation,
28+
random_state=0,
29+
)
30+
mlp.fit(x_train, y_train)
31+
32+
print(
33+
f"Activation: {activation}, "
34+
f"Train Accuracy: {mlp.score(x_train, y_train):.2f}, "
35+
f"Test Accuracy: {mlp.score(x_test, y_test):.2f}"
36+
)
37+
38+
# Decision boundary
39+
x_min, x_max = x[:, 0].min() - 0.5, x[:, 0].max() + 0.5
40+
y_min, y_max = x[:, 1].min() - 0.5, x[:, 1].max() + 0.5
41+
xx, yy = np.meshgrid(
42+
np.linspace(x_min, x_max, 200),
43+
np.linspace(y_min, y_max, 200),
44+
)
45+
z = mlp.predict(np.c_[xx.ravel(), yy.ravel()])
46+
z = z.reshape(xx.shape)
47+
48+
plt.contourf(xx, yy, z, alpha=0.3)
49+
plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, marker="o", label="Train")
50+
plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, marker="s", label="Test")
51+
plt.title(f"Activation: {activation}")
52+
plt.legend()
53+
plt.show()
54+
55+
56+
if __name__ == "__main__":
57+
compare_activations()

0 commit comments

Comments
 (0)