diff --git a/dask_ml/decomposition/pca.py b/dask_ml/decomposition/pca.py index 76811e680..9510ae9bf 100644 --- a/dask_ml/decomposition/pca.py +++ b/dask_ml/decomposition/pca.py @@ -193,6 +193,9 @@ def __init__( self.svd_solver = svd_solver self.tol = tol self.iterated_power = iterated_power + # scikit-learn's PCA.__sklearn_tags__ reads this attribute for + # randomized solvers when check_is_fitted calls get_tags. + self.power_iteration_normalizer = "auto" self.random_state = random_state def fit(self, X, y=None): diff --git a/tests/test_pca.py b/tests/test_pca.py index df81c6769..bd3785444 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -104,6 +104,17 @@ def test_pca_randomized_solver(): ) +def test_pca_randomized_transform_after_fit(): + pca = dd.PCA(n_components=2, svd_solver="randomized", random_state=0) + + pca.fit(dX) + assert pca.power_iteration_normalizer == sd.PCA().power_iteration_normalizer + + X_r = pca.transform(dX) + assert X_r.shape == (n_samples, 2) + X_r.compute() + + def test_no_empty_slice_warning(): if not DASK_2_26_0: # See https://github.com/dask/dask/pull/6591