Skip to content

Commit 046f63f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e055137 commit 046f63f

1 file changed

Lines changed: 8 additions & 4 deletions

File tree

neural_network/rbfnn.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sklearn.model_selection import train_test_split
2222
from sklearn.preprocessing import OneHotEncoder, StandardScaler
2323

24+
2425
class RBFNN:
2526
def __init__(self, num_centers, gamma):
2627
# Initialize with number of RBF centers and spread parameter (gamma)
@@ -31,8 +32,8 @@ def __init__(self, num_centers, gamma):
3132

3233
def _rbf(self, x, centers):
3334
# Compute Gaussian RBF activations for inputs x given the centers
34-
dist = cdist(x, centers, 'euclidean') # Compute Euclidean distance to centers
35-
return np.exp(-self.gamma * (dist ** 2)) # Apply Gaussian function
35+
dist = cdist(x, centers, "euclidean") # Compute Euclidean distance to centers
36+
return np.exp(-self.gamma * (dist**2)) # Apply Gaussian function
3637

3738
def train(self, x_data, y_data):
3839
# Train the RBFNN
@@ -51,6 +52,7 @@ def predict(self, x):
5152
rbf_activations = self._rbf(x, self.centers)
5253
return rbf_activations.dot(self.weights)
5354

55+
5456
if __name__ == "__main__":
5557
# Load and preprocess Iris dataset
5658
iris = load_iris()
@@ -66,7 +68,9 @@ def predict(self, x):
6668
y_encoded = encoder.fit_transform(y)
6769

6870
# Split data into training and testing sets
69-
x_train, x_test, y_train, y_test = train_test_split(x_scaled, y_encoded, test_size=0.2, random_state=42)
71+
x_train, x_test, y_train, y_test = train_test_split(
72+
x_scaled, y_encoded, test_size=0.2, random_state=42
73+
)
7074

7175
# Initialize and train the RBF Neural Network
7276
rbfnn = RBFNN(num_centers=10, gamma=1.0)
@@ -79,4 +83,4 @@ def predict(self, x):
7983

8084
# Evaluate accuracy
8185
accuracy = accuracy_score(y_true, y_pred)
82-
print(f"Classification Accuracy: {accuracy:.4f}")
86+
print(f"Classification Accuracy: {accuracy:.4f}")

0 commit comments

Comments
 (0)