Skip to content
Merged
6 changes: 6 additions & 0 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ class SEQopts:
cense_eligible_colname: Optional[str] = None
compevent_colname: Optional[str] = None
covariates: Optional[str] = None
cox_package: Literal["lifelines", "scikit-survival"] = "lifelines"
denominator: Optional[str] = None
excused: bool = False
excused_colnames: List[str] = field(default_factory=lambda: [])
expand_only: bool = False
glm_package: Literal["statsmodels", "glum"] = "statsmodels"
followup_class: bool = False
followup_include: bool = True
followup_max: int = None
Expand Down Expand Up @@ -231,6 +233,10 @@ def _validate_choices(self):
)
if self.bootstrap_CI_method not in ["se", "percentile"]:
raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'")
if self.glm_package not in ["statsmodels", "glum"]:
raise ValueError("glm_package must be 'statsmodels' or 'glum'")
if self.cox_package not in ["lifelines", "scikit-survival"]:
raise ValueError("cox_package must be 'lifelines' or 'scikit-survival'")

def _normalize_formulas(self):
for i in (
Expand Down
16 changes: 15 additions & 1 deletion pySEQTarget/SEQoutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def retrieve_data(
"risk_difference",
"unique_outcomes",
"nonunique_outcomes",
"unique_followup",
"nonunique_followup",
"unique_compevent",
"nonunique_compevent",
"unique_switches",
"nonunique_switches",
]
Expand All @@ -112,7 +116,9 @@ def retrieve_data(

:param type: Data which you would like to access, ['km_data', 'hazard',
'risk_ratio', 'risk_difference', 'unique_outcomes',
'nonunique_outcomes', 'unique_switches', 'nonunique_switches']
'nonunique_outcomes', 'unique_followup', 'nonunique_followup',
'unique_compevent', 'nonunique_compevent',
'unique_switches', 'nonunique_switches']
:type type: str
"""
match type:
Expand All @@ -126,6 +132,14 @@ def retrieve_data(
data = self.diagnostic_tables["unique_outcomes"]
case "nonunique_outcomes":
data = self.diagnostic_tables["nonunique_outcomes"]
case "unique_followup":
data = self.diagnostic_tables["unique_followup"]
case "nonunique_followup":
data = self.diagnostic_tables["nonunique_followup"]
case "unique_compevent":
data = self.diagnostic_tables.get("unique_compevent")
case "nonunique_compevent":
data = self.diagnostic_tables.get("nonunique_compevent")
case "unique_switches":
if self.diagnostic_tables.has_key("unique_switches"):
data = self.diagnostic_tables["unique_switches"]
Expand Down
67 changes: 44 additions & 23 deletions pySEQTarget/analysis/_hazard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@ def _calculate_hazard_single(self, data, idx=None, val=None):
if self.bootstrap_nboot > 0:
boot_log_hrs = []

for boot_idx in range(len(self._boot_samples)):
# outcome_model[model_pos + 1] was fit on _boot_samples[sample_idx];
# skipped replicates make this mapping non-identity, so iterate it
# explicitly rather than assuming model index == sample index.
boot_sample_idx = getattr(self, "_boot_sample_idx", None)
if boot_sample_idx is None:
boot_sample_idx = list(range(len(self._boot_samples)))

for model_pos, sample_idx in enumerate(boot_sample_idx):
if self.seed is not None:
self._rng = np.random.RandomState(self.seed + boot_idx + 1)
id_counts = self._boot_samples[boot_idx]
self._rng = np.random.RandomState(self.seed + sample_idx + 1)
id_counts = self._boot_samples[sample_idx]

counts = pl.DataFrame(
{
Expand All @@ -55,7 +62,9 @@ def _calculate_hazard_single(self, data, idx=None, val=None):
.collect()
)

boot_log_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng)
boot_log_hr = _hazard_handler(
self, boot_data, idx, model_pos + 1, self._rng
)
if boot_log_hr is not None and not np.isnan(boot_log_hr):
boot_log_hrs.append(boot_log_hr)

Expand Down Expand Up @@ -180,37 +189,49 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
sim_data = sim_data.with_columns([pl.col("outcome").alias("event")])

sim_data_pd = sim_data.to_pandas()
tx_bas = f"{self.treatment_col}{self.indicator_baseline}"

try:
# COXPHFITTER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow()
warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*")
if ce_model is not None:
cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy()
cox_data["event_binary"] = (cox_data["event"] == 1).astype(int)

cph = CoxPHFitter()
cph.fit(
cox_data,
duration_col="followup",
event_col="event_binary",
formula=f"`{self.treatment_col}{self.indicator_baseline}`",
)
else:
cph = CoxPHFitter()
cph.fit(
sim_data_pd,
duration_col="followup",
event_col="event",
formula=f"`{self.treatment_col}{self.indicator_baseline}`",
)

log_hr = cph.params_.values[0]
return log_hr
return _cox_log_hr(self, cox_data, "followup", "event_binary", tx_bas)
return _cox_log_hr(self, sim_data_pd, "followup", "event", tx_bas)
except Exception as e:
print(f"Cox model fitting failed: {e}")
return None


def _cox_log_hr(self, data_pd, duration_col, event_col, covariate_col):
"""
Fit a univariate Cox model (single covariate = baseline treatment) and
return the log hazard ratio, dispatching on self.cox_package. scikit-survival
uses Efron tie handling to match lifelines, which matters here because
integer follow-up produces many tied event times.
"""
if getattr(self, "cox_package", "lifelines") == "scikit-survival":
from sksurv.linear_model import CoxPHSurvivalAnalysis

y = np.empty(len(data_pd), dtype=[("event", bool), ("time", "float64")])
y["event"] = data_pd[event_col].to_numpy().astype(bool)
y["time"] = data_pd[duration_col].to_numpy().astype(float)
X = data_pd[[covariate_col]].to_numpy().astype(float)
cox = CoxPHSurvivalAnalysis(ties="efron")
cox.fit(X, y)
return float(cox.coef_[0])

cph = CoxPHFitter()
cph.fit(
data_pd,
duration_col=duration_col,
event_col=event_col,
formula=f"`{covariate_col}`",
)
return cph.params_.values[0]


def _create_hazard_output(hr, lci, uci, val, self):
if lci is not None and uci is not None:
output = pl.DataFrame(
Expand Down
9 changes: 9 additions & 0 deletions pySEQTarget/analysis/_outcome_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def _outcome_fit(

full_formula = f"{outcome} ~ {formula}"

if getattr(self, "glm_package", "statsmodels") == "glum":
from ..helpers._glum_fit import _fit_glum

return _fit_glum(
full_formula,
df_pd,
var_weights=df_pd[weight_col] if weighted else None,
)

glm_kwargs = {
"formula": full_formula,
"data": df_pd,
Expand Down
48 changes: 48 additions & 0 deletions pySEQTarget/expansion/_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ def _diagnostics(self):
nonunique_out = _outcome_diag(self, unique=False)
out = {"unique_outcomes": unique_out, "nonunique_outcomes": nonunique_out}

unique_fu = _followup_diag(self, unique=True)
nonunique_fu = _followup_diag(self, unique=False)
out.update({"unique_followup": unique_fu, "nonunique_followup": nonunique_fu})

if self.compevent_colname is not None:
unique_ce = _compevent_diag(self, unique=True)
nonunique_ce = _compevent_diag(self, unique=False)
out.update({"unique_compevent": unique_ce, "nonunique_compevent": nonunique_ce})

if self.method == "censoring":
unique_switch = _switch_diag(self, unique=True)
nonunique_switch = _switch_diag(self, unique=False)
Expand All @@ -30,6 +39,45 @@ def _outcome_diag(self, unique):
return out


def _followup_diag(self, unique):
"""
Follow-up per treatment arm, grouped by the baseline treatment value, over
the rows the outcome model is fit on (under method="censoring" the switched
rows are dropped, matching _outcome_fit). ``unique`` counts distinct subjects
contributing follow-up to the arm; otherwise counts follow-up intervals
(rows / person-time), so non-unique outcome counts divided by these give
per-arm event rates. A subject appearing in both arms is counted in each.
"""
tx_bas = f"{self.treatment_col}{self.indicator_baseline}"
data = self.DT
if self.method == "censoring":
data = data.filter(pl.col("switch") != 1)

if unique:
out = data.group_by(tx_bas).agg(pl.col(self.id_col).n_unique().alias("len"))
else:
out = data.group_by(tx_bas).len()

return out.sort(tx_bas)


def _compevent_diag(self, unique):
"""
Competing-event counts per treatment arm. Counts rows where the configured
compevent column is 1, grouped by the baseline treatment value. ``unique``
counts distinct subjects who experience the competing event in each arm;
otherwise counts the intervals (rows). A subject appearing with a competing
event in both arms is counted in each.
"""
tx_bas = f"{self.treatment_col}{self.indicator_baseline}"
data = self.DT.filter(pl.col(self.compevent_colname) == 1)
if unique:
out = data.group_by(tx_bas).agg(pl.col(self.id_col).n_unique().alias("len"))
else:
out = data.group_by(tx_bas).len()
return out.sort(tx_bas)


def _switch_diag(self, unique):
if not self.excused:
data = self.DT.with_columns(pl.lit(False).alias("isExcused"))
Expand Down
9 changes: 9 additions & 0 deletions pySEQTarget/helpers/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,14 @@ def wrapper(self, *args, **kwargs):
for i in range(nboot)
}
skipped = 0
boot_sample_idx = []
for j in tqdm(
as_completed(futures), total=nboot, desc="Bootstrapping..."
):
boot_idx = futures[j]
try:
results.append(j.result())
boot_sample_idx.append(boot_idx)
except np.linalg.LinAlgError as e:
skipped += 1
warnings.warn(
Expand All @@ -144,6 +146,7 @@ def wrapper(self, *args, **kwargs):
original_DT_ref = original_DT

skipped = 0
boot_sample_idx = []
for i in tqdm(range(nboot), desc="Bootstrapping..."):
self._current_boot_idx = i + 1
if seed is not None:
Expand All @@ -156,6 +159,7 @@ def wrapper(self, *args, **kwargs):
try:
boot_fit = method(self, *args, **kwargs)
results.append(boot_fit)
boot_sample_idx.append(i)
except np.linalg.LinAlgError as e:
skipped += 1
warnings.warn(
Expand All @@ -167,6 +171,11 @@ def wrapper(self, *args, **kwargs):

self.DT = self._offloader.load_dataframe(original_DT_ref)

# Maps each fitted bootstrap model (results[1:]) back to its
# original _boot_samples index. Skipped replicates leave gaps, so
# downstream consumers (e.g. hazard) must use this to pair a
# resample with its model rather than assuming a 1:1 ordering.
self._boot_sample_idx = boot_sample_idx
self.bootstrap_nboot = len(results) - 1
if skipped > 0:
warnings.warn(
Expand Down
Loading