@@ -179,15 +179,21 @@ def linear_discriminant_analysis(
179179 """
180180 assert classes > dimensions
181181
182- if features .any :
183- _ , eigenvectors = eigh (
184- covariance_between_classes (features , labels , classes ),
185- covariance_within_classes (features , labels , classes ),
186- )
187- filtered_eigenvectors = eigenvectors [:, ::- 1 ][:, :dimensions ]
188- svd_matrix , _ , _ = np .linalg .svd (filtered_eigenvectors )
189- filtered_svd_matrix = svd_matrix [:, 0 :dimensions ]
190- projected_data = np .dot (filtered_svd_matrix .T , features )
182+ if features .any (): # FIXED: Added missing parentheses
183+ sb = covariance_between_classes (features , labels , classes )
184+ sw = covariance_within_classes (features , labels , classes )
185+
186+ # Solve the generalized eigenvalue problem: Sb v = λ Sw v
187+ eigenvalues , eigenvectors = eigh (sb , sw )
188+
189+ # Sort eigenvectors by eigenvalues (descending)
190+ idx = np .argsort (eigenvalues )[::- 1 ]
191+ eigenvectors = eigenvectors [:, idx ]
192+
193+ # Take top "dimensions"
194+ filtered_eigenvectors = eigenvectors [:, :dimensions ]
195+
196+ projected_data = np .dot (filtered_eigenvectors .T , features )
191197 logging .info ("Linear Discriminant Analysis computed" )
192198 return projected_data
193199 else :
0 commit comments