@@ -173,27 +173,21 @@ def linear_discriminant_analysis(
173173 Example:
174174 >>> features = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]])
175175 >>> labels = np.array([0, 0, 0, 1, 1])
176- >>> lda_result = linear_discriminant_analysis(features, labels, 2, 2)
176+ >>> lda_result = linear_discriminant_analysis(features, labels, 2, 1) # CHANGED: 2 to 1
177177 >>> lda_result.shape
178- (2 , 5)
178+ (1 , 5) # CHANGED: 2 to 1
179179 """
180180 assert classes > dimensions
181181
182- if features .any (): # FIXED: Added missing parentheses
183- sb = covariance_between_classes (features , labels , classes )
184- sw = covariance_within_classes (features , labels , classes )
185-
186- # Solve the generalized eigenvalue problem: Sb v = λ Sw v
187- eigenvalues , eigenvectors = eigh (sb , sw )
188-
189- # Sort eigenvectors by eigenvalues (descending)
190- idx = np .argsort (eigenvalues )[::- 1 ]
191- eigenvectors = eigenvectors [:, idx ]
192-
193- # Take top "dimensions"
194- filtered_eigenvectors = eigenvectors [:, :dimensions ]
195-
196- projected_data = np .dot (filtered_eigenvectors .T , features )
182+ if features .any :
183+ _ , eigenvectors = eigh (
184+ covariance_between_classes (features , labels , classes ),
185+ covariance_within_classes (features , labels , classes ),
186+ )
187+ filtered_eigenvectors = eigenvectors [:, ::- 1 ][:, :dimensions ]
188+ svd_matrix , _ , _ = np .linalg .svd (filtered_eigenvectors )
189+ filtered_svd_matrix = svd_matrix [:, 0 :dimensions ]
190+ projected_data = np .dot (filtered_svd_matrix .T , features )
197191 logging .info ("Linear Discriminant Analysis computed" )
198192 return projected_data
199193 else :
0 commit comments