Skip to content

Commit 37e9aaf

Browse files
authored
Update dimensionality_reduction.py
1 parent 0238000 commit 37e9aaf

1 file changed

Lines changed: 11 additions & 17 deletions

File tree

machine_learning/dimensionality_reduction.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,27 +173,21 @@ def linear_discriminant_analysis(
173173
Example:
174174
>>> features = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]])
175175
>>> labels = np.array([0, 0, 0, 1, 1])
176-
>>> lda_result = linear_discriminant_analysis(features, labels, 2, 2)
176+
>>> lda_result = linear_discriminant_analysis(features, labels, 2, 1) # CHANGED: 2 to 1
177177
>>> lda_result.shape
178-
(2, 5)
178+
(1, 5) # CHANGED: 2 to 1
179179
"""
180180
assert classes > dimensions
181181

182-
if features.any(): # FIXED: Added missing parentheses
183-
sb = covariance_between_classes(features, labels, classes)
184-
sw = covariance_within_classes(features, labels, classes)
185-
186-
# Solve the generalized eigenvalue problem: Sb v = λ Sw v
187-
eigenvalues, eigenvectors = eigh(sb, sw)
188-
189-
# Sort eigenvectors by eigenvalues (descending)
190-
idx = np.argsort(eigenvalues)[::-1]
191-
eigenvectors = eigenvectors[:, idx]
192-
193-
# Take top "dimensions"
194-
filtered_eigenvectors = eigenvectors[:, :dimensions]
195-
196-
projected_data = np.dot(filtered_eigenvectors.T, features)
182+
if features.any:
183+
_, eigenvectors = eigh(
184+
covariance_between_classes(features, labels, classes),
185+
covariance_within_classes(features, labels, classes),
186+
)
187+
filtered_eigenvectors = eigenvectors[:, ::-1][:, :dimensions]
188+
svd_matrix, _, _ = np.linalg.svd(filtered_eigenvectors)
189+
filtered_svd_matrix = svd_matrix[:, 0:dimensions]
190+
projected_data = np.dot(filtered_svd_matrix.T, features)
197191
logging.info("Linear Discriminant Analysis computed")
198192
return projected_data
199193
else:

0 commit comments

Comments
 (0)