Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dask_ml/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def __init__(
self.svd_solver = svd_solver
self.tol = tol
self.iterated_power = iterated_power
# scikit-learn's PCA.__sklearn_tags__ reads this attribute for
# randomized solvers when check_is_fitted calls get_tags.
self.power_iteration_normalizer = "auto"
self.random_state = random_state

def fit(self, X, y=None):
Expand Down
11 changes: 11 additions & 0 deletions tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ def test_pca_randomized_solver():
)


def test_pca_randomized_transform_after_fit():
pca = dd.PCA(n_components=2, svd_solver="randomized", random_state=0)

pca.fit(dX)
assert pca.power_iteration_normalizer == sd.PCA().power_iteration_normalizer

X_r = pca.transform(dX)
assert X_r.shape == (n_samples, 2)
X_r.compute()


def test_no_empty_slice_warning():
if not DASK_2_26_0:
# See https://github.com/dask/dask/pull/6591
Expand Down