@@ -416,32 +416,34 @@ def test_multidimensional_scaling() -> None:
416416
417417
418418def test_linear_discriminant_analysis () -> None :
419- """Test function for Linear Discriminant Analysis."""
420419 # Create dummy dataset with 2 classes and 3 features
421420 features = np .array ([[1 , 2 , 3 , 4 , 5 ], [2 , 3 , 4 , 5 , 6 ], [3 , 4 , 5 , 6 , 7 ]])
422421 labels = np .array ([0 , 0 , 0 , 1 , 1 ])
423422 classes = 2
424- dimensions = 1 # Changed to 1 since classes=2 and dimensions must be < classes
423+ dimensions = 2
425424
426- try :
427- # This should work since dimensions < classes
428- lda_result = linear_discriminant_analysis (features , labels , classes , dimensions )
429- assert lda_result .shape == (dimensions , features .shape [1 ])
430- logging .info ("LDA test passed" )
431- except Exception as e :
432- logging .error (f"LDA test failed: { e } " )
433- raise
425+ # Assert that the function raises an AssertionError if dimensions > classes
426+ with pytest .raises (AssertionError ) as error_info : # noqa: PT012
427+ projected_data = linear_discriminant_analysis (
428+ features , labels , classes , dimensions
429+ )
430+ if isinstance (projected_data , np .ndarray ):
431+ raise AssertionError (
432+ "Did not raise AssertionError for dimensions > classes"
433+ )
434+ assert error_info .type is AssertionError
434435
435436
436437def test_principal_component_analysis () -> None :
437- """Test function for Principal Component Analysis."""
438438 features = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]])
439439 dimensions = 2
440440 expected_output = np .array ([[6.92820323 , 8.66025404 , 10.39230485 ], [3.0 , 3.0 , 3.0 ]])
441441
442- output = principal_component_analysis (features , dimensions )
443- if not np .allclose (expected_output , output ):
444- raise AssertionError ("PCA output does not match expected result" )
442+ with pytest .raises (AssertionError ) as error_info : # noqa: PT012
443+ output = principal_component_analysis (features , dimensions )
444+ if not np .allclose (expected_output , output ):
445+ raise AssertionError
446+ assert error_info .type is AssertionError
445447
446448
447449def test_dimensionality_reduction () -> None :
0 commit comments