diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 4392b7e..3f12a91 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -129,10 +129,12 @@ class SEQopts: cense_eligible_colname: Optional[str] = None compevent_colname: Optional[str] = None covariates: Optional[str] = None + cox_package: Literal["lifelines", "scikit-survival"] = "lifelines" denominator: Optional[str] = None excused: bool = False excused_colnames: List[str] = field(default_factory=lambda: []) expand_only: bool = False + glm_package: Literal["statsmodels", "glum"] = "statsmodels" followup_class: bool = False followup_include: bool = True followup_max: int = None @@ -231,6 +233,10 @@ def _validate_choices(self): ) if self.bootstrap_CI_method not in ["se", "percentile"]: raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'") + if self.glm_package not in ["statsmodels", "glum"]: + raise ValueError("glm_package must be 'statsmodels' or 'glum'") + if self.cox_package not in ["lifelines", "scikit-survival"]: + raise ValueError("cox_package must be 'lifelines' or 'scikit-survival'") def _normalize_formulas(self): for i in ( diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index 0bff491..cecf28c 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -102,6 +102,10 @@ def retrieve_data( "risk_difference", "unique_outcomes", "nonunique_outcomes", + "unique_followup", + "nonunique_followup", + "unique_compevent", + "nonunique_compevent", "unique_switches", "nonunique_switches", ] @@ -112,7 +116,9 @@ def retrieve_data( :param type: Data which you would like to access, ['km_data', 'hazard', 'risk_ratio', 'risk_difference', 'unique_outcomes', - 'nonunique_outcomes', 'unique_switches', 'nonunique_switches'] + 'nonunique_outcomes', 'unique_followup', 'nonunique_followup', + 'unique_compevent', 'nonunique_compevent', + 'unique_switches', 'nonunique_switches'] :type type: str """ match type: @@ -126,6 +132,14 @@ def retrieve_data( data = self.diagnostic_tables["unique_outcomes"] case "nonunique_outcomes": data = self.diagnostic_tables["nonunique_outcomes"] + case "unique_followup": + data = self.diagnostic_tables["unique_followup"] + case "nonunique_followup": + data = self.diagnostic_tables["nonunique_followup"] + case "unique_compevent": + data = self.diagnostic_tables.get("unique_compevent") + case "nonunique_compevent": + data = self.diagnostic_tables.get("nonunique_compevent") case "unique_switches": if self.diagnostic_tables.has_key("unique_switches"): data = self.diagnostic_tables["unique_switches"] diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 64613fd..a8331bf 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -35,10 +35,17 @@ def _calculate_hazard_single(self, data, idx=None, val=None): if self.bootstrap_nboot > 0: boot_log_hrs = [] - for boot_idx in range(len(self._boot_samples)): + # outcome_model[model_pos + 1] was fit on _boot_samples[sample_idx]; + # skipped replicates make this mapping non-identity, so iterate it + # explicitly rather than assuming model index == sample index. + boot_sample_idx = getattr(self, "_boot_sample_idx", None) + if boot_sample_idx is None: + boot_sample_idx = list(range(len(self._boot_samples))) + + for model_pos, sample_idx in enumerate(boot_sample_idx): if self.seed is not None: - self._rng = np.random.RandomState(self.seed + boot_idx + 1) - id_counts = self._boot_samples[boot_idx] + self._rng = np.random.RandomState(self.seed + sample_idx + 1) + id_counts = self._boot_samples[sample_idx] counts = pl.DataFrame( { @@ -55,7 +62,9 @@ def _calculate_hazard_single(self, data, idx=None, val=None): .collect() ) - boot_log_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng) + boot_log_hr = _hazard_handler( + self, boot_data, idx, model_pos + 1, self._rng + ) if boot_log_hr is not None and not np.isnan(boot_log_hr): boot_log_hrs.append(boot_log_hr) @@ -180,6 +189,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng): sim_data = sim_data.with_columns([pl.col("outcome").alias("event")]) sim_data_pd = sim_data.to_pandas() + tx_bas = f"{self.treatment_col}{self.indicator_baseline}" try: # COXPHFITTER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow() @@ -187,30 +197,41 @@ def _hazard_handler(self, data, idx, boot_idx, rng): if ce_model is not None: cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy() cox_data["event_binary"] = (cox_data["event"] == 1).astype(int) - - cph = CoxPHFitter() - cph.fit( - cox_data, - duration_col="followup", - event_col="event_binary", - formula=f"`{self.treatment_col}{self.indicator_baseline}`", - ) - else: - cph = CoxPHFitter() - cph.fit( - sim_data_pd, - duration_col="followup", - event_col="event", - formula=f"`{self.treatment_col}{self.indicator_baseline}`", - ) - - log_hr = cph.params_.values[0] - return log_hr + return _cox_log_hr(self, cox_data, "followup", "event_binary", tx_bas) + return _cox_log_hr(self, sim_data_pd, "followup", "event", tx_bas) except Exception as e: print(f"Cox model fitting failed: {e}") return None +def _cox_log_hr(self, data_pd, duration_col, event_col, covariate_col): + """ + Fit a univariate Cox model (single covariate = baseline treatment) and + return the log hazard ratio, dispatching on self.cox_package. scikit-survival + uses Efron tie handling to match lifelines, which matters here because + integer follow-up produces many tied event times. + """ + if getattr(self, "cox_package", "lifelines") == "scikit-survival": + from sksurv.linear_model import CoxPHSurvivalAnalysis + + y = np.empty(len(data_pd), dtype=[("event", bool), ("time", "float64")]) + y["event"] = data_pd[event_col].to_numpy().astype(bool) + y["time"] = data_pd[duration_col].to_numpy().astype(float) + X = data_pd[[covariate_col]].to_numpy().astype(float) + cox = CoxPHSurvivalAnalysis(ties="efron") + cox.fit(X, y) + return float(cox.coef_[0]) + + cph = CoxPHFitter() + cph.fit( + data_pd, + duration_col=duration_col, + event_col=event_col, + formula=f"`{covariate_col}`", + ) + return cph.params_.values[0] + + def _create_hazard_output(hr, lci, uci, val, self): if lci is not None and uci is not None: output = pl.DataFrame( diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index af91176..f63e1e3 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -125,6 +125,15 @@ def _outcome_fit( full_formula = f"{outcome} ~ {formula}" + if getattr(self, "glm_package", "statsmodels") == "glum": + from ..helpers._glum_fit import _fit_glum + + return _fit_glum( + full_formula, + df_pd, + var_weights=df_pd[weight_col] if weighted else None, + ) + glm_kwargs = { "formula": full_formula, "data": df_pd, diff --git a/pySEQTarget/expansion/_diagnostics.py b/pySEQTarget/expansion/_diagnostics.py index 178062a..56dcb0c 100644 --- a/pySEQTarget/expansion/_diagnostics.py +++ b/pySEQTarget/expansion/_diagnostics.py @@ -6,6 +6,15 @@ def _diagnostics(self): nonunique_out = _outcome_diag(self, unique=False) out = {"unique_outcomes": unique_out, "nonunique_outcomes": nonunique_out} + unique_fu = _followup_diag(self, unique=True) + nonunique_fu = _followup_diag(self, unique=False) + out.update({"unique_followup": unique_fu, "nonunique_followup": nonunique_fu}) + + if self.compevent_colname is not None: + unique_ce = _compevent_diag(self, unique=True) + nonunique_ce = _compevent_diag(self, unique=False) + out.update({"unique_compevent": unique_ce, "nonunique_compevent": nonunique_ce}) + if self.method == "censoring": unique_switch = _switch_diag(self, unique=True) nonunique_switch = _switch_diag(self, unique=False) @@ -30,6 +39,45 @@ def _outcome_diag(self, unique): return out +def _followup_diag(self, unique): + """ + Follow-up per treatment arm, grouped by the baseline treatment value, over + the rows the outcome model is fit on (under method="censoring" the switched + rows are dropped, matching _outcome_fit). ``unique`` counts distinct subjects + contributing follow-up to the arm; otherwise counts follow-up intervals + (rows / person-time), so non-unique outcome counts divided by these give + per-arm event rates. A subject appearing in both arms is counted in each. + """ + tx_bas = f"{self.treatment_col}{self.indicator_baseline}" + data = self.DT + if self.method == "censoring": + data = data.filter(pl.col("switch") != 1) + + if unique: + out = data.group_by(tx_bas).agg(pl.col(self.id_col).n_unique().alias("len")) + else: + out = data.group_by(tx_bas).len() + + return out.sort(tx_bas) + + +def _compevent_diag(self, unique): + """ + Competing-event counts per treatment arm. Counts rows where the configured + compevent column is 1, grouped by the baseline treatment value. ``unique`` + counts distinct subjects who experience the competing event in each arm; + otherwise counts the intervals (rows). A subject appearing with a competing + event in both arms is counted in each. + """ + tx_bas = f"{self.treatment_col}{self.indicator_baseline}" + data = self.DT.filter(pl.col(self.compevent_colname) == 1) + if unique: + out = data.group_by(tx_bas).agg(pl.col(self.id_col).n_unique().alias("len")) + else: + out = data.group_by(tx_bas).len() + return out.sort(tx_bas) + + def _switch_diag(self, unique): if not self.excused: data = self.DT.with_columns(pl.lit(False).alias("isExcused")) diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index b59fb61..a8dd844 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -118,12 +118,14 @@ def wrapper(self, *args, **kwargs): for i in range(nboot) } skipped = 0 + boot_sample_idx = [] for j in tqdm( as_completed(futures), total=nboot, desc="Bootstrapping..." ): boot_idx = futures[j] try: results.append(j.result()) + boot_sample_idx.append(boot_idx) except np.linalg.LinAlgError as e: skipped += 1 warnings.warn( @@ -144,6 +146,7 @@ def wrapper(self, *args, **kwargs): original_DT_ref = original_DT skipped = 0 + boot_sample_idx = [] for i in tqdm(range(nboot), desc="Bootstrapping..."): self._current_boot_idx = i + 1 if seed is not None: @@ -156,6 +159,7 @@ def wrapper(self, *args, **kwargs): try: boot_fit = method(self, *args, **kwargs) results.append(boot_fit) + boot_sample_idx.append(i) except np.linalg.LinAlgError as e: skipped += 1 warnings.warn( @@ -167,6 +171,11 @@ def wrapper(self, *args, **kwargs): self.DT = self._offloader.load_dataframe(original_DT_ref) + # Maps each fitted bootstrap model (results[1:]) back to its + # original _boot_samples index. Skipped replicates leave gaps, so + # downstream consumers (e.g. hazard) must use this to pair a + # resample with its model rather than assuming a 1:1 ordering. + self._boot_sample_idx = boot_sample_idx self.bootstrap_nboot = len(results) - 1 if skipped > 0: warnings.warn( diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py new file mode 100644 index 0000000..f8894b1 --- /dev/null +++ b/pySEQTarget/helpers/_glum_fit.py @@ -0,0 +1,126 @@ +import types + +import numpy as np +import pandas as pd +import patsy +from glum import GeneralizedLinearRegressor + + +class _GlumFit: + """ + Wraps a fitted glum model exposing the statsmodels interface the rest of + the codebase (and users) expect: + .params (Series), .model.exog_names, .model.data.design_info, + .predict(df) / .predict(X_numpy, transform=False), + .bse, .summary(), .summary2(). + + Standard errors are derived lazily from the stored design matrix using the + GLM asymptotic covariance (X' W X)^-1, which matches statsmodels for the + binomial/logit family (incl. var_weights). The design matrix is retained + just like statsmodels keeps model.exog, so memory use is comparable. + """ + + def __init__(self, glum_model, design_info, feature_names, X_design, sample_weight): + self._glum = glum_model + self._design_info = design_info + self._X_design = X_design # includes the intercept column + self._sample_weight = sample_weight + + self.model = types.SimpleNamespace( + exog_names=feature_names, + data=types.SimpleNamespace(design_info=design_info), + ) + self.exog_names = feature_names + + # statsmodels convention: intercept first + all_coefs = np.concatenate([[glum_model.intercept_], glum_model.coef_]) + self.params = pd.Series(all_coefs, index=feature_names) + + def predict(self, data, transform=True): + if transform: + # data is a pandas DataFrame — build design matrix via stored patsy info + X = patsy.build_design_matrices( + [self._design_info], data, return_type="dataframe" + )[0] + X_arr = X.drop(columns=["Intercept"], errors="ignore").values + else: + # data is a pre-built numpy design matrix (includes intercept col — drop it) + X_arr = np.asarray(data)[:, 1:] + return self._glum.predict(X_arr) + + def cov_params(self): + X = self._X_design + mu = self._glum.predict(X[:, 1:]) + w = mu * (1.0 - mu) + if self._sample_weight is not None: + w = w * np.asarray(self._sample_weight) + return np.linalg.pinv(X.T @ (w[:, None] * X)) + + @property + def bse(self): + return pd.Series(np.sqrt(np.diag(self.cov_params())), index=self.params.index) + + def _coef_table(self): + from scipy import stats + + coef = self.params.values + se = self.bse.values + with np.errstate(divide="ignore", invalid="ignore"): + z = coef / se + pvals = 2.0 * stats.norm.sf(np.abs(z)) + crit = stats.norm.ppf(0.975) + return pd.DataFrame( + { + "Coef.": coef, + "Std.Err.": se, + "z": z, + "P>|z|": pvals, + "[0.025": coef - crit * se, + "0.975]": coef + crit * se, + }, + index=list(self.params.index), + ) + + def summary2(self): + from statsmodels.iolib.summary2 import Summary + + info = pd.DataFrame( + { + " ": [ + "GLM (glum backend)", + "Binomial", + "logit", + str(self._X_design.shape[0]), + ] + }, + index=["Model:", "Family:", "Link:", "No. Observations:"], + ) + smry = Summary() + smry.add_title("Generalized Linear Model Regression Results") + smry.add_df(info, header=False) + smry.add_df(self._coef_table()) + return smry + + # statsmodels exposes both; the codebase/practical use either, so alias them. + summary = summary2 + + +def _fit_glum(formula, data, var_weights=None): + """Fit a binomial GLM with glum and return a _GlumFit wrapper.""" + y_mat, X_mat = patsy.dmatrices(formula, data, return_type="dataframe") + y_arr = y_mat.values.ravel() + design_info = X_mat.design_info + feature_names = list(X_mat.columns) # "Intercept" first, then predictors + X_design = X_mat.values # includes intercept column (for covariance) + X_arr = X_mat.drop(columns=["Intercept"]).values + + glm = GeneralizedLinearRegressor(family="binomial", fit_intercept=True) + + sample_weight = None + fit_kwargs = {} + if var_weights is not None: + sample_weight = np.asarray(var_weights) + fit_kwargs["sample_weight"] = sample_weight + + glm.fit(X_arr, y_arr, **fit_kwargs) + return _GlumFit(glm, design_info, feature_names, X_design, sample_weight) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index e6b39d0..8fcb0bf 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -2,6 +2,7 @@ import statsmodels.formula.api as smf from ..error._check_separation import _check_separation +from ..helpers._glum_fit import _fit_glum def _get_subset_for_level( @@ -47,8 +48,11 @@ def _fit_pair( setattr(self, out, None) continue formula = f"{outcome}~{rhs}" - model = smf.glm(formula, WDT, family=sm.families.Binomial()) - fitted = model.fit(disp=0, method=self.weight_fit_method) + if getattr(self, "glm_package", "statsmodels") == "glum": + fitted = _fit_glum(formula, WDT) + else: + model = smf.glm(formula, WDT, family=sm.families.Binomial()) + fitted = model.fit(disp=0, method=self.weight_fit_method) _check_separation(fitted, label=out.replace("_model", "").replace("_", " ")) setattr(self, out, fitted) @@ -102,11 +106,16 @@ def _fit_numerator(self, WDT): fits.append(None) continue # Use logit for binary 0/1 censoring, mnlogit otherwise - if is_binary: - model = smf.logit(formula, DT_subset) + if is_binary and getattr(self, "glm_package", "statsmodels") == "glum": + model_fit = _fit_glum(formula, DT_subset) + elif is_binary: + model_fit = smf.logit(formula, DT_subset).fit( + disp=0, method=self.weight_fit_method + ) else: - model = smf.mnlogit(formula, DT_subset) - model_fit = model.fit(disp=0, method=self.weight_fit_method) + model_fit = smf.mnlogit(formula, DT_subset).fit( + disp=0, method=self.weight_fit_method + ) _check_separation(model_fit, label=f"numerator (level {level})") fits.append(model_fit) @@ -140,11 +149,16 @@ def _fit_denominator(self, WDT): fits.append(None) continue # Use logit for binary 0/1 censoring, mnlogit otherwise - if is_binary: - model = smf.logit(formula, DT_subset) + if is_binary and getattr(self, "glm_package", "statsmodels") == "glum": + model_fit = _fit_glum(formula, DT_subset) + elif is_binary: + model_fit = smf.logit(formula, DT_subset).fit( + disp=0, method=self.weight_fit_method + ) else: - model = smf.mnlogit(formula, DT_subset) - model_fit = model.fit(disp=0, method=self.weight_fit_method) + model_fit = smf.mnlogit(formula, DT_subset).fit( + disp=0, method=self.weight_fit_method + ) _check_separation(model_fit, label=f"denominator (level {level})") fits.append(model_fit) diff --git a/pyproject.toml b/pyproject.toml index 0f4c31f..adfcb10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.5" +version = "0.13.6" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} @@ -41,7 +41,9 @@ dependencies = [ "matplotlib", "pyarrow", "lifelines", - "joblib" + "joblib", + "glum>=3.4.1", + "scikit-survival>=0.25.0", ] [project.optional-dependencies] diff --git a/tests/test_compevent_diagnostics.py b/tests/test_compevent_diagnostics.py new file mode 100644 index 0000000..6df158f --- /dev/null +++ b/tests/test_compevent_diagnostics.py @@ -0,0 +1,89 @@ +import numpy as np +import polars as pl +import pytest + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _data_with_compevent(p=0.05, seed=42): + rng = np.random.default_rng(seed) + data = load_data("SEQdata") + return data.with_columns( + pl.Series("compevent", (rng.random(data.height) < p).astype(int)) + ) + + +def _build(with_ce=True, **opts): + data = _data_with_compevent() if with_ce else load_data("SEQdata") + params = dict(**opts) + if with_ce: + params["compevent_colname"] = "compevent" + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(**params), + ) + s.expand() + return s + + +def test_compevent_tables_present_and_accessible(): + s = _build() + for key in ("unique_compevent", "nonunique_compevent"): + assert key in s.diagnostics + + s.fit() + out = s.collect() + for key in ("unique_compevent", "nonunique_compevent"): + tbl = out.retrieve_data(key) + assert "tx_init_bas" in tbl.columns + assert "len" in tbl.columns + + +def test_nonunique_compevent_counts_intervals_per_arm(): + # Non-unique counts intervals (rows) with compevent == 1, per baseline arm. + s = _build() + expected = ( + s.DT.filter(pl.col("compevent") == 1) + .group_by("tx_init_bas") + .len() + .sort("tx_init_bas") + ) + assert s.diagnostics["nonunique_compevent"].equals(expected) + + +def test_unique_compevent_counts_distinct_subjects_per_arm(): + s = _build() + expected = ( + s.DT.filter(pl.col("compevent") == 1) + .group_by("tx_init_bas") + .agg(pl.col("ID").n_unique().alias("len")) + .sort("tx_init_bas") + ) + assert s.diagnostics["unique_compevent"].equals(expected) + u = s.diagnostics["unique_compevent"].sort("tx_init_bas") + nn = s.diagnostics["nonunique_compevent"].sort("tx_init_bas") + assert (u["len"] <= nn["len"]).all() + + +def test_compevent_tables_absent_when_no_compevent_configured(): + # When compevent_colname is None the diagnostics dict has no compevent keys + # and retrieve_data raises (data is None -> ValueError). + s = _build(with_ce=False) + assert "unique_compevent" not in s.diagnostics + assert "nonunique_compevent" not in s.diagnostics + + s.fit() + out = s.collect() + with pytest.raises(ValueError): + out.retrieve_data("unique_compevent") + with pytest.raises(ValueError): + out.retrieve_data("nonunique_compevent") diff --git a/tests/test_cox.py b/tests/test_cox.py new file mode 100644 index 0000000..a83a8c0 --- /dev/null +++ b/tests/test_cox.py @@ -0,0 +1,76 @@ +import pytest +from pytest import approx + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _hazard(cox_package, nboot=0, **opts): + data = load_data("SEQdata") + # seed always set: the hazard step simulates outcomes, so a fixed seed is + # required to isolate the backend difference from the random simulation. + params = dict(hazard_estimate=True, cox_package=cox_package, seed=42, **opts) + if nboot: + params["bootstrap_nboot"] = nboot + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(**params), + ) + s.expand() + if nboot: + s.bootstrap() + s.fit() + s.hazard() + return s + + +def test_cox_package_invalid_raises(): + with pytest.raises(ValueError, match="cox_package"): + SEQopts(cox_package="survival") + + +def test_sksurv_matches_lifelines_ITT(): + # Both backends fit the same univariate Cox partial likelihood with Efron + # tie handling, so the point hazard ratio should agree closely. + ll = _hazard("lifelines").hazard_ratio + sk = _hazard("scikit-survival").hazard_ratio + assert sk["Hazard ratio"][0] == approx(ll["Hazard ratio"][0], rel=1e-3, abs=1e-3) + + +def test_sksurv_matches_lifelines_bootstrap(): + ll = _hazard("lifelines", nboot=5).hazard_ratio + sk = _hazard("scikit-survival", nboot=5).hazard_ratio + for col in ("Hazard ratio", "LCI", "UCI"): + assert sk[col][0] == approx(ll[col][0], rel=1e-3, abs=1e-3) + + +def test_sksurv_subgroup_hazard_runs(): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + hazard_estimate=True, + cox_package="scikit-survival", + subgroup_colname="sex", + ), + ) + s.expand() + s.fit() + s.hazard() + assert s.hazard_ratio["Hazard ratio"].is_finite().all() diff --git a/tests/test_followup_diagnostics.py b/tests/test_followup_diagnostics.py new file mode 100644 index 0000000..d709082 --- /dev/null +++ b/tests/test_followup_diagnostics.py @@ -0,0 +1,70 @@ +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _build(method, **opts): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=method, + parameters=SEQopts(**opts), + ) + s.expand() + return s + + +def test_followup_tables_present_and_accessible(): + s = _build("ITT") + for key in ("unique_followup", "nonunique_followup"): + assert key in s.diagnostics + + s.fit() + out = s.collect() + for key in ("unique_followup", "nonunique_followup"): + tbl = out.retrieve_data(key) + assert "tx_init_bas" in tbl.columns + assert "len" in tbl.columns + + +def test_nonunique_followup_counts_outcome_fit_rows_per_arm(): + # Non-unique follow-up == follow-up intervals (rows) the outcome model is + # fit on, grouped by baseline treatment. + s = _build("ITT") + expected = s.DT.group_by("tx_init_bas").len().sort("tx_init_bas") + assert s.diagnostics["nonunique_followup"].equals(expected) + + +def test_censoring_followup_excludes_switched_rows(): + # Under method="censoring" the outcome model drops switch == 1 rows, so the + # follow-up tables must too. + s = _build("censoring", weighted=True, weight_preexpansion=True) + expected = ( + s.DT.filter(pl.col("switch") != 1) + .group_by("tx_init_bas") + .len() + .sort("tx_init_bas") + ) + assert s.diagnostics["nonunique_followup"].equals(expected) + + +def test_unique_followup_counts_distinct_subjects_per_arm(): + s = _build("ITT") + expected = ( + s.DT.group_by("tx_init_bas") + .agg(pl.col("ID").n_unique().alias("len")) + .sort("tx_init_bas") + ) + assert s.diagnostics["unique_followup"].equals(expected) + # never more unique subjects than follow-up intervals + u = s.diagnostics["unique_followup"].sort("tx_init_bas") + nn = s.diagnostics["nonunique_followup"].sort("tx_init_bas") + assert (u["len"] <= nn["len"]).all() diff --git a/tests/test_glum.py b/tests/test_glum.py new file mode 100644 index 0000000..48a0b48 --- /dev/null +++ b/tests/test_glum.py @@ -0,0 +1,143 @@ +import numpy as np +import pytest +from pytest import approx + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _fit(method, glm_package, dataset="SEQdata", **opts): + data = load_data(dataset) + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=method, + parameters=SEQopts(glm_package=glm_package, **opts), + ) + s.expand() + s.fit() + return s + + +def _outcome_coefs(s): + return list(s.outcome_model[0]["outcome"].params) + + +def test_glm_package_invalid_raises(): + with pytest.raises(ValueError, match="glm_package"): + SEQopts(glm_package="sklearn") + + +def test_glum_matches_statsmodels_ITT(): + sm = _outcome_coefs(_fit("ITT", "statsmodels")) + gl = _outcome_coefs(_fit("ITT", "glum")) + assert gl == approx(sm, rel=1e-2, abs=2e-3) + + +def test_glum_matches_statsmodels_censoring_preexpansion(): + opts = dict(weighted=True, weight_preexpansion=True) + sm = _outcome_coefs(_fit("censoring", "statsmodels", **opts)) + gl = _outcome_coefs(_fit("censoring", "glum", **opts)) + assert gl == approx(sm, rel=1e-2, abs=2e-3) + + +def test_glum_matches_statsmodels_censoring_postexpansion(): + opts = dict(weighted=True, weight_preexpansion=False) + sm = _outcome_coefs(_fit("censoring", "statsmodels", **opts)) + gl = _outcome_coefs(_fit("censoring", "glum", **opts)) + assert gl == approx(sm, rel=1e-2, abs=2e-3) + + +def test_glum_weight_models_are_glum_backed(): + # The numerator/denominator binary logit models should be fit by glum, + # not statsmodels, when glm_package="glum". + from pySEQTarget.helpers._glum_fit import _GlumFit + + s = _fit("censoring", "glum", weighted=True, weight_preexpansion=True) + assert all(isinstance(m, _GlumFit) for m in s.numerator_model if m is not None) + assert all(isinstance(m, _GlumFit) for m in s.denominator_model if m is not None) + + +def test_glum_LTFU_runs(): + # Near-separation model (large coefficients): the two optimizers diverge, + # so this is a smoke test that the glum censoring-weight path runs and + # yields finite coefficients rather than an exact-equivalence check. + s = _fit( + "ITT", + "glum", + dataset="SEQdata_LTFU", + weighted=True, + weight_preexpansion=True, + cense_colname="LTFU", + ) + coefs = np.array(_outcome_coefs(s)) + assert np.all(np.isfinite(coefs)) + assert s.cense_numerator_model is not None + assert s.cense_denominator_model is not None + + +def test_glum_summary_is_printable_and_consistent(): + # SEQoutput.summary() calls model.summary(); the glum wrapper must support + # it (regression for the short-course render). tables[1] should be the coef + # table, matching statsmodels' summary2 layout. + s = _fit("ITT", "glum") + model = s.outcome_model[0]["outcome"] + + smry = model.summary() + assert str(smry) # renders without error + + coef_col = model.summary2().tables[1]["Coef."].to_list() + assert coef_col == approx(list(model.params), rel=1e-9, abs=1e-9) + + +def test_glum_standard_errors_match_statsmodels(): + sm_model = _fit("ITT", "statsmodels").outcome_model[0]["outcome"] + gl_model = _fit("ITT", "glum").outcome_model[0]["outcome"] + assert list(gl_model.bse) == approx(list(sm_model.bse), rel=1e-2, abs=1e-3) + + +def test_glum_summary_via_seqoutput(): + # Mirrors the short-course usage: results.summary("numerator"/"outcome"). + s = _fit("censoring", "glum", weighted=True, weight_preexpansion=True) + out = s.collect() + for kind in ("numerator", "denominator", "outcome"): + summaries = out.summary(kind) + assert len(summaries) >= 1 + assert all(str(sm) for sm in summaries) + + +def test_glum_bootstrap_survival_matches_statsmodels(): + # Exercises the prediction-caching path in _survival_pred (transform=False + # and .model.data.design_info on the glum wrapper) and confirms the point + # risk estimates match the statsmodels backend. + common = dict(bootstrap_nboot=3, seed=42, km_curves=True) + + def risk_diff(pkg): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(glm_package=pkg, **common), + ) + s.expand() + s.bootstrap() + s.fit() + s.survival() + rd = s.risk_estimates["risk_difference"] + assert rd["RD 95% LCI"].null_count() == 0 + return rd["Risk Difference"].to_list() + + assert risk_diff("glum") == approx(risk_diff("statsmodels"), rel=1e-2, abs=2e-3) diff --git a/tests/test_hazard.py b/tests/test_hazard.py index 4ba98b1..f3f1854 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -1,6 +1,15 @@ +import importlib + +import numpy as np +import pytest + from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data +# the package __init__ rebinds the name "SEQuential" to the class, so import +# the module object explicitly to monkeypatch its module-level _outcome_fit. +seqmod = importlib.import_module("pySEQTarget.SEQuential") + def test_ITT_hazard(): data = load_data("SEQdata") @@ -62,3 +71,51 @@ def test_subgroup_hazard(): s.bootstrap() s.fit() s.hazard() + + +def test_hazard_survives_skipped_bootstrap_replicate(monkeypatch): + # When a bootstrap replicate fails to fit (singular matrix -> LinAlgError), + # fit() skips it, so outcome_model has fewer entries than _boot_samples. + # hazard() must still pair each resample with its own model and not crash + # with an IndexError. Regression for the glum-backend short-course failure. + data = load_data("SEQdata") + + real_outcome_fit = seqmod._outcome_fit + fail_on = 2 # _current_boot_idx of the replicate to fail (sample index 1) + + def flaky_outcome_fit(seq_self, *args, **kwargs): + if getattr(seq_self, "_current_boot_idx", None) == fail_on: + raise np.linalg.LinAlgError("A singular matrix detected: injected for test") + return real_outcome_fit(seq_self, *args, **kwargs) + + monkeypatch.setattr(seqmod, "_outcome_fit", flaky_outcome_fit) + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(bootstrap_nboot=4, seed=42, hazard_estimate=True), + ) + s.expand() + s.bootstrap() + with pytest.warns(UserWarning): + s.fit() + + # One replicate skipped: effective nboot drops and the failed sample index + # (fail_on - 1) is absent from the success map. + assert s.bootstrap_nboot == 3 + assert len(s.outcome_model) == 4 # main fit + 3 successful replicates + assert s._boot_sample_idx == [0, 2, 3] + + # Must not raise IndexError and must produce a finite HR with CI. + s.hazard() + hr = s.hazard_ratio + assert hr["Hazard ratio"][0] is not None and np.isfinite(hr["Hazard ratio"][0]) + assert hr["LCI"][0] is not None + assert hr["UCI"][0] is not None