diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 8342dbb88..7bfe2a283 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, Mapping, Optional, Tuple, Sequence import allel # type: ignore @@ -444,19 +445,33 @@ def cohort_diversity_stats( except CacheMiss: debug("access allele counts") - ac = self.snp_allele_counts( - region=region, - site_mask=site_mask, - site_class=site_class, - sample_query=cohort_query, - sample_sets=sample_sets, - cohort_size=cohort_size, - min_cohort_size=min_cohort_size, - max_cohort_size=max_cohort_size, - random_seed=random_seed, - chunks=chunks, - inline_array=inline_array, - ) + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + ac = self.snp_allele_counts( + region=region, + site_mask=site_mask, + site_class=site_class, + sample_query=cohort_query, + sample_sets=sample_sets, + cohort_size=cohort_size, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + chunks=chunks, + inline_array=inline_array, + ) + for w in caught_warnings: + if "Cohort downsampled" in str(w.message): + msg = str(w.message).split(". Set")[0] + "." + warnings.warn(msg, w.category, stacklevel=2) + else: + warnings.warn_explicit( + message=w.message, + category=w.category, + filename=w.filename, + lineno=w.lineno, + source=w.source, + ) debug("compute diversity stats") stats = self._block_jackknife_cohort_diversity_stats(