@@ -179,24 +179,15 @@ def linear_discriminant_analysis(
179179 """
180180 assert classes > dimensions
181181
182- if features .any ():
183- sb = covariance_between_classes (features , labels , classes )
184- sw = covariance_within_classes (features , labels , classes )
185-
186- # Add regularization to Sw to avoid singular matrix
187- sw_reg = sw + 1e-6 * np .eye (sw .shape [0 ])
188-
189- # Solve the generalized eigenvalue problem: Sb v = λ Sw v
190- eigenvalues , eigenvectors = eigh (sb , sw_reg )
191-
192- # Sort eigenvectors by eigenvalues (descending)
193- idx = np .argsort (eigenvalues )[::- 1 ]
194- eigenvectors = eigenvectors [:, idx ]
195-
196- # Take top "dimensions" eigenvectors
197- filtered_eigenvectors = eigenvectors [:, :dimensions ]
198-
199- projected_data = np .dot (filtered_eigenvectors .T , features )
182+ if features .any :
183+ _ , eigenvectors = eigh (
184+ covariance_between_classes (features , labels , classes ),
185+ covariance_within_classes (features , labels , classes ),
186+ )
187+ filtered_eigenvectors = eigenvectors [:, ::- 1 ][:, :dimensions ]
188+ svd_matrix , _ , _ = np .linalg .svd (filtered_eigenvectors )
189+ filtered_svd_matrix = svd_matrix [:, 0 :dimensions ]
190+ projected_data = np .dot (filtered_svd_matrix .T , features )
200191 logging .info ("Linear Discriminant Analysis computed" )
201192 return projected_data
202193 else :
0 commit comments