|
| 1 | +import matplotlib.pyplot as plt |
| 2 | +import numpy as np |
| 3 | +from sklearn.datasets import make_moons |
| 4 | +from sklearn.model_selection import train_test_split |
| 5 | +from sklearn.neural_network import MLPClassifier |
| 6 | + |
| 7 | + |
| 8 | +def compare_activations() -> None: |
| 9 | + """ |
| 10 | + Demonstrates the effect of different activation functions on a simple dataset |
| 11 | + using scikit-learn's MLPClassifier. |
| 12 | +
|
| 13 | + >>> compare_activations() # doctest: +SKIP |
| 14 | + This function trains models and plots decision boundaries for each activation. |
| 15 | + """ |
| 16 | + x, y = make_moons(n_samples=200, noise=0.25, random_state=3) |
| 17 | + x_train, x_test, y_train, y_test = train_test_split( |
| 18 | + x, y, stratify=y, random_state=42 |
| 19 | + ) |
| 20 | + |
| 21 | + activations = ["identity", "logistic", "tanh", "relu"] |
| 22 | + |
| 23 | + for activation in activations: |
| 24 | + mlp = MLPClassifier( |
| 25 | + hidden_layer_sizes=[50], |
| 26 | + max_iter=1000, |
| 27 | + activation=activation, |
| 28 | + random_state=0, |
| 29 | + ) |
| 30 | + mlp.fit(x_train, y_train) |
| 31 | + |
| 32 | + print( |
| 33 | + f"Activation: {activation}, " |
| 34 | + f"Train Accuracy: {mlp.score(x_train, y_train):.2f}, " |
| 35 | + f"Test Accuracy: {mlp.score(x_test, y_test):.2f}" |
| 36 | + ) |
| 37 | + |
| 38 | + # Decision boundary |
| 39 | + x_min, x_max = x[:, 0].min() - 0.5, x[:, 0].max() + 0.5 |
| 40 | + y_min, y_max = x[:, 1].min() - 0.5, x[:, 1].max() + 0.5 |
| 41 | + xx, yy = np.meshgrid( |
| 42 | + np.linspace(x_min, x_max, 200), |
| 43 | + np.linspace(y_min, y_max, 200), |
| 44 | + ) |
| 45 | + z = mlp.predict(np.c_[xx.ravel(), yy.ravel()]) |
| 46 | + z = z.reshape(xx.shape) |
| 47 | + |
| 48 | + plt.contourf(xx, yy, z, alpha=0.3) |
| 49 | + plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, marker="o", label="Train") |
| 50 | + plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, marker="s", label="Test") |
| 51 | + plt.title(f"Activation: {activation}") |
| 52 | + plt.legend() |
| 53 | + plt.show() |
| 54 | + |
| 55 | + |
| 56 | +if __name__ == "__main__": |
| 57 | + compare_activations() |
0 commit comments