diff --git a/malariagen_data/anoph/pca.py b/malariagen_data/anoph/pca.py index d18e2842d..260a21d55 100644 --- a/malariagen_data/anoph/pca.py +++ b/malariagen_data/anoph/pca.py @@ -8,6 +8,7 @@ from ..util import CacheMiss, _check_types, _jitter from . import base_params, pca_params, plotly_params +from .sample_metadata import _locate_cohorts from .snp_data import AnophelesSnpData @@ -89,7 +90,7 @@ def pca( ) -> Tuple[pca_params.df_pca, pca_params.evr]: # Change this name if you ever change the behaviour of this function, to # invalidate any previously cached data. - name = "pca_v8" + name = "pca_v9" # Check that either sample_query xor sample_indices are provided. base_params._validate_sample_selection_params( @@ -118,9 +119,12 @@ def pca( sample_query_options=sample_query_options, ) # N.B., we are going to overwrite the sample_indices parameter here. - groups = df_samples.groupby(cohorts, sort=False) + coh_dict = _locate_cohorts( + cohorts=cohorts, data=df_samples, min_cohort_size=0 + ) ix = [] - for _, group in groups: + for _label, loc_coh in coh_dict.items(): + group = df_samples[loc_coh] if len(group) > max_cohort_size: ix.extend( group.sample(