Skip to content

Commit c972d1d

Browse files
authored
Update dimensionality_reduction.py
1 parent 45964b0 commit c972d1d

1 file changed

Lines changed: 9 additions & 18 deletions

File tree

machine_learning/dimensionality_reduction.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -179,24 +179,15 @@ def linear_discriminant_analysis(
179179
"""
180180
assert classes > dimensions
181181

182-
if features.any():
183-
sb = covariance_between_classes(features, labels, classes)
184-
sw = covariance_within_classes(features, labels, classes)
185-
186-
# Add regularization to Sw to avoid singular matrix
187-
sw_reg = sw + 1e-6 * np.eye(sw.shape[0])
188-
189-
# Solve the generalized eigenvalue problem: Sb v = λ Sw v
190-
eigenvalues, eigenvectors = eigh(sb, sw_reg)
191-
192-
# Sort eigenvectors by eigenvalues (descending)
193-
idx = np.argsort(eigenvalues)[::-1]
194-
eigenvectors = eigenvectors[:, idx]
195-
196-
# Take top "dimensions" eigenvectors
197-
filtered_eigenvectors = eigenvectors[:, :dimensions]
198-
199-
projected_data = np.dot(filtered_eigenvectors.T, features)
182+
if features.any:
183+
_, eigenvectors = eigh(
184+
covariance_between_classes(features, labels, classes),
185+
covariance_within_classes(features, labels, classes),
186+
)
187+
filtered_eigenvectors = eigenvectors[:, ::-1][:, :dimensions]
188+
svd_matrix, _, _ = np.linalg.svd(filtered_eigenvectors)
189+
filtered_svd_matrix = svd_matrix[:, 0:dimensions]
190+
projected_data = np.dot(filtered_svd_matrix.T, features)
200191
logging.info("Linear Discriminant Analysis computed")
201192
return projected_data
202193
else:

0 commit comments

Comments
 (0)