@@ -173,28 +173,45 @@ def linear_discriminant_analysis(
173173 Example:
174174 >>> features = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 6], [3, 4, 5, 6, 7]])
175175 >>> labels = np.array([0, 0, 0, 1, 1])
176- >>> lda_result = linear_discriminant_analysis(features, labels, 2, 1) # CHANGED: 2 to 1
176+ >>> lda_result = linear_discriminant_analysis(features, labels, 2, 1)
177177 >>> lda_result.shape
178- (1, 5) # CHANGED: 2 to 1
178+ (1, 5)
179179 """
180180 assert classes > dimensions
181181
182- if features.any:
183- _, eigenvectors = eigh(
184- covariance_between_classes(features, labels, classes),
185- covariance_within_classes(features, labels, classes),
186- )
187- filtered_eigenvectors = eigenvectors[:, ::-1][:, :dimensions]
188- svd_matrix, _, _ = np.linalg.svd(filtered_eigenvectors)
189- filtered_svd_matrix = svd_matrix[:, 0:dimensions]
190- projected_data = np.dot(filtered_svd_matrix.T, features)
191- logging.info("Linear Discriminant Analysis computed")
192- return projected_data
182+ if features.any():
183+ # Add regularization to avoid singular matrix
184+ sw = covariance_within_classes(features, labels, classes)
185+ sb = covariance_between_classes(features, labels, classes)
186+
187+ # Regularize the within-class covariance matrix
188+ reg_param = 1e-6
189+ sw_reg = sw + reg_param * np.eye(sw.shape[0])
190+
191+ try:
192+ _, eigenvectors = eigh(sb, sw_reg)
193+ filtered_eigenvectors = eigenvectors[:, ::-1][:, :dimensions]
194+ svd_matrix, _, _ = np.linalg.svd(filtered_eigenvectors)
195+ filtered_svd_matrix = svd_matrix[:, 0:dimensions]
196+ projected_data = np.dot(filtered_svd_matrix.T, features)
197+ logging.info("Linear Discriminant Analysis computed")
198+ return projected_data
199+ except np.linalg.LinAlgError:
200+ # Fallback: use pseudoinverse if still singular
201+ try:
202+ sw_pinv = np.linalg.pinv(sw_reg)
203+ _, eigenvectors = eigh(sb, sw_pinv)
204+ filtered_eigenvectors = eigenvectors[:, ::-1][:, :dimensions]
205+ projected_data = np.dot(filtered_eigenvectors.T, features)
206+ logging.info("Linear Discriminant Analysis computed with pseudoinverse")
207+ return projected_data
208+ except np.linalg.LinAlgError:
209+ logging.error("LDA failed: matrix is too ill-conditioned")
210+ raise AssertionError("LDA computation failed")
193211 else:
194212 logging.error("Dataset empty")
195213 raise AssertionError
196214
197-
198215def locally_linear_embedding(
199216 features: np.ndarray, dimensions: int, n_neighbors: int = 12, reg: float = 1e-3
200217) -> np.ndarray:
0 commit comments