diff --git a/malariagen_data/anoph/ihs.py b/malariagen_data/anoph/ihs.py new file mode 100644 index 000000000..8b141eb88 --- /dev/null +++ b/malariagen_data/anoph/ihs.py @@ -0,0 +1,457 @@ +from typing import Optional, Tuple + +import allel # type: ignore +import bokeh.layouts +import bokeh.models +import bokeh.palettes +import bokeh.plotting +import numpy as np +from numpydoc_decorator import doc # type: ignore + +from .hap_data import AnophelesHapData +from ..util import _check_types, CacheMiss +from . import base_params +from . import ihs_params, gplt_params, hap_params + + +class AnophelesIhsAnalysis( + AnophelesHapData, +): + def __init__( + self, + **kwargs, + ): + # N.B., this class is designed to work cooperatively, and + # so it's important that any remaining parameters are passed + # to the superclass constructor. + super().__init__(**kwargs) + + @_check_types + @doc( + summary="Run iHS GWSS.", + returns=dict( + x="An array containing the window centre point genomic positions.", + ihs="An array with iHS statistic values for each window.", + ), + ) + def ihs_gwss( + self, + contig: base_params.contig, + analysis: hap_params.analysis = base_params.DEFAULT, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + window_size: ihs_params.window_size = ihs_params.window_size_default, + percentiles: ihs_params.percentiles = ihs_params.percentiles_default, + standardize: ihs_params.standardize = True, + standardization_bins: Optional[ihs_params.standardization_bins] = None, + standardization_n_bins: ihs_params.standardization_n_bins = ihs_params.standardization_n_bins_default, + standardization_diagnostics: ihs_params.standardization_diagnostics = False, + filter_min_maf: ihs_params.filter_min_maf = ihs_params.filter_min_maf_default, + compute_min_maf: ihs_params.compute_min_maf = ihs_params.compute_min_maf_default, + min_ehh: ihs_params.min_ehh = ihs_params.min_ehh_default, + max_gap: ihs_params.max_gap = ihs_params.max_gap_default, + gap_scale: ihs_params.gap_scale = ihs_params.gap_scale_default, + include_edges: ihs_params.include_edges = True, + use_threads: ihs_params.use_threads = True, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = ihs_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = ihs_params.max_cohort_size_default, + random_seed: base_params.random_seed = 42, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + ) -> Tuple[np.ndarray, np.ndarray]: + # Change this name if you ever change the behaviour of this function, to + # invalidate any previously cached data. + name = "ihs_gwss_v1" + + params = dict( + contig=contig, + analysis=self._prep_phasing_analysis_param(analysis=analysis), + window_size=window_size, + percentiles=percentiles, + standardize=standardize, + standardization_bins=standardization_bins, + standardization_n_bins=standardization_n_bins, + standardization_diagnostics=standardization_diagnostics, + filter_min_maf=filter_min_maf, + compute_min_maf=compute_min_maf, + min_ehh=min_ehh, + include_edges=include_edges, + max_gap=max_gap, + gap_scale=gap_scale, + use_threads=use_threads, + sample_sets=self._prep_sample_sets_param(sample_sets=sample_sets), + # N.B., do not be tempted to convert this sample query into integer + # indices using _prep_sample_selection_params, because the indices + # are different in the haplotype data. + sample_query=self._prep_sample_query_param(sample_query=sample_query), + sample_query_options=sample_query_options, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + ) + + try: + results = self.results_cache_get(name=name, params=params) + + except CacheMiss: + results = self._ihs_gwss(chunks=chunks, inline_array=inline_array, **params) + self.results_cache_set(name=name, params=params, results=results) + + x = results["x"] + ihs = results["ihs"] + + return x, ihs + + def _ihs_gwss( + self, + *, + contig, + analysis, + sample_sets, + sample_query, + sample_query_options, + window_size, + percentiles, + standardize, + standardization_bins, + standardization_n_bins, + standardization_diagnostics, + filter_min_maf, + compute_min_maf, + min_ehh, + max_gap, + gap_scale, + include_edges, + use_threads, + min_cohort_size, + max_cohort_size, + random_seed, + chunks, + inline_array, + ): + ds_haps = self.haplotypes( + region=contig, + analysis=analysis, + sample_query=sample_query, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + chunks=chunks, + inline_array=inline_array, + ) + + gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data) + with self._dask_progress(desc="Load haplotypes"): + ht = gt.to_haplotypes().compute() + + with self._spinner(desc="Compute IHS"): + ac = ht.count_alleles(max_allele=1) + pos = ds_haps["variant_position"].values + + if filter_min_maf > 0: + af = ac.to_frequencies() + maf = np.min(af, axis=1) + maf_filter = maf > filter_min_maf + ht = ht.compress(maf_filter, axis=0) + pos = pos[maf_filter] + ac = ac[maf_filter] + + # compute iHS + ihs = allel.ihs( + h=ht, + pos=pos, + min_maf=compute_min_maf, + min_ehh=min_ehh, + include_edges=include_edges, + max_gap=max_gap, + gap_scale=gap_scale, + use_threads=use_threads, + ) + + # remove any NaNs + na_mask = ~np.isnan(ihs) + ihs = ihs[na_mask] + pos = pos[na_mask] + ac = ac[na_mask] + + # take absolute value + ihs = np.fabs(ihs) + + if standardize: + ihs, _ = allel.standardize_by_allele_count( + score=ihs, + aac=ac[:, 1], + bins=standardization_bins, + n_bins=standardization_n_bins, + diagnostics=standardization_diagnostics, + ) + + if window_size: + ihs = allel.moving_statistic( + ihs, statistic=np.percentile, size=window_size, q=percentiles + ) + pos = allel.moving_statistic(pos, statistic=np.mean, size=window_size) + + results = dict(x=pos, ihs=ihs) + + return results + + @_check_types + @doc( + summary="Run and plot iHS GWSS data.", + ) + def plot_ihs_gwss_track( + self, + contig: base_params.contig, + analysis: hap_params.analysis = base_params.DEFAULT, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + window_size: ihs_params.window_size = ihs_params.window_size_default, + percentiles: ihs_params.percentiles = ihs_params.percentiles_default, + standardize: ihs_params.standardize = True, + standardization_bins: Optional[ihs_params.standardization_bins] = None, + standardization_n_bins: ihs_params.standardization_n_bins = ihs_params.standardization_n_bins_default, + standardization_diagnostics: ihs_params.standardization_diagnostics = False, + filter_min_maf: ihs_params.filter_min_maf = ihs_params.filter_min_maf_default, + compute_min_maf: ihs_params.compute_min_maf = ihs_params.compute_min_maf_default, + min_ehh: ihs_params.min_ehh = ihs_params.min_ehh_default, + max_gap: ihs_params.max_gap = ihs_params.max_gap_default, + gap_scale: ihs_params.gap_scale = ihs_params.gap_scale_default, + include_edges: ihs_params.include_edges = True, + use_threads: ihs_params.use_threads = True, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = ihs_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = ihs_params.max_cohort_size_default, + random_seed: base_params.random_seed = 42, + palette: ihs_params.palette = ihs_params.palette_default, + title: Optional[gplt_params.title] = None, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + height: gplt_params.height = 200, + show: gplt_params.show = True, + x_range: Optional[gplt_params.x_range] = None, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + ) -> gplt_params.optional_figure: + # compute ihs + x, ihs = self.ihs_gwss( + contig=contig, + analysis=analysis, + window_size=window_size, + percentiles=percentiles, + standardize=standardize, + standardization_bins=standardization_bins, + standardization_n_bins=standardization_n_bins, + standardization_diagnostics=standardization_diagnostics, + filter_min_maf=filter_min_maf, + compute_min_maf=compute_min_maf, + min_ehh=min_ehh, + max_gap=max_gap, + gap_scale=gap_scale, + include_edges=include_edges, + use_threads=use_threads, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + sample_query=sample_query, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + random_seed=random_seed, + chunks=chunks, + inline_array=inline_array, + ) + + # determine X axis range + x_min = x[0] + x_max = x[-1] + if x_range is None: + x_range = bokeh.models.Range1d(x_min, x_max, bounds="auto") + + # create a figure + xwheel_zoom = bokeh.models.WheelZoomTool( + dimensions="width", maintain_focus=False + ) + if title is None: + title = sample_query + fig = bokeh.plotting.figure( + title=title, + tools=[ + "xpan", + "xzoom_in", + "xzoom_out", + xwheel_zoom, + "reset", + "save", + "crosshair", + ], + active_inspect=None, + active_scroll=xwheel_zoom, + active_drag="xpan", + sizing_mode=sizing_mode, + width=width, + height=height, + toolbar_location="above", + x_range=x_range, + output_backend=output_backend, + ) + + if window_size: + if isinstance(percentiles, int): + percentiles = (percentiles,) + # Ensure percentiles are sorted so that colors make sense. + percentiles = tuple(sorted(percentiles)) + + # add an empty dimension to ihs array if 1D + ihs = np.reshape(ihs, (ihs.shape[0], -1)) + + # select the base color palette to work from + base_palette = bokeh.palettes.all_palettes[palette][8] + + # keep only enough colours to plot the IHS tracks + bokeh_palette = base_palette[: ihs.shape[1]] + + # reverse the colors so darkest is last + bokeh_palette = bokeh_palette[::-1] + + # plot IHS tracks + for i in range(ihs.shape[1]): + ihs_perc = ihs[:, i] + color = bokeh_palette[i] + + # plot ihs + fig.circle( + x=x, + y=ihs_perc, + size=4, + line_width=0, + line_color=color, + fill_color=color, + ) + + # tidy up the plot + fig.yaxis.axis_label = "ihs" + self._bokeh_style_genome_xaxis(fig, contig) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return None + else: + return fig + + @doc( + summary="Run and plot iHS GWSS data.", + ) + def plot_ihs_gwss( + self, + contig: base_params.contig, + analysis: hap_params.analysis = base_params.DEFAULT, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + window_size: ihs_params.window_size = ihs_params.window_size_default, + percentiles: ihs_params.percentiles = ihs_params.percentiles_default, + standardize: ihs_params.standardize = True, + standardization_bins: Optional[ihs_params.standardization_bins] = None, + standardization_n_bins: ihs_params.standardization_n_bins = ihs_params.standardization_n_bins_default, + standardization_diagnostics: ihs_params.standardization_diagnostics = False, + filter_min_maf: ihs_params.filter_min_maf = ihs_params.filter_min_maf_default, + compute_min_maf: ihs_params.compute_min_maf = ihs_params.compute_min_maf_default, + min_ehh: ihs_params.min_ehh = ihs_params.min_ehh_default, + max_gap: ihs_params.max_gap = ihs_params.max_gap_default, + gap_scale: ihs_params.gap_scale = ihs_params.gap_scale_default, + include_edges: ihs_params.include_edges = True, + use_threads: ihs_params.use_threads = True, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = ihs_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = ihs_params.max_cohort_size_default, + random_seed: base_params.random_seed = 42, + palette: ihs_params.palette = ihs_params.palette_default, + title: Optional[gplt_params.title] = None, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + track_height: gplt_params.track_height = 170, + genes_height: gplt_params.genes_height = gplt_params.genes_height_default, + show: gplt_params.show = True, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + gene_labels: Optional[gplt_params.gene_labels] = None, + gene_labelset: Optional[gplt_params.gene_labelset] = None, + ) -> gplt_params.optional_figure: + # gwss track + fig1 = self.plot_ihs_gwss_track( + contig=contig, + analysis=analysis, + sample_sets=sample_sets, + sample_query=sample_query, + sample_query_options=sample_query_options, + window_size=window_size, + percentiles=percentiles, + palette=palette, + standardize=standardize, + standardization_bins=standardization_bins, + standardization_n_bins=standardization_n_bins, + standardization_diagnostics=standardization_diagnostics, + filter_min_maf=filter_min_maf, + compute_min_maf=compute_min_maf, + min_ehh=min_ehh, + max_gap=max_gap, + gap_scale=gap_scale, + include_edges=include_edges, + use_threads=use_threads, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + title=title, + sizing_mode=sizing_mode, + width=width, + height=track_height, + show=False, + output_backend=output_backend, + chunks=chunks, + inline_array=inline_array, + ) + + fig1.xaxis.visible = False + + # plot genes + fig2 = self.plot_genes( + region=contig, + sizing_mode=sizing_mode, + width=width, + height=genes_height, + x_range=fig1.x_range, + show=False, + output_backend=output_backend, + gene_labels=gene_labels, + gene_labelset=gene_labelset, + ) + + # combine plots into a single figure + fig = bokeh.layouts.gridplot( + [fig1, fig2], + ncols=1, + toolbar_location="above", + merge_tools=True, + sizing_mode=sizing_mode, + ) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return None + else: + return fig diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 8342dbb88..d2c66f461 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import Any, Dict, Mapping, Optional, Tuple, Sequence import allel # type: ignore @@ -11,7 +12,6 @@ import plotly.graph_objects as go # type: ignore from numpydoc_decorator import doc # type: ignore -from .anoph.safe_query import validate_query from .anoph import ( aim_params, @@ -19,8 +19,9 @@ dash_params, gplt_params, hapnet_params, - ihs_params, + het_params, plotly_params, + xpehh_params, ) from .anoph.karyotype import AnophelesKaryotypeAnalysis from .anoph.aim_data import AnophelesAimData @@ -35,8 +36,6 @@ from .anoph.sample_metadata import AnophelesSampleMetadata from .anoph.snp_data import AnophelesSnpData from .anoph.to_plink import PlinkConverter -from .anoph.ld import AnophelesLdAnalysis -from .anoph.to_vcf import SnpVcfExporter from .anoph.g123 import AnophelesG123Analysis from .anoph.fst import AnophelesFstAnalysis from .anoph.h12 import AnophelesH12Analysis @@ -46,13 +45,13 @@ from .anoph.hapclust import AnophelesHapClustAnalysis from .anoph.describe import AnophelesDescribe from .anoph.dipclust import AnophelesDipClustAnalysis -from .anoph.heterozygosity import AnophelesHetAnalysis -from .anoph.xpehh import AnophelesXpehhAnalysis +from .anoph.ihs import AnophelesIhsAnalysis from .util import ( CacheMiss, - Region, # noqa: F401 (re-exported via __init__.py) + Region, _check_types, _jackknife_ci, + _parse_single_region, _plotly_discrete_legend, ) @@ -81,18 +80,15 @@ class AnophelesDataResource( AnophelesDipClustAnalysis, AnophelesHapClustAnalysis, - AnophelesXpehhAnalysis, + AnophelesIhsAnalysis, AnophelesH1XAnalysis, AnophelesH12Analysis, AnophelesG123Analysis, AnophelesFstAnalysis, - AnophelesHetAnalysis, AnophelesHapFrequencyAnalysis, AnophelesDistanceAnalysis, AnophelesPca, PlinkConverter, - AnophelesLdAnalysis, - SnpVcfExporter, AnophelesIgv, AnophelesKaryotypeAnalysis, AnophelesAimData, @@ -185,26 +181,10 @@ def __init__( surveillance_use_only=surveillance_use_only, ) - def _get_ihs_gwss_cache_name(self): - """Safely resolve the ihs gwss cache name. - - Supports class attribute, property, or legacy method override. - Falls back to the default "ihs_gwss_v1" if resolution fails. - - See also: https://github.com/malariagen/malariagen-data-python/issues/1151 - """ - try: - name = self._ihs_gwss_cache_name - # Handle legacy case where _ihs_gwss_cache_name might be a - # callable method rather than a property or class attribute. - if callable(name): - name = name() - if isinstance(name, str) and len(name) > 0: - return name - except NotImplementedError: - pass - # Fallback to default. - return "ihs_gwss_v1" + @property + @abstractmethod + def _xpehh_gwss_cache_name(self): + raise NotImplementedError("Must override _xpehh_gwss_cache_name") @staticmethod def _make_gene_cnv_label(gene_id, gene_name, cnv_type): @@ -214,6 +194,683 @@ def _make_gene_cnv_label(gene_id, gene_name, cnv_type): label += f" {cnv_type}" return label + @staticmethod + def _roh_hmm_predict( + *, + windows, + counts, + phet_roh, + phet_nonroh, + transition, + window_size, + sample_id, + contig, + ): + # This implementation is based on scikit-allel, but modified to use + # moving window computation of het counts. + from allel.stats.misc import tabulate_state_blocks # type: ignore + from allel.stats.roh import _hmm_derive_transition_matrix # type: ignore + + # Protopunica is pomegranate frozen at version 0.14.8, wich is compatible + # with the code here. Also protopunica has binary wheels available from + # PyPI and so installs much faster. + from protopunica import HiddenMarkovModel, PoissonDistribution # type: ignore + + # het probabilities + het_px = np.concatenate([(phet_roh,), phet_nonroh]) + + # start probabilities (all equal) + start_prob = np.repeat(1 / het_px.size, het_px.size) + + # transition between underlying states + transition_mx = _hmm_derive_transition_matrix(transition, het_px.size) + + # emission probability distribution + dists = [PoissonDistribution(x * window_size) for x in het_px] + + # set up model + # noinspection PyArgumentList + model = HiddenMarkovModel.from_matrix( + transition_probabilities=transition_mx, + distributions=dists, + starts=start_prob, + ) + + # predict hidden states + prediction = np.array(model.predict(counts[:, None])) + + # tabulate runs of homozygosity (state 0) + # noinspection PyTypeChecker + df_blocks = tabulate_state_blocks(prediction, states=list(range(len(het_px)))) + df_roh = df_blocks[(df_blocks["state"] == 0)].reset_index(drop=True) + + # adapt the dataframe for ROH + df_roh["sample_id"] = sample_id + df_roh["contig"] = contig + df_roh["roh_start"] = df_roh["start_ridx"].apply(lambda y: windows[y, 0]) + df_roh["roh_stop"] = df_roh["stop_lidx"].apply(lambda y: windows[y, 1]) + df_roh["roh_length"] = df_roh["roh_stop"] - df_roh["roh_start"] + df_roh.rename(columns={"is_marginal": "roh_is_marginal"}, inplace=True) + + return df_roh[ + [ + "sample_id", + "contig", + "roh_start", + "roh_stop", + "roh_length", + "roh_is_marginal", + ] + ] + + def _plot_heterozygosity_track( + self, + *, + sample_id, + sample_set, + windows, + counts, + region: Region, + window_size, + y_max, + sizing_mode, + width, + height, + circle_kwargs, + show, + x_range, + output_backend, + ): + debug = self._log.debug + + # pos axis + window_pos = windows.mean(axis=1) + + # het axis + window_het = counts / window_size + + # determine plotting limits + if x_range is None: + if region.start is not None: + x_min = region.start + else: + x_min = 0 + if region.end is not None: + x_max = region.end + else: + x_max = len(self.genome_sequence(region.contig)) + x_range = bokeh.models.Range1d(x_min, x_max, bounds="auto") + + debug("create a figure for plotting") + xwheel_zoom = bokeh.models.WheelZoomTool( + dimensions="width", maintain_focus=False + ) + fig = bokeh.plotting.figure( + title=f"{sample_id} ({sample_set})", + tools=["xpan", "xzoom_in", "xzoom_out", xwheel_zoom, "reset", "save"], + active_scroll=xwheel_zoom, + active_drag="xpan", + sizing_mode=sizing_mode, + width=width, + height=height, + toolbar_location="above", + x_range=x_range, + y_range=(0, y_max), + output_backend=output_backend, + ) + + debug("plot heterozygosity") + data = pd.DataFrame( + { + "position": window_pos, + "heterozygosity": window_het, + } + ) + if circle_kwargs is None: + circle_kwargs = dict() + circle_kwargs.setdefault("size", 4) + circle_kwargs.setdefault("line_width", 0) + fig.circle(x="position", y="heterozygosity", source=data, **circle_kwargs) + + debug("tidy up the plot") + fig.yaxis.axis_label = "Heterozygosity (bp⁻¹)" + self._bokeh_style_genome_xaxis(fig, region.contig) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + + return fig + + @_check_types + @doc( + summary="Plot windowed heterozygosity for a single sample over a genome region.", + ) + def plot_heterozygosity_track( + self, + sample: base_params.sample, + region: base_params.region, + window_size: het_params.window_size = het_params.window_size_default, + y_max: het_params.y_max = het_params.y_max_default, + circle_kwargs: Optional[gplt_params.circle_kwargs] = None, + site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, + sample_set: Optional[base_params.sample_set] = None, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + height: gplt_params.height = 200, + show: gplt_params.show = True, + x_range: Optional[gplt_params.x_range] = None, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + ) -> gplt_params.optional_figure: + debug = self._log.debug + + # Normalise parameters. + region_prepped: Region = _parse_single_region(self, region) + del region + + debug("compute windowed heterozygosity") + sample_id, sample_set, windows, counts = self._sample_count_het( + sample=sample, + region=region_prepped, + site_mask=site_mask, + window_size=window_size, + sample_set=sample_set, + chunks=chunks, + inline_array=inline_array, + ) + + debug("plot heterozygosity") + fig = self._plot_heterozygosity_track( + sample_id=sample_id, + sample_set=sample_set, + windows=windows, + counts=counts, + region=region_prepped, + window_size=window_size, + y_max=y_max, + sizing_mode=sizing_mode, + width=width, + height=height, + circle_kwargs=circle_kwargs, + show=show, + x_range=x_range, + output_backend=output_backend, + ) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return None + else: + return fig + + @_check_types + @doc( + summary="Plot windowed heterozygosity for a single sample over a genome region.", + ) + def plot_heterozygosity( + self, + sample: base_params.samples, + region: base_params.region, + window_size: het_params.window_size = het_params.window_size_default, + y_max: het_params.y_max = het_params.y_max_default, + circle_kwargs: Optional[gplt_params.circle_kwargs] = None, + site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, + sample_set: Optional[base_params.sample_set] = None, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + track_height: gplt_params.track_height = 170, + genes_height: gplt_params.genes_height = gplt_params.genes_height_default, + show: gplt_params.show = True, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + gene_labels: Optional[gplt_params.gene_labels] = None, + gene_labelset: Optional[gplt_params.gene_labelset] = None, + ) -> gplt_params.optional_figure: + debug = self._log.debug + + # normalise to support multiple samples + if isinstance(sample, (list, tuple)): + samples = sample + else: + samples = [sample] + + debug("plot first sample track") + fig1 = self.plot_heterozygosity_track( + sample=samples[0], + sample_set=sample_set, + region=region, + site_mask=site_mask, + window_size=window_size, + y_max=y_max, + sizing_mode=sizing_mode, + width=width, + height=track_height, + circle_kwargs=circle_kwargs, + show=False, + output_backend=output_backend, + chunks=chunks, + inline_array=inline_array, + ) + fig1.xaxis.visible = False + figs = [fig1] + + debug("plot remaining sample tracks") + for sample in samples[1:]: + fig_het = self.plot_heterozygosity_track( + sample=sample, + sample_set=sample_set, + region=region, + site_mask=site_mask, + window_size=window_size, + y_max=y_max, + sizing_mode=sizing_mode, + width=width, + height=track_height, + circle_kwargs=circle_kwargs, + show=False, + x_range=fig1.x_range, + output_backend=output_backend, + chunks=chunks, + inline_array=inline_array, + ) + fig_het.xaxis.visible = False + figs.append(fig_het) + + debug("plot genes track") + fig_genes = self.plot_genes( + region=region, + sizing_mode=sizing_mode, + width=width, + height=genes_height, + x_range=fig1.x_range, + show=False, + output_backend=output_backend, + gene_labels=gene_labels, + gene_labelset=gene_labelset, + ) + figs.append(fig_genes) + + debug("combine plots into a single figure") + fig_all = bokeh.layouts.gridplot( + figs, + ncols=1, + toolbar_location="above", + merge_tools=True, + sizing_mode=sizing_mode, + ) + + if show: # pragma: no cover + bokeh.plotting.show(fig_all) + return None + else: + return fig_all + + def _sample_count_het( + self, + sample: base_params.sample, + region: Region, + site_mask: Optional[base_params.site_mask], + window_size: het_params.window_size, + sample_set: Optional[base_params.sample_set], + chunks: base_params.chunks, + inline_array: base_params.inline_array, + ): + debug = self._log.debug + + debug("access sample metadata, look up sample") + sample_rec = self.lookup_sample(sample=sample, sample_set=sample_set) + sample_id = sample_rec.name # sample_id + sample_set = sample_rec["sample_set"] + + debug("access SNPs, select data for sample") + ds_snps = self.snp_calls( + region=region, + sample_sets=sample_set, + site_mask=site_mask, + chunks=chunks, + inline_array=inline_array, + ) + ds_snps_sample = ds_snps.set_index(samples="sample_id").sel(samples=sample_id) + + # snp positions + pos = ds_snps_sample["variant_position"].values + + # access genotypes + gt = allel.GenotypeDaskVector(ds_snps_sample["call_genotype"].data) + + # compute het + with self._dask_progress(desc="Compute heterozygous genotypes"): + is_het = gt.is_het().compute() + + # compute window coordinates + windows = allel.moving_statistic( + values=pos, + statistic=lambda x: [x[0], x[-1]], + size=window_size, + ) + + # compute windowed heterozygosity + counts = allel.moving_statistic( + values=is_het, + statistic=np.sum, + size=window_size, + ) + + return sample_id, sample_set, windows, counts + + @property + @abstractmethod + def _roh_hmm_cache_name(self): + raise NotImplementedError("Must override _roh_hmm_cache_name") + + @_check_types + @doc( + summary="Infer runs of homozygosity for a single sample over a genome region.", + ) + def roh_hmm( + self, + sample: base_params.sample, + region: base_params.region, + window_size: het_params.window_size = het_params.window_size_default, + site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, + sample_set: Optional[base_params.sample_set] = None, + phet_roh: het_params.phet_roh = het_params.phet_roh_default, + phet_nonroh: het_params.phet_nonroh = het_params.phet_nonroh_default, + transition: het_params.transition = het_params.transition_default, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + ) -> het_params.df_roh: + debug = self._log.debug + + resolved_region: Region = _parse_single_region(self, region) + + name = self._roh_hmm_cache_name + + params = dict( + sample=sample, + region=region, + window_size=window_size, + site_mask=site_mask, + sample_set=sample_set, + phet_roh=phet_roh, + phet_nonroh=phet_nonroh, + transition=transition, + chunks=chunks, + inline_array=inline_array, + ) + + del region + + try: + # Load cached numeric data, adding str / obj data again. + results = self.results_cache_get(name=name, params=params) + + # Reconstruct dataframe + df_roh = pd.DataFrame( + { + "roh_start": results["roh_start"], + "roh_stop": results["roh_stop"], + "roh_length": results["roh_length"], + "roh_is_marginal": results["roh_is_marginal"], + } + ) + + df_roh["sample_id"] = sample + df_roh["contig"] = resolved_region.contig + + except CacheMiss: + debug("compute windowed heterozygosity") + sample_id, sample_set, windows, counts = self._sample_count_het( + sample=sample, + region=resolved_region, + site_mask=site_mask, + window_size=window_size, + sample_set=sample_set, + chunks=chunks, + inline_array=inline_array, + ) + + debug("compute runs of homozygosity") + df_roh = self._roh_hmm_predict( + windows=windows, + counts=counts, + phet_roh=phet_roh, + phet_nonroh=phet_nonroh, + transition=transition, + window_size=window_size, + sample_id=sample_id, + contig=resolved_region.contig, + ) + + # Specify numeric columns to save (saving obj - sample ID and contig - breaks the save. + columns_to_save = [ + "roh_start", + "roh_stop", + "roh_length", + "roh_is_marginal", + ] + + self.results_cache_set( + name=name, + params=params, + results={col: df_roh[col].to_numpy() for col in columns_to_save}, + ) + + return df_roh + + @_check_types + @doc( + summary="Plot a runs of homozygosity track.", + ) + def plot_roh_track( + self, + df_roh: het_params.df_roh, + region: base_params.region, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + height: gplt_params.height = 80, + show: gplt_params.show = True, + x_range: Optional[gplt_params.x_range] = None, + title: Optional[gplt_params.title] = None, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + ) -> gplt_params.optional_figure: + debug = self._log.debug + + debug("handle region parameter - this determines the genome region to plot") + resolved_region: Region = _parse_single_region(self, region) + del region + contig = resolved_region.contig + start = resolved_region.start + end = resolved_region.end + if start is None: + start = 0 + if end is None: + end = len(self.genome_sequence(contig)) + + debug("define x axis range") + if x_range is None: + x_range = bokeh.models.Range1d(start, end, bounds="auto") + + debug( + "we're going to plot each gene as a rectangle, so add some additional columns" + ) + data = df_roh.copy() + data["bottom"] = 0.2 + data["top"] = 0.8 + + debug("make a figure") + xwheel_zoom = bokeh.models.WheelZoomTool( + dimensions="width", maintain_focus=False + ) + fig = bokeh.plotting.figure( + title=title, + sizing_mode=sizing_mode, + width=width, + height=height, + tools=[ + "xpan", + "xzoom_in", + "xzoom_out", + xwheel_zoom, + "reset", + "tap", + "hover", + "save", + ], + active_scroll=xwheel_zoom, + active_drag="xpan", + x_range=x_range, + y_range=bokeh.models.Range1d(0, 1), + output_backend=output_backend, + ) + + debug("now plot the ROH as rectangles") + fig.quad( + bottom="bottom", + top="top", + left="roh_start", + right="roh_stop", + source=data, + line_width=1, + fill_alpha=0.5, + ) + + debug("tidy up the plot") + fig.ygrid.visible = False + fig.yaxis.ticker = [] + fig.yaxis.axis_label = "RoH" + self._bokeh_style_genome_xaxis(fig, resolved_region.contig) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return None + else: + return fig + + @_check_types + @doc( + summary=""" + Plot windowed heterozygosity and inferred runs of homozygosity for a + single sample over a genome region. + """, + ) + def plot_roh( + self, + sample: base_params.sample, + region: base_params.region, + window_size: het_params.window_size = het_params.window_size_default, + site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, + sample_set: Optional[base_params.sample_set] = None, + phet_roh: het_params.phet_roh = het_params.phet_roh_default, + phet_nonroh: het_params.phet_nonroh = het_params.phet_nonroh_default, + transition: het_params.transition = het_params.transition_default, + y_max: het_params.y_max = het_params.y_max_default, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + heterozygosity_height: gplt_params.height = 170, + roh_height: gplt_params.height = 40, + genes_height: gplt_params.genes_height = gplt_params.genes_height_default, + circle_kwargs: Optional[gplt_params.circle_kwargs] = None, + show: gplt_params.show = True, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + gene_labels: Optional[gplt_params.gene_labels] = None, + gene_labelset: Optional[gplt_params.gene_labelset] = None, + ) -> gplt_params.optional_figure: + debug = self._log.debug + + resolved_region: Region = _parse_single_region(self, region) + del region + + debug("compute windowed heterozygosity") + sample_id, sample_set, windows, counts = self._sample_count_het( + sample=sample, + region=resolved_region, + site_mask=site_mask, + window_size=window_size, + sample_set=sample_set, + chunks=chunks, + inline_array=inline_array, + ) + + debug("plot_heterozygosity track") + fig_het = self._plot_heterozygosity_track( + sample_id=sample_id, + sample_set=sample_set, + windows=windows, + counts=counts, + region=resolved_region, + window_size=window_size, + y_max=y_max, + sizing_mode=sizing_mode, + width=width, + height=heterozygosity_height, + circle_kwargs=circle_kwargs, + show=False, + x_range=None, + output_backend=output_backend, + ) + fig_het.xaxis.visible = False + figs = [fig_het] + + debug("compute runs of homozygosity") + df_roh = self._roh_hmm_predict( + windows=windows, + counts=counts, + phet_roh=phet_roh, + phet_nonroh=phet_nonroh, + transition=transition, + window_size=window_size, + sample_id=sample_id, + contig=resolved_region.contig, + ) + + debug("plot roh track") + fig_roh = self.plot_roh_track( + df_roh, + region=resolved_region, + sizing_mode=sizing_mode, + width=width, + height=roh_height, + show=False, + x_range=fig_het.x_range, + output_backend=output_backend, + ) + fig_roh.xaxis.visible = False + figs.append(fig_roh) + + debug("plot genes track") + fig_genes = self.plot_genes( + region=resolved_region, + sizing_mode=sizing_mode, + width=width, + height=genes_height, + x_range=fig_het.x_range, + show=False, + output_backend=output_backend, + gene_labels=gene_labels, + gene_labelset=gene_labelset, + ) + figs.append(fig_genes) + + debug("combine plots into a single figure") + fig_all = bokeh.layouts.gridplot( + figs, + ncols=1, + toolbar_location="above", + merge_tools=True, + sizing_mode=sizing_mode, + ) + + if show: # pragma: no cover + bokeh.plotting.show(fig_all) + return None + else: + return fig_all + def _block_jackknife_cohort_diversity_stats( self, *, cohort_label, ac, n_jack, confidence_level ): @@ -268,11 +925,7 @@ def _block_jackknife_cohort_diversity_stats( block_stop = block_start + block_length loc_j = np.ones(n_sites, dtype=bool) loc_j[block_start:block_stop] = False - if np.count_nonzero(loc_j) != n_sites_j: - raise RuntimeError( - f"Internal error in jackknife resampling: expected {n_sites_j} " - f"sites after block deletion, got {np.count_nonzero(loc_j)}" - ) + assert np.count_nonzero(loc_j) == n_sites_j # resample data and compute statistics @@ -391,10 +1044,6 @@ def cohort_diversity_stats( ) -> pd.Series: debug = self._log.debug - # Change this name if you ever change the behaviour of this function, to - # invalidate any previously cached data. - name = "cohort_diversity_stats_v1" - debug("process cohort parameter") cohort_query = None if isinstance(cohort, str): @@ -413,11 +1062,14 @@ def cohort_diversity_stats( cohort_label, cohort_query = cohort else: - raise TypeError(f"invalid cohort parameter: {cohort!r}") + raise TypeError(r"invalid cohort parameter: {cohort!r}") + + # Change this name if you ever change the behaviour of this function, + # to invalidate any previously cached data. + name = "cohort_diversity_stats_v1" params = dict( - cohort_label=cohort_label, - cohort_query=cohort_query, + cohort=cohort, cohort_size=cohort_size, region=region, min_cohort_size=min_cohort_size, @@ -432,15 +1084,11 @@ def cohort_diversity_stats( inline_array=inline_array, ) - # Try to retrieve results from the cache. try: - results = self.results_cache_get(name=name, params=params) - stats = { - key: value.item() - if isinstance(value, np.ndarray) and value.shape == () - else value - for key, value in results.items() - } + stats = self.results_cache_get(name=name, params=params) + + # Reconstruct series from cached numeric results. + stats = pd.Series({key: values[0] for key, values in stats.items()}) except CacheMiss: debug("access allele counts") @@ -466,8 +1114,17 @@ def cohort_diversity_stats( confidence_level=confidence_level, ) - cache_results = {key: np.asarray(value) for key, value in stats.items()} - self.results_cache_set(name=name, params=params, results=cache_results) + self.results_cache_set( + name=name, + params=params, + results={ + key: np.asarray([value]) + for key, value in stats.items() + if key != "cohort" + }, + ) + + stats["cohort"] = cohort_label debug("compute some extra cohort variables") df_samples = self.sample_metadata( @@ -703,61 +1360,159 @@ def plot_diversity_stats( fig2.show(renderer=renderer) fig3.show(renderer=renderer) fig4.show(renderer=renderer) - return (fig1, fig2, fig3, fig4) + return None + else: + return (fig1, fig2, fig3, fig4) @_check_types @doc( - summary="Run iHS GWSS.", + summary="Run and plot XP-EHH GWSS data.", + ) + def plot_xpehh_gwss( + self, + contig: base_params.contig, + analysis: hap_params.analysis = base_params.DEFAULT, + sample_sets: Optional[base_params.sample_sets] = None, + cohort1_query: Optional[base_params.sample_query] = None, + cohort2_query: Optional[base_params.sample_query] = None, + sample_query_options: Optional[base_params.sample_query_options] = None, + window_size: xpehh_params.window_size = xpehh_params.window_size_default, + percentiles: xpehh_params.percentiles = xpehh_params.percentiles_default, + filter_min_maf: xpehh_params.filter_min_maf = xpehh_params.filter_min_maf_default, + map_pos: Optional[xpehh_params.map_pos] = None, + min_ehh: xpehh_params.min_ehh = xpehh_params.min_ehh_default, + max_gap: xpehh_params.max_gap = xpehh_params.max_gap_default, + gap_scale: xpehh_params.gap_scale = xpehh_params.gap_scale_default, + include_edges: xpehh_params.include_edges = True, + use_threads: xpehh_params.use_threads = True, + min_cohort_size: Optional[ + base_params.min_cohort_size + ] = xpehh_params.min_cohort_size_default, + max_cohort_size: Optional[ + base_params.max_cohort_size + ] = xpehh_params.max_cohort_size_default, + random_seed: base_params.random_seed = 42, + palette: xpehh_params.palette = xpehh_params.palette_default, + title: Optional[gplt_params.title] = None, + sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, + width: gplt_params.width = gplt_params.width_default, + track_height: gplt_params.track_height = 170, + genes_height: gplt_params.genes_height = gplt_params.genes_height_default, + show: gplt_params.show = True, + output_backend: gplt_params.output_backend = gplt_params.output_backend_default, + chunks: base_params.chunks = base_params.native_chunks, + inline_array: base_params.inline_array = base_params.inline_array_default, + gene_labels: Optional[gplt_params.gene_labels] = None, + gene_labelset: Optional[gplt_params.gene_labelset] = None, + ) -> gplt_params.optional_figure: + # gwss track + fig1 = self.plot_xpehh_gwss_track( + contig=contig, + analysis=analysis, + sample_sets=sample_sets, + cohort1_query=cohort1_query, + cohort2_query=cohort2_query, + sample_query_options=sample_query_options, + window_size=window_size, + percentiles=percentiles, + palette=palette, + filter_min_maf=filter_min_maf, + map_pos=map_pos, + min_ehh=min_ehh, + max_gap=max_gap, + gap_scale=gap_scale, + include_edges=include_edges, + use_threads=use_threads, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + title=title, + sizing_mode=sizing_mode, + width=width, + height=track_height, + show=False, + x_range=None, + output_backend=output_backend, + chunks=chunks, + inline_array=inline_array, + ) + + fig1.xaxis.visible = False + + # plot genes + fig2 = self.plot_genes( + region=contig, + sizing_mode=sizing_mode, + width=width, + height=genes_height, + x_range=fig1.x_range, + show=False, + output_backend=output_backend, + gene_labels=gene_labels, + gene_labelset=gene_labelset, + ) + + # combine plots into a single figure + fig = bokeh.layouts.gridplot( + [fig1, fig2], + ncols=1, + toolbar_location="above", + merge_tools=True, + sizing_mode=sizing_mode, + ) + + if show: # pragma: no cover + bokeh.plotting.show(fig) + return None + else: + return fig + + @_check_types + @doc( + summary="Run XP-EHH GWSS.", returns=dict( x="An array containing the window centre point genomic positions.", - ihs="An array with iHS statistic values for each window.", + xpehh="An array with XP-EHH statistic values for each window.", ), ) - def ihs_gwss( + def xpehh_gwss( self, contig: base_params.contig, analysis: hap_params.analysis = base_params.DEFAULT, sample_sets: Optional[base_params.sample_sets] = None, - sample_query: Optional[base_params.sample_query] = None, + cohort1_query: Optional[base_params.sample_query] = None, + cohort2_query: Optional[base_params.sample_query] = None, sample_query_options: Optional[base_params.sample_query_options] = None, - window_size: ihs_params.window_size = ihs_params.window_size_default, - percentiles: ihs_params.percentiles = ihs_params.percentiles_default, - standardize: ihs_params.standardize = True, - standardization_bins: Optional[ihs_params.standardization_bins] = None, - standardization_n_bins: ihs_params.standardization_n_bins = ihs_params.standardization_n_bins_default, - standardization_diagnostics: ihs_params.standardization_diagnostics = False, - filter_min_maf: ihs_params.filter_min_maf = ihs_params.filter_min_maf_default, - compute_min_maf: ihs_params.compute_min_maf = ihs_params.compute_min_maf_default, - min_ehh: ihs_params.min_ehh = ihs_params.min_ehh_default, - max_gap: ihs_params.max_gap = ihs_params.max_gap_default, - gap_scale: ihs_params.gap_scale = ihs_params.gap_scale_default, - include_edges: ihs_params.include_edges = True, - use_threads: ihs_params.use_threads = True, + window_size: xpehh_params.window_size = xpehh_params.window_size_default, + percentiles: xpehh_params.percentiles = xpehh_params.percentiles_default, + filter_min_maf: xpehh_params.filter_min_maf = xpehh_params.filter_min_maf_default, + map_pos: Optional[xpehh_params.map_pos] = None, + min_ehh: xpehh_params.min_ehh = xpehh_params.min_ehh_default, + max_gap: xpehh_params.max_gap = xpehh_params.max_gap_default, + gap_scale: xpehh_params.gap_scale = xpehh_params.gap_scale_default, + include_edges: xpehh_params.include_edges = True, + use_threads: xpehh_params.use_threads = True, min_cohort_size: Optional[ base_params.min_cohort_size - ] = ihs_params.min_cohort_size_default, + ] = xpehh_params.min_cohort_size_default, max_cohort_size: Optional[ base_params.max_cohort_size - ] = ihs_params.max_cohort_size_default, + ] = xpehh_params.max_cohort_size_default, random_seed: base_params.random_seed = 42, chunks: base_params.chunks = base_params.native_chunks, inline_array: base_params.inline_array = base_params.inline_array_default, ) -> Tuple[np.ndarray, np.ndarray]: # change this name if you ever change the behaviour of this function, to # invalidate any previously cached data - name = self._get_ihs_gwss_cache_name() + name = self._xpehh_gwss_cache_name params = dict( contig=contig, analysis=self._prep_phasing_analysis_param(analysis=analysis), window_size=window_size, percentiles=percentiles, - standardize=standardize, - standardization_bins=standardization_bins, - standardization_n_bins=standardization_n_bins, - standardization_diagnostics=standardization_diagnostics, filter_min_maf=filter_min_maf, - compute_min_maf=compute_min_maf, + map_pos=map_pos, min_ehh=min_ehh, include_edges=include_edges, max_gap=max_gap, @@ -767,7 +1522,8 @@ def ihs_gwss( # N.B., do not be tempted to convert this sample query into integer # indices using _prep_sample_selection_params, because the indices # are different in the haplotype data. - sample_query=self._prep_sample_query_param(sample_query=sample_query), + cohort1_query=self._prep_sample_query_param(sample_query=cohort1_query), + cohort2_query=self._prep_sample_query_param(sample_query=cohort2_query), sample_query_options=sample_query_options, min_cohort_size=min_cohort_size, max_cohort_size=max_cohort_size, @@ -778,30 +1534,29 @@ def ihs_gwss( results = self.results_cache_get(name=name, params=params) except CacheMiss: - results = self._ihs_gwss(chunks=chunks, inline_array=inline_array, **params) + results = self._xpehh_gwss( + chunks=chunks, inline_array=inline_array, **params + ) self.results_cache_set(name=name, params=params, results=results) x = results["x"] - ihs = results["ihs"] + xpehh = results["xpehh"] - return x, ihs + return x, xpehh - def _ihs_gwss( + def _xpehh_gwss( self, *, contig, analysis, sample_sets, - sample_query, + cohort1_query, + cohort2_query, sample_query_options, window_size, percentiles, - standardize, - standardization_bins, - standardization_n_bins, - standardization_diagnostics, filter_min_maf, - compute_min_maf, + map_pos, min_ehh, max_gap, gap_scale, @@ -813,10 +1568,10 @@ def _ihs_gwss( chunks, inline_array, ): - ds_haps = self.haplotypes( + ds_haps1 = self.haplotypes( region=contig, analysis=analysis, - sample_query=sample_query, + sample_query=cohort1_query, sample_query_options=sample_query_options, sample_sets=sample_sets, min_cohort_size=min_cohort_size, @@ -826,27 +1581,49 @@ def _ihs_gwss( inline_array=inline_array, ) - gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data) - with self._dask_progress(desc="Load haplotypes"): - ht = gt.to_haplotypes().compute() + ds_haps2 = self.haplotypes( + region=contig, + analysis=analysis, + sample_query=cohort2_query, + sample_query_options=sample_query_options, + sample_sets=sample_sets, + min_cohort_size=min_cohort_size, + max_cohort_size=max_cohort_size, + random_seed=random_seed, + chunks=chunks, + inline_array=inline_array, + ) - with self._spinner(desc="Compute IHS"): - ac = ht.count_alleles(max_allele=1) - pos = ds_haps["variant_position"].values + gt1 = allel.GenotypeDaskArray(ds_haps1["call_genotype"].data) + gt2 = allel.GenotypeDaskArray(ds_haps2["call_genotype"].data) + with self._dask_progress(desc="Load haplotypes for cohort 1"): + ht1 = gt1.to_haplotypes().compute() + with self._dask_progress(desc="Load haplotypes for cohort 2"): + ht2 = gt2.to_haplotypes().compute() + + with self._spinner("Compute XPEHH"): + ac1 = ht1.count_alleles(max_allele=1) + ac2 = ht2.count_alleles(max_allele=1) + pos = ds_haps1["variant_position"].values if filter_min_maf > 0: + ac = ac1 + ac2 af = ac.to_frequencies() maf = np.min(af, axis=1) maf_filter = maf > filter_min_maf - ht = ht.compress(maf_filter, axis=0) + + ht1 = ht1.compress(maf_filter, axis=0) + ht2 = ht2.compress(maf_filter, axis=0) pos = pos[maf_filter] - ac = ac[maf_filter] + ac1 = ac1[maf_filter] + ac2 = ac2[maf_filter] - # compute iHS - ihs = allel.ihs( - h=ht, + # compute XP-EHH + xp = allel.xpehh( + h1=ht1, + h2=ht2, pos=pos, - min_maf=compute_min_maf, + map_pos=map_pos, min_ehh=min_ehh, include_edges=include_edges, max_gap=max_gap, @@ -855,65 +1632,50 @@ def _ihs_gwss( ) # remove any NaNs - na_mask = ~np.isnan(ihs) - ihs = ihs[na_mask] + na_mask = ~np.isnan(xp) + xp = xp[na_mask] pos = pos[na_mask] - ac = ac[na_mask] - - # take absolute value - ihs = np.fabs(ihs) - - if standardize: - ihs, _ = allel.standardize_by_allele_count( - score=ihs, - aac=ac[:, 1], - bins=standardization_bins, - n_bins=standardization_n_bins, - diagnostics=standardization_diagnostics, - ) + ac1 = ac1[na_mask] + ac2 = ac2[na_mask] if window_size: - ihs = allel.moving_statistic( - ihs, statistic=np.percentile, size=window_size, q=percentiles + xp = allel.moving_statistic( + xp, statistic=np.percentile, size=window_size, q=percentiles ) pos = allel.moving_statistic(pos, statistic=np.mean, size=window_size) - results = dict(x=pos, ihs=ihs) + results = dict(x=pos, xpehh=xp) return results - @_check_types @doc( - summary="Run and plot iHS GWSS data.", + summary="Run and plot XP-EHH GWSS data.", ) - def plot_ihs_gwss_track( + def plot_xpehh_gwss_track( self, contig: base_params.contig, analysis: hap_params.analysis = base_params.DEFAULT, sample_sets: Optional[base_params.sample_sets] = None, - sample_query: Optional[base_params.sample_query] = None, + cohort1_query: Optional[base_params.sample_query] = None, + cohort2_query: Optional[base_params.sample_query] = None, sample_query_options: Optional[base_params.sample_query_options] = None, - window_size: ihs_params.window_size = ihs_params.window_size_default, - percentiles: ihs_params.percentiles = ihs_params.percentiles_default, - standardize: ihs_params.standardize = True, - standardization_bins: Optional[ihs_params.standardization_bins] = None, - standardization_n_bins: ihs_params.standardization_n_bins = ihs_params.standardization_n_bins_default, - standardization_diagnostics: ihs_params.standardization_diagnostics = False, - filter_min_maf: ihs_params.filter_min_maf = ihs_params.filter_min_maf_default, - compute_min_maf: ihs_params.compute_min_maf = ihs_params.compute_min_maf_default, - min_ehh: ihs_params.min_ehh = ihs_params.min_ehh_default, - max_gap: ihs_params.max_gap = ihs_params.max_gap_default, - gap_scale: ihs_params.gap_scale = ihs_params.gap_scale_default, - include_edges: ihs_params.include_edges = True, - use_threads: ihs_params.use_threads = True, + window_size: xpehh_params.window_size = xpehh_params.window_size_default, + percentiles: xpehh_params.percentiles = xpehh_params.percentiles_default, + filter_min_maf: xpehh_params.filter_min_maf = xpehh_params.filter_min_maf_default, + map_pos: Optional[xpehh_params.map_pos] = None, + min_ehh: xpehh_params.min_ehh = xpehh_params.min_ehh_default, + max_gap: xpehh_params.max_gap = xpehh_params.max_gap_default, + gap_scale: xpehh_params.gap_scale = xpehh_params.gap_scale_default, + include_edges: xpehh_params.include_edges = True, + use_threads: xpehh_params.use_threads = True, min_cohort_size: Optional[ base_params.min_cohort_size - ] = ihs_params.min_cohort_size_default, + ] = xpehh_params.min_cohort_size_default, max_cohort_size: Optional[ base_params.max_cohort_size - ] = ihs_params.max_cohort_size_default, + ] = xpehh_params.max_cohort_size_default, random_seed: base_params.random_seed = 42, - palette: ihs_params.palette = ihs_params.palette_default, + palette: xpehh_params.palette = xpehh_params.palette_default, title: Optional[gplt_params.title] = None, sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, width: gplt_params.width = gplt_params.width_default, @@ -924,18 +1686,14 @@ def plot_ihs_gwss_track( chunks: base_params.chunks = base_params.native_chunks, inline_array: base_params.inline_array = base_params.inline_array_default, ) -> gplt_params.optional_figure: - # compute ihs - x, ihs = self.ihs_gwss( + # compute xpehh + x, xpehh = self.xpehh_gwss( contig=contig, analysis=analysis, window_size=window_size, percentiles=percentiles, - standardize=standardize, - standardization_bins=standardization_bins, - standardization_n_bins=standardization_n_bins, - standardization_diagnostics=standardization_diagnostics, filter_min_maf=filter_min_maf, - compute_min_maf=compute_min_maf, + map_pos=map_pos, min_ehh=min_ehh, max_gap=max_gap, gap_scale=gap_scale, @@ -943,7 +1701,8 @@ def plot_ihs_gwss_track( use_threads=use_threads, min_cohort_size=min_cohort_size, max_cohort_size=max_cohort_size, - sample_query=sample_query, + cohort1_query=cohort1_query, + cohort2_query=cohort2_query, sample_query_options=sample_query_options, sample_sets=sample_sets, random_seed=random_seed, @@ -951,12 +1710,6 @@ def plot_ihs_gwss_track( inline_array=inline_array, ) - if len(x) == 0: - raise ValueError( - "No iHS values remain after filtering. " - "Try relaxing filter_min_maf or min_ehh parameters." - ) - # determine X axis range x_min = x[0] x_max = x[-1] @@ -968,7 +1721,10 @@ def plot_ihs_gwss_track( dimensions="width", maintain_focus=False ) if title is None: - title = sample_query + if cohort1_query is None or cohort2_query is None: + title = "XP-EHH" + else: + title = f"Cohort 1: {cohort1_query}\nCohort 2: {cohort2_query}" fig = bokeh.plotting.figure( title=title, tools=[ @@ -997,27 +1753,26 @@ def plot_ihs_gwss_track( # Ensure percentiles are sorted so that colors make sense. percentiles = tuple(sorted(percentiles)) - # add an empty dimension to ihs array if 1D - ihs = np.reshape(ihs, (ihs.shape[0], -1)) + # add an empty dimension to XP-EHH array if 1D + xpehh = np.reshape(xpehh, (xpehh.shape[0], -1)) # select the base color palette to work from base_palette = bokeh.palettes.all_palettes[palette][8] - # keep only enough colours to plot the IHS tracks - bokeh_palette = base_palette[: ihs.shape[1]] + # keep only enough colours to plot the XP-EHH tracks + bokeh_palette = base_palette[: xpehh.shape[1]] # reverse the colors so darkest is last bokeh_palette = bokeh_palette[::-1] - # plot IHS tracks - for i in range(ihs.shape[1]): - ihs_perc = ihs[:, i] + for i in range(xpehh.shape[1]): + xpehh_perc = xpehh[:, i] color = bokeh_palette[i] - # plot ihs + # plot XP-EHH fig.circle( x=x, - y=ihs_perc, + y=xpehh_perc, size=4, line_width=0, line_color=color, @@ -1025,117 +1780,14 @@ def plot_ihs_gwss_track( ) # tidy up the plot - fig.yaxis.axis_label = "ihs" + fig.yaxis.axis_label = "XP-EHH" self._bokeh_style_genome_xaxis(fig, contig) if show: # pragma: no cover bokeh.plotting.show(fig) - return fig - - @doc( - summary="Run and plot iHS GWSS data.", - ) - def plot_ihs_gwss( - self, - contig: base_params.contig, - analysis: hap_params.analysis = base_params.DEFAULT, - sample_sets: Optional[base_params.sample_sets] = None, - sample_query: Optional[base_params.sample_query] = None, - sample_query_options: Optional[base_params.sample_query_options] = None, - window_size: ihs_params.window_size = ihs_params.window_size_default, - percentiles: ihs_params.percentiles = ihs_params.percentiles_default, - standardize: ihs_params.standardize = True, - standardization_bins: Optional[ihs_params.standardization_bins] = None, - standardization_n_bins: ihs_params.standardization_n_bins = ihs_params.standardization_n_bins_default, - standardization_diagnostics: ihs_params.standardization_diagnostics = False, - filter_min_maf: ihs_params.filter_min_maf = ihs_params.filter_min_maf_default, - compute_min_maf: ihs_params.compute_min_maf = ihs_params.compute_min_maf_default, - min_ehh: ihs_params.min_ehh = ihs_params.min_ehh_default, - max_gap: ihs_params.max_gap = ihs_params.max_gap_default, - gap_scale: ihs_params.gap_scale = ihs_params.gap_scale_default, - include_edges: ihs_params.include_edges = True, - use_threads: ihs_params.use_threads = True, - min_cohort_size: Optional[ - base_params.min_cohort_size - ] = ihs_params.min_cohort_size_default, - max_cohort_size: Optional[ - base_params.max_cohort_size - ] = ihs_params.max_cohort_size_default, - random_seed: base_params.random_seed = 42, - palette: ihs_params.palette = ihs_params.palette_default, - title: Optional[gplt_params.title] = None, - sizing_mode: gplt_params.sizing_mode = gplt_params.sizing_mode_default, - width: gplt_params.width = gplt_params.width_default, - track_height: gplt_params.track_height = 170, - genes_height: gplt_params.genes_height = gplt_params.genes_height_default, - show: gplt_params.show = True, - output_backend: gplt_params.output_backend = gplt_params.output_backend_default, - chunks: base_params.chunks = base_params.native_chunks, - inline_array: base_params.inline_array = base_params.inline_array_default, - gene_labels: Optional[gplt_params.gene_labels] = None, - gene_labelset: Optional[gplt_params.gene_labelset] = None, - ) -> gplt_params.optional_figure: - # gwss track - fig1 = self.plot_ihs_gwss_track( - contig=contig, - analysis=analysis, - sample_sets=sample_sets, - sample_query=sample_query, - sample_query_options=sample_query_options, - window_size=window_size, - percentiles=percentiles, - palette=palette, - standardize=standardize, - standardization_bins=standardization_bins, - standardization_n_bins=standardization_n_bins, - standardization_diagnostics=standardization_diagnostics, - filter_min_maf=filter_min_maf, - compute_min_maf=compute_min_maf, - min_ehh=min_ehh, - max_gap=max_gap, - gap_scale=gap_scale, - include_edges=include_edges, - use_threads=use_threads, - min_cohort_size=min_cohort_size, - max_cohort_size=max_cohort_size, - random_seed=random_seed, - title=title, - sizing_mode=sizing_mode, - width=width, - height=track_height, - show=False, - output_backend=output_backend, - chunks=chunks, - inline_array=inline_array, - ) - - fig1.xaxis.visible = False - - # plot genes - fig2 = self.plot_genes( - region=contig, - sizing_mode=sizing_mode, - width=width, - height=genes_height, - x_range=fig1.x_range, - show=False, - output_backend=output_backend, - gene_labels=gene_labels, - gene_labelset=gene_labelset, - ) - - # combine plots into a single figure - fig = bokeh.layouts.gridplot( - [fig1, fig2], - ncols=1, - toolbar_location="above", - merge_tools=True, - sizing_mode=sizing_mode, - ) - - if show: # pragma: no cover - bokeh.plotting.show(fig) - return fig + return None + else: + return fig @_check_types @doc( @@ -1298,8 +1950,7 @@ def plot_haplotype_network( # Apply each query in the mapping to create the _partition column for label, query in color.items(): - # Validate and apply the query to matching rows - validate_query(query) + # Apply the query and assign the label to matching rows mask = df_haps.eval(query) df_haps.loc[mask, "_partition"] = label @@ -1453,8 +2104,6 @@ def plot_haplotype_network( boxSelectionEnabled=True, # prevent accidentally zooming out to oblivion minZoom=0.1, - # lower scroll wheel zoom sensitivity to prevent accidental zooming when trying to navigate large graphs - wheelSensitivity=0.1, ) debug("create dash app") diff --git a/tests/anoph/test_ihs.py b/tests/anoph/test_ihs.py new file mode 100644 index 000000000..5bbc96807 --- /dev/null +++ b/tests/anoph/test_ihs.py @@ -0,0 +1,341 @@ +"""Tests for AnophelesIhsAnalysis using simulated data.""" + +import random + +import bokeh.models +import numpy as np +import pytest +from pytest_cases import parametrize_with_cases + +from malariagen_data import af1 as _af1 +from malariagen_data import ag3 as _ag3 +from malariagen_data.anoph.ihs import AnophelesIhsAnalysis + + +@pytest.fixture +def ag3_sim_api(ag3_sim_fixture): + return AnophelesIhsAnalysis( + url=ag3_sim_fixture.url, + public_url=ag3_sim_fixture.url, + config_path=_ag3.CONFIG_PATH, + major_version_number=_ag3.MAJOR_VERSION_NUMBER, + major_version_path=_ag3.MAJOR_VERSION_PATH, + pre=True, + aim_metadata_dtype={ + "aim_species_fraction_arab": "float64", + "aim_species_fraction_colu": "float64", + "aim_species_fraction_colu_no2l": "float64", + "aim_species_gambcolu_arabiensis": object, + "aim_species_gambiae_coluzzii": object, + "aim_species": object, + }, + gff_gene_type="gene", + gff_gene_name_attribute="Name", + gff_default_attributes=("ID", "Parent", "Name", "description"), + results_cache=ag3_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_ag3.TAXON_COLORS, + default_phasing_analysis="gamb_colu_arab", + virtual_contigs=_ag3.VIRTUAL_CONTIGS, + ) + + +@pytest.fixture +def af1_sim_api(af1_sim_fixture): + return AnophelesIhsAnalysis( + url=af1_sim_fixture.url, + public_url=af1_sim_fixture.url, + config_path=_af1.CONFIG_PATH, + major_version_number=_af1.MAJOR_VERSION_NUMBER, + major_version_path=_af1.MAJOR_VERSION_PATH, + pre=False, + gff_gene_type="protein_coding_gene", + gff_gene_name_attribute="Note", + gff_default_attributes=("ID", "Parent", "Note", "description"), + results_cache=af1_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_af1.TAXON_COLORS, + default_phasing_analysis="funestus", + ) + + +# N.B., here we use pytest_cases to parametrize tests. Each +# function whose name begins with "case_" defines a set of +# inputs to the test functions. See the documentation for +# pytest_cases for more information, e.g.: +# +# https://smarie.github.io/python-pytest-cases/#basic-usage +# +# We use this approach here because we want to use fixtures +# as test parameters, which is otherwise hard to do with +# pytest alone. + + +def case_ag3_sim(ag3_sim_fixture, ag3_sim_api): + return ag3_sim_fixture, ag3_sim_api + + +def case_af1_sim(af1_sim_fixture, af1_sim_api): + return af1_sim_fixture, af1_sim_api + + +def check_ihs_gwss(*, api, ihs_params): + """Core check for ihs_gwss results and plots.""" + # Run main gwss function under test. + x, ihs = api.ihs_gwss(**ihs_params) + + # Check types and shapes. + assert isinstance(x, np.ndarray) + assert isinstance(ihs, np.ndarray) + assert x.ndim == 1 + assert x.dtype.kind == "f" + + # When window_size is set (default), ihs can be 1D (single percentile) + # or 2D (multiple percentiles). Either way the leading dimension matches x. + assert ihs.shape[0] == x.shape[0] + + if len(x) == 0: + # With very sparse simulated data, all variants may be filtered + # and there are no windows to plot; skip plotting checks. + return + + # Check plotting functions. + fig = api.plot_ihs_gwss_track(**ihs_params, show=False) + assert isinstance(fig, bokeh.models.Plot) + + fig = api.plot_ihs_gwss(**ihs_params, show=False) + assert isinstance(fig, bokeh.models.GridPlot) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_ihs_gwss_with_default_analysis(fixture, api: AnophelesIhsAnalysis): + # Skip datasets with no phasing analyses (IHS requires haplotype data). + if not api.phasing_analysis_ids: + pytest.skip("No phasing analyses available for this dataset.") + + # Set up test parameters using small cohort sizes for simulated data. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + ihs_params = dict( + contig=random.choice(api.contigs), + sample_sets=[random.choice(all_sample_sets)], + window_size=random.randint(20, 100), + min_cohort_size=1, + max_cohort_size=None, + # Disable standardization for simulated data: too few variants to build + # reliable standardization bins. + standardize=False, + filter_min_maf=0.0, + compute_min_maf=0.0, + ) + + # Run checks. + check_ihs_gwss(api=api, ihs_params=ihs_params) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_ihs_gwss_with_specific_analysis(fixture, api: AnophelesIhsAnalysis): + # Skip datasets with no phasing analyses. + if not api.phasing_analysis_ids: + pytest.skip("No phasing analyses available for this dataset.") + + # Test with each available phasing analysis. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + contig = random.choice(api.contigs) + + for analysis in api.phasing_analysis_ids: + # Check whether any samples are available for this analysis. + try: + ds_hap = api.haplotypes( + sample_sets=all_sample_sets, + analysis=analysis, + region=contig, + ) + except ValueError: + # No samples available for this analysis on this contig. + continue + + n_samples = ds_hap.sizes["samples"] + ihs_params = dict( + contig=contig, + analysis=analysis, + sample_sets=all_sample_sets, + window_size=random.randint(20, 100), + min_cohort_size=n_samples, + max_cohort_size=None, + # Disable standardization to avoid failures with small simulated + # datasets where there may be too few variants per bin. + standardize=False, + filter_min_maf=0.0, + compute_min_maf=0.0, + ) + + check_ihs_gwss(api=api, ihs_params=ihs_params) + + # Check that requesting more samples than available raises ValueError. + with pytest.raises(ValueError): + api.ihs_gwss( + contig=contig, + analysis=analysis, + sample_sets=all_sample_sets, + window_size=random.randint(20, 100), + min_cohort_size=n_samples + 1, + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_ihs_gwss_without_windowing(fixture, api: AnophelesIhsAnalysis): + """Test per-variant iHS (window_size=0) returns 1-D arrays.""" + # Skip datasets with no phasing analyses. + if not api.phasing_analysis_ids: + pytest.skip("No phasing analyses available for this dataset.") + + all_sample_sets = api.sample_sets()["sample_set"].to_list() + ihs_params = dict( + contig=random.choice(api.contigs), + sample_sets=[random.choice(all_sample_sets)], + # Use window_size=0 for per-variant iHS (falsy value skips windowing). + window_size=0, + min_cohort_size=1, + max_cohort_size=None, + # Turn off standardization and MAF filtering so that simulated data + # with few variants still produces results. + standardize=False, + filter_min_maf=0.0, + compute_min_maf=0.0, + ) + + x, ihs = api.ihs_gwss(**ihs_params) + + # Both arrays must be 1-D and have the same length. + assert isinstance(x, np.ndarray) + assert isinstance(ihs, np.ndarray) + assert x.ndim == 1 + assert ihs.ndim == 1 + assert x.shape == ihs.shape + # Positions are raw integers when no windowing is applied (no mean is taken). + assert x.dtype.kind in ("i", "u", "f") + + +@parametrize_with_cases("fixture,api", cases=".") +def test_ihs_gwss_with_sample_query(fixture, api: AnophelesIhsAnalysis): + """Test that ihs_gwss accepts a sample_query parameter.""" + # Skip datasets with no phasing analyses. + if not api.phasing_analysis_ids: + pytest.skip("No phasing analyses available for this dataset.") + + all_sample_sets = api.sample_sets()["sample_set"].to_list() + + # Pick a country that has samples. + all_countries = ( + api.sample_metadata(sample_sets=all_sample_sets)["country"].unique().tolist() + ) + country = random.choice(all_countries) + sample_query = f"country == '{country}'" + + try: + x, ihs = api.ihs_gwss( + contig=random.choice(api.contigs), + sample_sets=all_sample_sets, + sample_query=sample_query, + window_size=random.randint(20, 100), + min_cohort_size=1, + max_cohort_size=None, + ) + assert isinstance(x, np.ndarray) + assert isinstance(ihs, np.ndarray) + assert x.ndim == 1 + except ValueError: + # It's OK if there's no haplotype data for the selected country. + pass + + +@parametrize_with_cases("fixture,api", cases=".") +def test_ihs_gwss_caching(fixture, api: AnophelesIhsAnalysis): + """Test that calling ihs_gwss twice returns consistent cached results.""" + # Skip datasets with no phasing analyses. + if not api.phasing_analysis_ids: + pytest.skip("No phasing analyses available for this dataset.") + + all_sample_sets = api.sample_sets()["sample_set"].to_list() + ihs_params = dict( + contig=random.choice(api.contigs), + sample_sets=[random.choice(all_sample_sets)], + window_size=random.randint(20, 100), + min_cohort_size=1, + max_cohort_size=None, + # Disable standardization to avoid failures with small simulated + # datasets where there may be too few variants per standardization bin. + standardize=False, + filter_min_maf=0.0, + compute_min_maf=0.0, + ) + + # Call twice - second call should use cache. + x1, ihs1 = api.ihs_gwss(**ihs_params) + x2, ihs2 = api.ihs_gwss(**ihs_params) + + np.testing.assert_array_equal(x1, x2) + np.testing.assert_array_equal(ihs1, ihs2) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_ihs_gwss_multiple_percentiles(fixture, api: AnophelesIhsAnalysis): + """Test ihs_gwss with multiple percentiles returns correctly shaped output.""" + # Skip datasets with no phasing analyses. + if not api.phasing_analysis_ids: + pytest.skip("No phasing analyses available for this dataset.") + + all_sample_sets = api.sample_sets()["sample_set"].to_list() + percentiles = (50, 75, 100) + ihs_params = dict( + contig=random.choice(api.contigs), + sample_sets=[random.choice(all_sample_sets)], + window_size=random.randint(20, 100), + percentiles=percentiles, + min_cohort_size=1, + max_cohort_size=None, + # Disable standardization to avoid failures with small simulated + # datasets where there may be too few variants per standardization bin. + standardize=False, + filter_min_maf=0.0, + compute_min_maf=0.0, + ) + + x, ihs = api.ihs_gwss(**ihs_params) + + assert isinstance(x, np.ndarray) + assert isinstance(ihs, np.ndarray) + assert x.ndim == 1 + # With multiple percentiles, ihs should be 2-D: (n_windows, n_percentiles) + assert ihs.ndim == 2 + assert ihs.shape[0] == x.shape[0] + assert ihs.shape[1] == len(percentiles) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_ihs_gwss_single_percentile(fixture, api: AnophelesIhsAnalysis): + """Test ihs_gwss with a single integer percentile.""" + # Skip datasets with no phasing analyses. + if not api.phasing_analysis_ids: + pytest.skip("No phasing analyses available for this dataset.") + + all_sample_sets = api.sample_sets()["sample_set"].to_list() + ihs_params = dict( + contig=random.choice(api.contigs), + sample_sets=[random.choice(all_sample_sets)], + window_size=random.randint(20, 100), + percentiles=50, # single integer + min_cohort_size=1, + max_cohort_size=None, + # Disable standardization to avoid failures with small simulated data. + standardize=False, + filter_min_maf=0.0, + compute_min_maf=0.0, + ) + + x, ihs = api.ihs_gwss(**ihs_params) + + assert isinstance(x, np.ndarray) + assert isinstance(ihs, np.ndarray) + assert x.ndim == 1 + # Single percentile: ihs is 1-D + assert ihs.ndim == 1 + assert ihs.shape == x.shape