Skip to content

Commit 0238000

Browse files
authored
Update dimensionality_reduction.py
1 parent c972d1d commit 0238000

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

machine_learning/dimensionality_reduction.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,21 @@ def linear_discriminant_analysis(
179179
"""
180180
assert classes > dimensions
181181

182-
if features.any:
183-
_, eigenvectors = eigh(
184-
covariance_between_classes(features, labels, classes),
185-
covariance_within_classes(features, labels, classes),
186-
)
187-
filtered_eigenvectors = eigenvectors[:, ::-1][:, :dimensions]
188-
svd_matrix, _, _ = np.linalg.svd(filtered_eigenvectors)
189-
filtered_svd_matrix = svd_matrix[:, 0:dimensions]
190-
projected_data = np.dot(filtered_svd_matrix.T, features)
182+
if features.any(): # FIXED: Added missing parentheses
183+
sb = covariance_between_classes(features, labels, classes)
184+
sw = covariance_within_classes(features, labels, classes)
185+
186+
# Solve the generalized eigenvalue problem: Sb v = λ Sw v
187+
eigenvalues, eigenvectors = eigh(sb, sw)
188+
189+
# Sort eigenvectors by eigenvalues (descending)
190+
idx = np.argsort(eigenvalues)[::-1]
191+
eigenvectors = eigenvectors[:, idx]
192+
193+
# Take top "dimensions"
194+
filtered_eigenvectors = eigenvectors[:, :dimensions]
195+
196+
projected_data = np.dot(filtered_eigenvectors.T, features)
191197
logging.info("Linear Discriminant Analysis computed")
192198
return projected_data
193199
else:

0 commit comments

Comments
 (0)