From c68dcf9e2665d48d3ff16a41bd6e07b1a8fc5324 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 22 May 2026 13:07:27 +0100 Subject: [PATCH 01/11] Add glum backend for faster GLM fitting --- pySEQTarget/SEQopts.py | 3 ++ pySEQTarget/analysis/_outcome_fit.py | 9 ++++ pySEQTarget/helpers/_glum_fit.py | 72 ++++++++++++++++++++++++++++ pySEQTarget/weighting/_weight_fit.py | 26 ++++++---- pyproject.toml | 3 +- 5 files changed, 102 insertions(+), 11 deletions(-) create mode 100644 pySEQTarget/helpers/_glum_fit.py diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 4392b7e..ed7268e 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -133,6 +133,7 @@ class SEQopts: excused: bool = False excused_colnames: List[str] = field(default_factory=lambda: []) expand_only: bool = False + glm_package: Literal["statsmodels", "glum"] = "statsmodels" followup_class: bool = False followup_include: bool = True followup_max: int = None @@ -231,6 +232,8 @@ def _validate_choices(self): ) if self.bootstrap_CI_method not in ["se", "percentile"]: raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'") + if self.glm_package not in ["statsmodels", "glum"]: + raise ValueError("glm_package must be 'statsmodels' or 'glum'") def _normalize_formulas(self): for i in ( diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index af91176..f63e1e3 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -125,6 +125,15 @@ def _outcome_fit( full_formula = f"{outcome} ~ {formula}" + if getattr(self, "glm_package", "statsmodels") == "glum": + from ..helpers._glum_fit import _fit_glum + + return _fit_glum( + full_formula, + df_pd, + var_weights=df_pd[weight_col] if weighted else None, + ) + glm_kwargs = { "formula": full_formula, "data": df_pd, diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py new file mode 100644 index 0000000..eac370b --- /dev/null +++ b/pySEQTarget/helpers/_glum_fit.py @@ -0,0 +1,72 @@ +import numpy as np +import pandas as pd +import patsy +from glum import GeneralizedLinearRegressor + + +class _GlumFit: + """ + Wraps a fitted glum model exposing the statsmodels interface the rest of + the codebase expects: + .params (Series), .model.exog_names, .model.data.design_info, + .predict(df) and .predict(X_numpy, transform=False). + """ + + class _Data: + pass + + def __init__(self, glum_model, design_info, feature_names): + self._glum = glum_model + self._design_info = design_info + + # .model.data.design_info — used by _survival_pred and _fix_categories + _d = self._Data() + _d.design_info = design_info + self._data = _d + + # statsmodels convention: intercept first + all_coefs = np.concatenate([[glum_model.intercept_], glum_model.coef_]) + self.params = pd.Series(all_coefs, index=feature_names) + + @property + def model(self): + # makes .model.exog_names and .model.data.design_info work + return self + + @property + def data(self): + return self._data + + @property + def exog_names(self): + return list(self.params.index) + + def predict(self, data, transform=True): + if transform: + # data is a pandas DataFrame — build design matrix via stored patsy info + X = patsy.build_design_matrices( + [self._design_info], data, return_type="dataframe" + )[0] + X_arr = X.drop(columns=["Intercept"], errors="ignore").values + else: + # data is a pre-built numpy design matrix (includes intercept col — drop it) + X_arr = np.asarray(data)[:, 1:] + return self._glum.predict(X_arr) + + +def _fit_glum(formula, data, var_weights=None): + """Fit a binomial GLM with glum and return a _GlumFit wrapper.""" + y_mat, X_mat = patsy.dmatrices(formula, data, return_type="dataframe") + y_arr = y_mat.values.ravel() + design_info = X_mat.design_info + feature_names = list(X_mat.columns) # "Intercept" first, then predictors + X_arr = X_mat.drop(columns=["Intercept"]).values + + glm = GeneralizedLinearRegressor(family="binomial", fit_intercept=True) + + fit_kwargs = {} + if var_weights is not None: + fit_kwargs["sample_weight"] = np.asarray(var_weights) + + glm.fit(X_arr, y_arr, **fit_kwargs) + return _GlumFit(glm, design_info, feature_names) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index e6b39d0..ea0a541 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -2,6 +2,7 @@ import statsmodels.formula.api as smf from ..error._check_separation import _check_separation +from ..helpers._glum_fit import _fit_glum def _get_subset_for_level( @@ -47,8 +48,11 @@ def _fit_pair( setattr(self, out, None) continue formula = f"{outcome}~{rhs}" - model = smf.glm(formula, WDT, family=sm.families.Binomial()) - fitted = model.fit(disp=0, method=self.weight_fit_method) + if getattr(self, "glm_package", "statsmodels") == "glum": + fitted = _fit_glum(formula, WDT) + else: + model = smf.glm(formula, WDT, family=sm.families.Binomial()) + fitted = model.fit(disp=0, method=self.weight_fit_method) _check_separation(fitted, label=out.replace("_model", "").replace("_", " ")) setattr(self, out, fitted) @@ -102,11 +106,12 @@ def _fit_numerator(self, WDT): fits.append(None) continue # Use logit for binary 0/1 censoring, mnlogit otherwise - if is_binary: - model = smf.logit(formula, DT_subset) + if is_binary and getattr(self, "glm_package", "statsmodels") == "glum": + model_fit = _fit_glum(formula, DT_subset) + elif is_binary: + model_fit = smf.logit(formula, DT_subset).fit(disp=0, method=self.weight_fit_method) else: - model = smf.mnlogit(formula, DT_subset) - model_fit = model.fit(disp=0, method=self.weight_fit_method) + model_fit = smf.mnlogit(formula, DT_subset).fit(disp=0, method=self.weight_fit_method) _check_separation(model_fit, label=f"numerator (level {level})") fits.append(model_fit) @@ -140,11 +145,12 @@ def _fit_denominator(self, WDT): fits.append(None) continue # Use logit for binary 0/1 censoring, mnlogit otherwise - if is_binary: - model = smf.logit(formula, DT_subset) + if is_binary and getattr(self, "glm_package", "statsmodels") == "glum": + model_fit = _fit_glum(formula, DT_subset) + elif is_binary: + model_fit = smf.logit(formula, DT_subset).fit(disp=0, method=self.weight_fit_method) else: - model = smf.mnlogit(formula, DT_subset) - model_fit = model.fit(disp=0, method=self.weight_fit_method) + model_fit = smf.mnlogit(formula, DT_subset).fit(disp=0, method=self.weight_fit_method) _check_separation(model_fit, label=f"denominator (level {level})") fits.append(model_fit) diff --git a/pyproject.toml b/pyproject.toml index 0f4c31f..5785ba3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ dependencies = [ "matplotlib", "pyarrow", "lifelines", - "joblib" + "joblib", + "glum>=3.4.1", ] [project.optional-dependencies] From f4919d1a89c3db8c8cae4591c6f6380dd0abe9f7 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 22 May 2026 13:25:04 +0100 Subject: [PATCH 02/11] Add tests for glum backend --- tests/test_glum.py | 113 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 tests/test_glum.py diff --git a/tests/test_glum.py b/tests/test_glum.py new file mode 100644 index 0000000..93f9b94 --- /dev/null +++ b/tests/test_glum.py @@ -0,0 +1,113 @@ +import numpy as np +import pytest +from pytest import approx + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _fit(method, glm_package, dataset="SEQdata", **opts): + data = load_data(dataset) + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=method, + parameters=SEQopts(glm_package=glm_package, **opts), + ) + s.expand() + s.fit() + return s + + +def _outcome_coefs(s): + return list(s.outcome_model[0]["outcome"].params) + + +def test_glm_package_invalid_raises(): + with pytest.raises(ValueError, match="glm_package"): + SEQopts(glm_package="sklearn") + + +def test_glum_matches_statsmodels_ITT(): + sm = _outcome_coefs(_fit("ITT", "statsmodels")) + gl = _outcome_coefs(_fit("ITT", "glum")) + assert gl == approx(sm, rel=1e-2, abs=2e-3) + + +def test_glum_matches_statsmodels_censoring_preexpansion(): + opts = dict(weighted=True, weight_preexpansion=True) + sm = _outcome_coefs(_fit("censoring", "statsmodels", **opts)) + gl = _outcome_coefs(_fit("censoring", "glum", **opts)) + assert gl == approx(sm, rel=1e-2, abs=2e-3) + + +def test_glum_matches_statsmodels_censoring_postexpansion(): + opts = dict(weighted=True, weight_preexpansion=False) + sm = _outcome_coefs(_fit("censoring", "statsmodels", **opts)) + gl = _outcome_coefs(_fit("censoring", "glum", **opts)) + assert gl == approx(sm, rel=1e-2, abs=2e-3) + + +def test_glum_weight_models_are_glum_backed(): + # The numerator/denominator binary logit models should be fit by glum, + # not statsmodels, when glm_package="glum". + from pySEQTarget.helpers._glum_fit import _GlumFit + + s = _fit("censoring", "glum", weighted=True, weight_preexpansion=True) + assert all(isinstance(m, _GlumFit) for m in s.numerator_model if m is not None) + assert all(isinstance(m, _GlumFit) for m in s.denominator_model if m is not None) + + +def test_glum_LTFU_runs(): + # Near-separation model (large coefficients): the two optimizers diverge, + # so this is a smoke test that the glum censoring-weight path runs and + # yields finite coefficients rather than an exact-equivalence check. + s = _fit( + "ITT", + "glum", + dataset="SEQdata_LTFU", + weighted=True, + weight_preexpansion=True, + cense_colname="LTFU", + ) + coefs = np.array(_outcome_coefs(s)) + assert np.all(np.isfinite(coefs)) + assert s.cense_numerator_model is not None + assert s.cense_denominator_model is not None + + +def test_glum_bootstrap_survival_matches_statsmodels(): + # Exercises the prediction-caching path in _survival_pred (transform=False + # and .model.data.design_info on the glum wrapper) and confirms the point + # risk estimates match the statsmodels backend. + common = dict(bootstrap_nboot=3, seed=42, km_curves=True) + + def risk_diff(pkg): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(glm_package=pkg, **common), + ) + s.expand() + s.bootstrap() + s.fit() + s.survival() + rd = s.risk_estimates["risk_difference"] + assert rd["RD 95% LCI"].null_count() == 0 + return rd["Risk Difference"].to_list() + + assert risk_diff("glum") == approx(risk_diff("statsmodels"), rel=1e-2, abs=2e-3) From b1a1f080e2d93d1047947e2836d555de344bfa29 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 22 May 2026 13:39:24 +0100 Subject: [PATCH 03/11] Fix hazard crash when a bootstrap replicate is skipped MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a bootstrap replicate fails to fit (e.g. a singular matrix raising LinAlgError, which the glum solver hits more readily than statsmodels' IRLS), fit() skips it, leaving outcome_model shorter than _boot_samples. hazard() previously assumed a 1:1 ordering between the two: it looped over range(len(_boot_samples)) and indexed outcome_model[boot_idx + 1], which raised IndexError once a skip had occurred and, before that, silently paired every replicate after the skipped one with the wrong resample's model — producing incorrect hazard CIs. bootstrap_loop now records _boot_sample_idx, the original _boot_samples index for each successfully fitted replicate (appended in lockstep with the models, so it is correct for both the serial and parallel paths). hazard() iterates that map, pairing outcome_model[model_pos + 1] with _boot_samples[sample_idx], which fixes both the crash and the misalignment. survival() was unaffected since it iterates outcome_model directly and never indexes _boot_samples. Adds a regression test that injects a LinAlgError on one replicate and asserts hazard() returns a finite HR with CI instead of crashing. --- pySEQTarget/analysis/_hazard.py | 17 ++++++--- pySEQTarget/helpers/_bootstrap.py | 9 +++++ tests/test_hazard.py | 59 +++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 4 deletions(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 64613fd..f1234f5 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -35,10 +35,17 @@ def _calculate_hazard_single(self, data, idx=None, val=None): if self.bootstrap_nboot > 0: boot_log_hrs = [] - for boot_idx in range(len(self._boot_samples)): + # outcome_model[model_pos + 1] was fit on _boot_samples[sample_idx]; + # skipped replicates make this mapping non-identity, so iterate it + # explicitly rather than assuming model index == sample index. + boot_sample_idx = getattr(self, "_boot_sample_idx", None) + if boot_sample_idx is None: + boot_sample_idx = list(range(len(self._boot_samples))) + + for model_pos, sample_idx in enumerate(boot_sample_idx): if self.seed is not None: - self._rng = np.random.RandomState(self.seed + boot_idx + 1) - id_counts = self._boot_samples[boot_idx] + self._rng = np.random.RandomState(self.seed + sample_idx + 1) + id_counts = self._boot_samples[sample_idx] counts = pl.DataFrame( { @@ -55,7 +62,9 @@ def _calculate_hazard_single(self, data, idx=None, val=None): .collect() ) - boot_log_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng) + boot_log_hr = _hazard_handler( + self, boot_data, idx, model_pos + 1, self._rng + ) if boot_log_hr is not None and not np.isnan(boot_log_hr): boot_log_hrs.append(boot_log_hr) diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index b59fb61..a8dd844 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -118,12 +118,14 @@ def wrapper(self, *args, **kwargs): for i in range(nboot) } skipped = 0 + boot_sample_idx = [] for j in tqdm( as_completed(futures), total=nboot, desc="Bootstrapping..." ): boot_idx = futures[j] try: results.append(j.result()) + boot_sample_idx.append(boot_idx) except np.linalg.LinAlgError as e: skipped += 1 warnings.warn( @@ -144,6 +146,7 @@ def wrapper(self, *args, **kwargs): original_DT_ref = original_DT skipped = 0 + boot_sample_idx = [] for i in tqdm(range(nboot), desc="Bootstrapping..."): self._current_boot_idx = i + 1 if seed is not None: @@ -156,6 +159,7 @@ def wrapper(self, *args, **kwargs): try: boot_fit = method(self, *args, **kwargs) results.append(boot_fit) + boot_sample_idx.append(i) except np.linalg.LinAlgError as e: skipped += 1 warnings.warn( @@ -167,6 +171,11 @@ def wrapper(self, *args, **kwargs): self.DT = self._offloader.load_dataframe(original_DT_ref) + # Maps each fitted bootstrap model (results[1:]) back to its + # original _boot_samples index. Skipped replicates leave gaps, so + # downstream consumers (e.g. hazard) must use this to pair a + # resample with its model rather than assuming a 1:1 ordering. + self._boot_sample_idx = boot_sample_idx self.bootstrap_nboot = len(results) - 1 if skipped > 0: warnings.warn( diff --git a/tests/test_hazard.py b/tests/test_hazard.py index 4ba98b1..5e7ade1 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -1,6 +1,15 @@ +import importlib + +import numpy as np +import pytest + from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data +# the package __init__ rebinds the name "SEQuential" to the class, so import +# the module object explicitly to monkeypatch its module-level _outcome_fit. +seqmod = importlib.import_module("pySEQTarget.SEQuential") + def test_ITT_hazard(): data = load_data("SEQdata") @@ -62,3 +71,53 @@ def test_subgroup_hazard(): s.bootstrap() s.fit() s.hazard() + + +def test_hazard_survives_skipped_bootstrap_replicate(monkeypatch): + # When a bootstrap replicate fails to fit (singular matrix -> LinAlgError), + # fit() skips it, so outcome_model has fewer entries than _boot_samples. + # hazard() must still pair each resample with its own model and not crash + # with an IndexError. Regression for the glum-backend short-course failure. + data = load_data("SEQdata") + + real_outcome_fit = seqmod._outcome_fit + fail_on = 2 # _current_boot_idx of the replicate to fail (sample index 1) + + def flaky_outcome_fit(seq_self, *args, **kwargs): + if getattr(seq_self, "_current_boot_idx", None) == fail_on: + raise np.linalg.LinAlgError( + "A singular matrix detected: injected for test" + ) + return real_outcome_fit(seq_self, *args, **kwargs) + + monkeypatch.setattr(seqmod, "_outcome_fit", flaky_outcome_fit) + + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(bootstrap_nboot=4, seed=42, hazard_estimate=True), + ) + s.expand() + s.bootstrap() + with pytest.warns(UserWarning): + s.fit() + + # One replicate skipped: effective nboot drops and the failed sample index + # (fail_on - 1) is absent from the success map. + assert s.bootstrap_nboot == 3 + assert len(s.outcome_model) == 4 # main fit + 3 successful replicates + assert s._boot_sample_idx == [0, 2, 3] + + # Must not raise IndexError and must produce a finite HR with CI. + s.hazard() + hr = s.hazard_ratio + assert hr["Hazard ratio"][0] is not None and np.isfinite(hr["Hazard ratio"][0]) + assert hr["LCI"][0] is not None + assert hr["UCI"][0] is not None From 0a6d4b1502a1d844513de6b3d4b2838abe8f20d3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 22 May 2026 12:39:49 +0000 Subject: [PATCH 04/11] Auto-format code --- pySEQTarget/weighting/_weight_fit.py | 16 ++++++++++++---- tests/test_hazard.py | 4 +--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index ea0a541..8fcb0bf 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -109,9 +109,13 @@ def _fit_numerator(self, WDT): if is_binary and getattr(self, "glm_package", "statsmodels") == "glum": model_fit = _fit_glum(formula, DT_subset) elif is_binary: - model_fit = smf.logit(formula, DT_subset).fit(disp=0, method=self.weight_fit_method) + model_fit = smf.logit(formula, DT_subset).fit( + disp=0, method=self.weight_fit_method + ) else: - model_fit = smf.mnlogit(formula, DT_subset).fit(disp=0, method=self.weight_fit_method) + model_fit = smf.mnlogit(formula, DT_subset).fit( + disp=0, method=self.weight_fit_method + ) _check_separation(model_fit, label=f"numerator (level {level})") fits.append(model_fit) @@ -148,9 +152,13 @@ def _fit_denominator(self, WDT): if is_binary and getattr(self, "glm_package", "statsmodels") == "glum": model_fit = _fit_glum(formula, DT_subset) elif is_binary: - model_fit = smf.logit(formula, DT_subset).fit(disp=0, method=self.weight_fit_method) + model_fit = smf.logit(formula, DT_subset).fit( + disp=0, method=self.weight_fit_method + ) else: - model_fit = smf.mnlogit(formula, DT_subset).fit(disp=0, method=self.weight_fit_method) + model_fit = smf.mnlogit(formula, DT_subset).fit( + disp=0, method=self.weight_fit_method + ) _check_separation(model_fit, label=f"denominator (level {level})") fits.append(model_fit) diff --git a/tests/test_hazard.py b/tests/test_hazard.py index 5e7ade1..f3f1854 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -85,9 +85,7 @@ def test_hazard_survives_skipped_bootstrap_replicate(monkeypatch): def flaky_outcome_fit(seq_self, *args, **kwargs): if getattr(seq_self, "_current_boot_idx", None) == fail_on: - raise np.linalg.LinAlgError( - "A singular matrix detected: injected for test" - ) + raise np.linalg.LinAlgError("A singular matrix detected: injected for test") return real_outcome_fit(seq_self, *args, **kwargs) monkeypatch.setattr(seqmod, "_outcome_fit", flaky_outcome_fit) From 733281f6e510b51b38914e58a5b4aa11c5b0c36a Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 22 May 2026 13:54:44 +0100 Subject: [PATCH 05/11] Add `.summary()`, `.summary2()`, `.bse`, and `.cov_params()` to `_GlumFit`: --- pySEQTarget/helpers/_glum_fit.py | 77 +++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index eac370b..134a084 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -7,17 +7,25 @@ class _GlumFit: """ Wraps a fitted glum model exposing the statsmodels interface the rest of - the codebase expects: + the codebase (and users) expect: .params (Series), .model.exog_names, .model.data.design_info, - .predict(df) and .predict(X_numpy, transform=False). + .predict(df) / .predict(X_numpy, transform=False), + .bse, .summary(), .summary2(). + + Standard errors are derived lazily from the stored design matrix using the + GLM asymptotic covariance (X' W X)^-1, which matches statsmodels for the + binomial/logit family (incl. var_weights). The design matrix is retained + just like statsmodels keeps model.exog, so memory use is comparable. """ class _Data: pass - def __init__(self, glum_model, design_info, feature_names): + def __init__(self, glum_model, design_info, feature_names, X_design, sample_weight): self._glum = glum_model self._design_info = design_info + self._X_design = X_design # includes the intercept column + self._sample_weight = sample_weight # .model.data.design_info — used by _survival_pred and _fix_categories _d = self._Data() @@ -53,6 +61,62 @@ def predict(self, data, transform=True): X_arr = np.asarray(data)[:, 1:] return self._glum.predict(X_arr) + def cov_params(self): + X = self._X_design + mu = self._glum.predict(X[:, 1:]) + w = mu * (1.0 - mu) + if self._sample_weight is not None: + w = w * np.asarray(self._sample_weight) + return np.linalg.pinv(X.T @ (w[:, None] * X)) + + @property + def bse(self): + return pd.Series(np.sqrt(np.diag(self.cov_params())), index=self.params.index) + + def _coef_table(self): + from scipy import stats + + coef = self.params.values + se = self.bse.values + with np.errstate(divide="ignore", invalid="ignore"): + z = coef / se + pvals = 2.0 * stats.norm.sf(np.abs(z)) + crit = stats.norm.ppf(0.975) + return pd.DataFrame( + { + "Coef.": coef, + "Std.Err.": se, + "z": z, + "P>|z|": pvals, + "[0.025": coef - crit * se, + "0.975]": coef + crit * se, + }, + index=list(self.params.index), + ) + + def summary2(self): + from statsmodels.iolib.summary2 import Summary + + info = pd.DataFrame( + { + " ": [ + "GLM (glum backend)", + "Binomial", + "logit", + str(self._X_design.shape[0]), + ] + }, + index=["Model:", "Family:", "Link:", "No. Observations:"], + ) + smry = Summary() + smry.add_title("Generalized Linear Model Regression Results") + smry.add_df(info, header=False) + smry.add_df(self._coef_table()) + return smry + + # statsmodels exposes both; the codebase/practical use either, so alias them. + summary = summary2 + def _fit_glum(formula, data, var_weights=None): """Fit a binomial GLM with glum and return a _GlumFit wrapper.""" @@ -60,13 +124,16 @@ def _fit_glum(formula, data, var_weights=None): y_arr = y_mat.values.ravel() design_info = X_mat.design_info feature_names = list(X_mat.columns) # "Intercept" first, then predictors + X_design = X_mat.values # includes intercept column (for covariance) X_arr = X_mat.drop(columns=["Intercept"]).values glm = GeneralizedLinearRegressor(family="binomial", fit_intercept=True) + sample_weight = None fit_kwargs = {} if var_weights is not None: - fit_kwargs["sample_weight"] = np.asarray(var_weights) + sample_weight = np.asarray(var_weights) + fit_kwargs["sample_weight"] = sample_weight glm.fit(X_arr, y_arr, **fit_kwargs) - return _GlumFit(glm, design_info, feature_names) + return _GlumFit(glm, design_info, feature_names, X_design, sample_weight) From acefaab75301e6931f7e04a5e073ca757c28b122 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 22 May 2026 13:55:00 +0100 Subject: [PATCH 06/11] Add tests for glum summary --- tests/test_glum.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_glum.py b/tests/test_glum.py index 93f9b94..48a0b48 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -82,6 +82,36 @@ def test_glum_LTFU_runs(): assert s.cense_denominator_model is not None +def test_glum_summary_is_printable_and_consistent(): + # SEQoutput.summary() calls model.summary(); the glum wrapper must support + # it (regression for the short-course render). tables[1] should be the coef + # table, matching statsmodels' summary2 layout. + s = _fit("ITT", "glum") + model = s.outcome_model[0]["outcome"] + + smry = model.summary() + assert str(smry) # renders without error + + coef_col = model.summary2().tables[1]["Coef."].to_list() + assert coef_col == approx(list(model.params), rel=1e-9, abs=1e-9) + + +def test_glum_standard_errors_match_statsmodels(): + sm_model = _fit("ITT", "statsmodels").outcome_model[0]["outcome"] + gl_model = _fit("ITT", "glum").outcome_model[0]["outcome"] + assert list(gl_model.bse) == approx(list(sm_model.bse), rel=1e-2, abs=1e-3) + + +def test_glum_summary_via_seqoutput(): + # Mirrors the short-course usage: results.summary("numerator"/"outcome"). + s = _fit("censoring", "glum", weighted=True, weight_preexpansion=True) + out = s.collect() + for kind in ("numerator", "denominator", "outcome"): + summaries = out.summary(kind) + assert len(summaries) >= 1 + assert all(str(sm) for sm in summaries) + + def test_glum_bootstrap_survival_matches_statsmodels(): # Exercises the prediction-caching path in _survival_pred (transform=False # and .model.data.design_info on the glum wrapper) and confirms the point From e58bd03b324006f050f92a9ccf16b94cec542abc Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 22 May 2026 15:16:08 +0100 Subject: [PATCH 07/11] Add per-arm follow-up counts to diagnostics --- pySEQTarget/SEQoutput.py | 9 +++- pySEQTarget/expansion/_diagnostics.py | 26 ++++++++++ tests/test_followup_diagnostics.py | 70 +++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 tests/test_followup_diagnostics.py diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index 0bff491..ce3a547 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -102,6 +102,8 @@ def retrieve_data( "risk_difference", "unique_outcomes", "nonunique_outcomes", + "unique_followup", + "nonunique_followup", "unique_switches", "nonunique_switches", ] @@ -112,7 +114,8 @@ def retrieve_data( :param type: Data which you would like to access, ['km_data', 'hazard', 'risk_ratio', 'risk_difference', 'unique_outcomes', - 'nonunique_outcomes', 'unique_switches', 'nonunique_switches'] + 'nonunique_outcomes', 'unique_followup', 'nonunique_followup', + 'unique_switches', 'nonunique_switches'] :type type: str """ match type: @@ -126,6 +129,10 @@ def retrieve_data( data = self.diagnostic_tables["unique_outcomes"] case "nonunique_outcomes": data = self.diagnostic_tables["nonunique_outcomes"] + case "unique_followup": + data = self.diagnostic_tables["unique_followup"] + case "nonunique_followup": + data = self.diagnostic_tables["nonunique_followup"] case "unique_switches": if self.diagnostic_tables.has_key("unique_switches"): data = self.diagnostic_tables["unique_switches"] diff --git a/pySEQTarget/expansion/_diagnostics.py b/pySEQTarget/expansion/_diagnostics.py index 178062a..d04e898 100644 --- a/pySEQTarget/expansion/_diagnostics.py +++ b/pySEQTarget/expansion/_diagnostics.py @@ -6,6 +6,10 @@ def _diagnostics(self): nonunique_out = _outcome_diag(self, unique=False) out = {"unique_outcomes": unique_out, "nonunique_outcomes": nonunique_out} + unique_fu = _followup_diag(self, unique=True) + nonunique_fu = _followup_diag(self, unique=False) + out.update({"unique_followup": unique_fu, "nonunique_followup": nonunique_fu}) + if self.method == "censoring": unique_switch = _switch_diag(self, unique=True) nonunique_switch = _switch_diag(self, unique=False) @@ -30,6 +34,28 @@ def _outcome_diag(self, unique): return out +def _followup_diag(self, unique): + """ + Follow-up per treatment arm, grouped by the baseline treatment value, over + the rows the outcome model is fit on (under method="censoring" the switched + rows are dropped, matching _outcome_fit). ``unique`` counts distinct subjects + contributing follow-up to the arm; otherwise counts follow-up intervals + (rows / person-time), so non-unique outcome counts divided by these give + per-arm event rates. A subject appearing in both arms is counted in each. + """ + tx_bas = f"{self.treatment_col}{self.indicator_baseline}" + data = self.DT + if self.method == "censoring": + data = data.filter(pl.col("switch") != 1) + + if unique: + out = data.group_by(tx_bas).agg(pl.col(self.id_col).n_unique().alias("len")) + else: + out = data.group_by(tx_bas).len() + + return out.sort(tx_bas) + + def _switch_diag(self, unique): if not self.excused: data = self.DT.with_columns(pl.lit(False).alias("isExcused")) diff --git a/tests/test_followup_diagnostics.py b/tests/test_followup_diagnostics.py new file mode 100644 index 0000000..d709082 --- /dev/null +++ b/tests/test_followup_diagnostics.py @@ -0,0 +1,70 @@ +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _build(method, **opts): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=method, + parameters=SEQopts(**opts), + ) + s.expand() + return s + + +def test_followup_tables_present_and_accessible(): + s = _build("ITT") + for key in ("unique_followup", "nonunique_followup"): + assert key in s.diagnostics + + s.fit() + out = s.collect() + for key in ("unique_followup", "nonunique_followup"): + tbl = out.retrieve_data(key) + assert "tx_init_bas" in tbl.columns + assert "len" in tbl.columns + + +def test_nonunique_followup_counts_outcome_fit_rows_per_arm(): + # Non-unique follow-up == follow-up intervals (rows) the outcome model is + # fit on, grouped by baseline treatment. + s = _build("ITT") + expected = s.DT.group_by("tx_init_bas").len().sort("tx_init_bas") + assert s.diagnostics["nonunique_followup"].equals(expected) + + +def test_censoring_followup_excludes_switched_rows(): + # Under method="censoring" the outcome model drops switch == 1 rows, so the + # follow-up tables must too. + s = _build("censoring", weighted=True, weight_preexpansion=True) + expected = ( + s.DT.filter(pl.col("switch") != 1) + .group_by("tx_init_bas") + .len() + .sort("tx_init_bas") + ) + assert s.diagnostics["nonunique_followup"].equals(expected) + + +def test_unique_followup_counts_distinct_subjects_per_arm(): + s = _build("ITT") + expected = ( + s.DT.group_by("tx_init_bas") + .agg(pl.col("ID").n_unique().alias("len")) + .sort("tx_init_bas") + ) + assert s.diagnostics["unique_followup"].equals(expected) + # never more unique subjects than follow-up intervals + u = s.diagnostics["unique_followup"].sort("tx_init_bas") + nn = s.diagnostics["nonunique_followup"].sort("tx_init_bas") + assert (u["len"] <= nn["len"]).all() From 391c09eb1cffbbca37ec724ea73571b6b58710c6 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Sat, 23 May 2026 17:49:52 +0100 Subject: [PATCH 08/11] Add scikit-survival backend for Cox model fitting Add a cox_package option (default "lifelines") to fit the univariate Cox model in the hazard step with either lifelines or scikit-survival. The Cox fit is extracted into a _cox_log_hr dispatcher: the lifelines path is unchanged, while the scikit-survival path builds the structured survival array and fits CoxPHSurvivalAnalysis with ties="efron" to match lifelines - this matters because integer follow-up produces many tied event times and the default Breslow handling would diverge. With matching tie handling and a fixed seed the two backends agree to ~1e-9 on the point hazard ratio and the bootstrap CIs match. Re-adds scikit-survival as a dependency and adds tests covering validation, ITT and bootstrap equivalence with lifelines, and the subgroup path. --- pySEQTarget/SEQopts.py | 3 ++ pySEQTarget/analysis/_hazard.py | 50 +++++++++++++--------- pyproject.toml | 1 + tests/test_cox.py | 76 +++++++++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 19 deletions(-) create mode 100644 tests/test_cox.py diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index ed7268e..3f12a91 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -129,6 +129,7 @@ class SEQopts: cense_eligible_colname: Optional[str] = None compevent_colname: Optional[str] = None covariates: Optional[str] = None + cox_package: Literal["lifelines", "scikit-survival"] = "lifelines" denominator: Optional[str] = None excused: bool = False excused_colnames: List[str] = field(default_factory=lambda: []) @@ -234,6 +235,8 @@ def _validate_choices(self): raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'") if self.glm_package not in ["statsmodels", "glum"]: raise ValueError("glm_package must be 'statsmodels' or 'glum'") + if self.cox_package not in ["lifelines", "scikit-survival"]: + raise ValueError("cox_package must be 'lifelines' or 'scikit-survival'") def _normalize_formulas(self): for i in ( diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index f1234f5..a8331bf 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -189,6 +189,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng): sim_data = sim_data.with_columns([pl.col("outcome").alias("event")]) sim_data_pd = sim_data.to_pandas() + tx_bas = f"{self.treatment_col}{self.indicator_baseline}" try: # COXPHFITTER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow() @@ -196,30 +197,41 @@ def _hazard_handler(self, data, idx, boot_idx, rng): if ce_model is not None: cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy() cox_data["event_binary"] = (cox_data["event"] == 1).astype(int) - - cph = CoxPHFitter() - cph.fit( - cox_data, - duration_col="followup", - event_col="event_binary", - formula=f"`{self.treatment_col}{self.indicator_baseline}`", - ) - else: - cph = CoxPHFitter() - cph.fit( - sim_data_pd, - duration_col="followup", - event_col="event", - formula=f"`{self.treatment_col}{self.indicator_baseline}`", - ) - - log_hr = cph.params_.values[0] - return log_hr + return _cox_log_hr(self, cox_data, "followup", "event_binary", tx_bas) + return _cox_log_hr(self, sim_data_pd, "followup", "event", tx_bas) except Exception as e: print(f"Cox model fitting failed: {e}") return None +def _cox_log_hr(self, data_pd, duration_col, event_col, covariate_col): + """ + Fit a univariate Cox model (single covariate = baseline treatment) and + return the log hazard ratio, dispatching on self.cox_package. scikit-survival + uses Efron tie handling to match lifelines, which matters here because + integer follow-up produces many tied event times. + """ + if getattr(self, "cox_package", "lifelines") == "scikit-survival": + from sksurv.linear_model import CoxPHSurvivalAnalysis + + y = np.empty(len(data_pd), dtype=[("event", bool), ("time", "float64")]) + y["event"] = data_pd[event_col].to_numpy().astype(bool) + y["time"] = data_pd[duration_col].to_numpy().astype(float) + X = data_pd[[covariate_col]].to_numpy().astype(float) + cox = CoxPHSurvivalAnalysis(ties="efron") + cox.fit(X, y) + return float(cox.coef_[0]) + + cph = CoxPHFitter() + cph.fit( + data_pd, + duration_col=duration_col, + event_col=event_col, + formula=f"`{covariate_col}`", + ) + return cph.params_.values[0] + + def _create_hazard_output(hr, lci, uci, val, self): if lci is not None and uci is not None: output = pl.DataFrame( diff --git a/pyproject.toml b/pyproject.toml index 5785ba3..0f7c42d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "lifelines", "joblib", "glum>=3.4.1", + "scikit-survival>=0.25.0", ] [project.optional-dependencies] diff --git a/tests/test_cox.py b/tests/test_cox.py new file mode 100644 index 0000000..a83a8c0 --- /dev/null +++ b/tests/test_cox.py @@ -0,0 +1,76 @@ +import pytest +from pytest import approx + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _hazard(cox_package, nboot=0, **opts): + data = load_data("SEQdata") + # seed always set: the hazard step simulates outcomes, so a fixed seed is + # required to isolate the backend difference from the random simulation. + params = dict(hazard_estimate=True, cox_package=cox_package, seed=42, **opts) + if nboot: + params["bootstrap_nboot"] = nboot + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(**params), + ) + s.expand() + if nboot: + s.bootstrap() + s.fit() + s.hazard() + return s + + +def test_cox_package_invalid_raises(): + with pytest.raises(ValueError, match="cox_package"): + SEQopts(cox_package="survival") + + +def test_sksurv_matches_lifelines_ITT(): + # Both backends fit the same univariate Cox partial likelihood with Efron + # tie handling, so the point hazard ratio should agree closely. + ll = _hazard("lifelines").hazard_ratio + sk = _hazard("scikit-survival").hazard_ratio + assert sk["Hazard ratio"][0] == approx(ll["Hazard ratio"][0], rel=1e-3, abs=1e-3) + + +def test_sksurv_matches_lifelines_bootstrap(): + ll = _hazard("lifelines", nboot=5).hazard_ratio + sk = _hazard("scikit-survival", nboot=5).hazard_ratio + for col in ("Hazard ratio", "LCI", "UCI"): + assert sk[col][0] == approx(ll[col][0], rel=1e-3, abs=1e-3) + + +def test_sksurv_subgroup_hazard_runs(): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + hazard_estimate=True, + cox_package="scikit-survival", + subgroup_colname="sex", + ), + ) + s.expand() + s.fit() + s.hazard() + assert s.hazard_ratio["Hazard ratio"].is_finite().all() From 79a7e37dcf0bd03c8ebdf70dee06930e872dc07b Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 28 May 2026 13:22:41 +0100 Subject: [PATCH 09/11] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0f7c42d..adfcb10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.5" +version = "0.13.6" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} From 8b21411a17c62452c04cdc594eeda244b93af77f Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 28 May 2026 13:22:49 +0100 Subject: [PATCH 10/11] Add per-arm competing-event diagnostic tables --- pySEQTarget/SEQoutput.py | 7 +++ pySEQTarget/expansion/_diagnostics.py | 22 +++++++ tests/test_compevent_diagnostics.py | 89 +++++++++++++++++++++++++++ 3 files changed, 118 insertions(+) create mode 100644 tests/test_compevent_diagnostics.py diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index ce3a547..cecf28c 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -104,6 +104,8 @@ def retrieve_data( "nonunique_outcomes", "unique_followup", "nonunique_followup", + "unique_compevent", + "nonunique_compevent", "unique_switches", "nonunique_switches", ] @@ -115,6 +117,7 @@ def retrieve_data( :param type: Data which you would like to access, ['km_data', 'hazard', 'risk_ratio', 'risk_difference', 'unique_outcomes', 'nonunique_outcomes', 'unique_followup', 'nonunique_followup', + 'unique_compevent', 'nonunique_compevent', 'unique_switches', 'nonunique_switches'] :type type: str """ @@ -133,6 +136,10 @@ def retrieve_data( data = self.diagnostic_tables["unique_followup"] case "nonunique_followup": data = self.diagnostic_tables["nonunique_followup"] + case "unique_compevent": + data = self.diagnostic_tables.get("unique_compevent") + case "nonunique_compevent": + data = self.diagnostic_tables.get("nonunique_compevent") case "unique_switches": if self.diagnostic_tables.has_key("unique_switches"): data = self.diagnostic_tables["unique_switches"] diff --git a/pySEQTarget/expansion/_diagnostics.py b/pySEQTarget/expansion/_diagnostics.py index d04e898..56dcb0c 100644 --- a/pySEQTarget/expansion/_diagnostics.py +++ b/pySEQTarget/expansion/_diagnostics.py @@ -10,6 +10,11 @@ def _diagnostics(self): nonunique_fu = _followup_diag(self, unique=False) out.update({"unique_followup": unique_fu, "nonunique_followup": nonunique_fu}) + if self.compevent_colname is not None: + unique_ce = _compevent_diag(self, unique=True) + nonunique_ce = _compevent_diag(self, unique=False) + out.update({"unique_compevent": unique_ce, "nonunique_compevent": nonunique_ce}) + if self.method == "censoring": unique_switch = _switch_diag(self, unique=True) nonunique_switch = _switch_diag(self, unique=False) @@ -56,6 +61,23 @@ def _followup_diag(self, unique): return out.sort(tx_bas) +def _compevent_diag(self, unique): + """ + Competing-event counts per treatment arm. Counts rows where the configured + compevent column is 1, grouped by the baseline treatment value. ``unique`` + counts distinct subjects who experience the competing event in each arm; + otherwise counts the intervals (rows). A subject appearing with a competing + event in both arms is counted in each. + """ + tx_bas = f"{self.treatment_col}{self.indicator_baseline}" + data = self.DT.filter(pl.col(self.compevent_colname) == 1) + if unique: + out = data.group_by(tx_bas).agg(pl.col(self.id_col).n_unique().alias("len")) + else: + out = data.group_by(tx_bas).len() + return out.sort(tx_bas) + + def _switch_diag(self, unique): if not self.excused: data = self.DT.with_columns(pl.lit(False).alias("isExcused")) diff --git a/tests/test_compevent_diagnostics.py b/tests/test_compevent_diagnostics.py new file mode 100644 index 0000000..6df158f --- /dev/null +++ b/tests/test_compevent_diagnostics.py @@ -0,0 +1,89 @@ +import numpy as np +import polars as pl +import pytest + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _data_with_compevent(p=0.05, seed=42): + rng = np.random.default_rng(seed) + data = load_data("SEQdata") + return data.with_columns( + pl.Series("compevent", (rng.random(data.height) < p).astype(int)) + ) + + +def _build(with_ce=True, **opts): + data = _data_with_compevent() if with_ce else load_data("SEQdata") + params = dict(**opts) + if with_ce: + params["compevent_colname"] = "compevent" + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(**params), + ) + s.expand() + return s + + +def test_compevent_tables_present_and_accessible(): + s = _build() + for key in ("unique_compevent", "nonunique_compevent"): + assert key in s.diagnostics + + s.fit() + out = s.collect() + for key in ("unique_compevent", "nonunique_compevent"): + tbl = out.retrieve_data(key) + assert "tx_init_bas" in tbl.columns + assert "len" in tbl.columns + + +def test_nonunique_compevent_counts_intervals_per_arm(): + # Non-unique counts intervals (rows) with compevent == 1, per baseline arm. + s = _build() + expected = ( + s.DT.filter(pl.col("compevent") == 1) + .group_by("tx_init_bas") + .len() + .sort("tx_init_bas") + ) + assert s.diagnostics["nonunique_compevent"].equals(expected) + + +def test_unique_compevent_counts_distinct_subjects_per_arm(): + s = _build() + expected = ( + s.DT.filter(pl.col("compevent") == 1) + .group_by("tx_init_bas") + .agg(pl.col("ID").n_unique().alias("len")) + .sort("tx_init_bas") + ) + assert s.diagnostics["unique_compevent"].equals(expected) + u = s.diagnostics["unique_compevent"].sort("tx_init_bas") + nn = s.diagnostics["nonunique_compevent"].sort("tx_init_bas") + assert (u["len"] <= nn["len"]).all() + + +def test_compevent_tables_absent_when_no_compevent_configured(): + # When compevent_colname is None the diagnostics dict has no compevent keys + # and retrieve_data raises (data is None -> ValueError). + s = _build(with_ce=False) + assert "unique_compevent" not in s.diagnostics + assert "nonunique_compevent" not in s.diagnostics + + s.fit() + out = s.collect() + with pytest.raises(ValueError): + out.retrieve_data("unique_compevent") + with pytest.raises(ValueError): + out.retrieve_data("nonunique_compevent") From 9d5fef3bf9e8b20458c42d839cfe44fd70ca0f76 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Sun, 31 May 2026 12:50:17 +0200 Subject: [PATCH 11/11] Minimized some of the easy returned properties --- pySEQTarget/helpers/_glum_fit.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index 134a084..f8894b1 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -1,3 +1,5 @@ +import types + import numpy as np import pandas as pd import patsy @@ -18,37 +20,22 @@ class _GlumFit: just like statsmodels keeps model.exog, so memory use is comparable. """ - class _Data: - pass - def __init__(self, glum_model, design_info, feature_names, X_design, sample_weight): self._glum = glum_model self._design_info = design_info self._X_design = X_design # includes the intercept column self._sample_weight = sample_weight - # .model.data.design_info — used by _survival_pred and _fix_categories - _d = self._Data() - _d.design_info = design_info - self._data = _d + self.model = types.SimpleNamespace( + exog_names=feature_names, + data=types.SimpleNamespace(design_info=design_info), + ) + self.exog_names = feature_names # statsmodels convention: intercept first all_coefs = np.concatenate([[glum_model.intercept_], glum_model.coef_]) self.params = pd.Series(all_coefs, index=feature_names) - @property - def model(self): - # makes .model.exog_names and .model.data.design_info work - return self - - @property - def data(self): - return self._data - - @property - def exog_names(self): - return list(self.params.index) - def predict(self, data, transform=True): if transform: # data is a pandas DataFrame — build design matrix via stored patsy info