diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b13b732..8dc1491 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ repos: # Ruff - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.12 + rev: v0.15.13 hooks: - id: ruff args: ["--fix"] diff --git a/docs/changes/61.feature.md b/docs/changes/61.feature.md new file mode 100644 index 0000000..7a05ab0 --- /dev/null +++ b/docs/changes/61.feature.md @@ -0,0 +1 @@ +Add calculation and plotting of zenith-angle dependent signal and background efficiencies (classification mode). diff --git a/src/eventdisplay_ml/evaluate.py b/src/eventdisplay_ml/evaluate.py index fe93b17..8e76bcb 100644 --- a/src/eventdisplay_ml/evaluate.py +++ b/src/eventdisplay_ml/evaluate.py @@ -15,11 +15,8 @@ _logger = logging.getLogger(__name__) -def evaluation_efficiency(name, model, x_test, y_test): - """Calculate signal and background efficiency as a function of threshold.""" - y_pred_proba = model.predict_proba(x_test)[:, 1] - thresholds = np.linspace(0, 1, 101) - +def _efficiency_dataframe(name, y_pred_proba, y_test, thresholds, context_label=""): + """Compute efficiency dataframe for a prediction/label slice.""" n_signal = (y_test == 1).sum() n_background = (y_test == 0).sum() @@ -31,7 +28,7 @@ def evaluation_efficiency(name, model, x_test, y_test): eff_signal.append(((pred) & (y_test == 1)).sum() / n_signal if n_signal else 0) eff_background.append(((pred) & (y_test == 0)).sum() / n_background if n_background else 0) _logger.info( - f"{name} Threshold: {t:.2f} | " + f"{name}{context_label} Threshold: {t:.2f} | " f"Signal Efficiency: {eff_signal[-1]:.4f} | " f"Background Efficiency: {eff_background[-1]:.4f}" ) @@ -50,6 +47,44 @@ def evaluation_efficiency(name, model, x_test, y_test): ) +def evaluation_efficiency(name, model, x_test, y_test, return_by_zenith=False): + """Calculate signal/background efficiency for all events and optionally by zenith bin.""" + y_pred_proba = model.predict_proba(x_test)[:, 1] + thresholds = np.linspace(0, 1, 101) + + efficiency_all = _efficiency_dataframe(name, y_pred_proba, y_test, thresholds) + if not return_by_zenith: + return efficiency_all + + efficiencies_by_zenith = {} + if "ze_bin" not in x_test.columns: + _logger.warning("Column 'ze_bin' missing in x_test; per-zenith efficiencies not computed.") + return efficiency_all, efficiencies_by_zenith + + ze_bins = pd.Series(x_test["ze_bin"]).dropna().unique().tolist() + ze_bins = sorted(ze_bins) + for ze_bin in ze_bins: + mask = x_test["ze_bin"] == ze_bin + if not np.any(mask): + continue + try: + key = int(ze_bin) + except (TypeError, ValueError): + _logger.warning( + "Skipping non-integer ze_bin value in efficiency calculation: %s", ze_bin + ) + continue + efficiencies_by_zenith[key] = _efficiency_dataframe( + name, + y_pred_proba[mask], + y_test[mask], + thresholds, + context_label=f" [ze{key}]", + ) + + return efficiency_all, efficiencies_by_zenith + + def evaluate_classification_model(model, x_test, y_test, df, x_cols, name): """Evaluate the trained model on the test set and log performance metrics. diff --git a/src/eventdisplay_ml/models.py b/src/eventdisplay_ml/models.py index 059468b..51c38c8 100644 --- a/src/eventdisplay_ml/models.py +++ b/src/eventdisplay_ml/models.py @@ -773,7 +773,12 @@ def train_classification(df, model_configs): ) cfg["model"] = model cfg["features"] = x_data.columns.tolist() # Store feature names for diagnostics - cfg["efficiency"] = evaluation_efficiency(name, model, x_test, y_test) + efficiency_all, efficiencies_by_zenith = evaluation_efficiency( + name, model, x_test, y_test, return_by_zenith=True + ) + cfg["efficiency"] = efficiency_all + for ze_bin, ze_efficiency in efficiencies_by_zenith.items(): + cfg[f"efficiency_ze{ze_bin}"] = ze_efficiency cfg["shap_importance"] = shap_importance return model_configs diff --git a/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py b/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py index 23cd5ae..5656402 100644 --- a/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py +++ b/src/eventdisplay_ml/scripts/plot_classification_performance_metrics.py @@ -14,6 +14,7 @@ import argparse import logging +import re from pathlib import Path import joblib @@ -27,10 +28,11 @@ _logger = logging.getLogger(__name__) -def plot_efficiencies(ax, x_root, y_effs, y_effb, x_joblib, y_effs_xgb, y_effb_xgb): +def plot_efficiencies(ax, x_joblib, y_effs_xgb, y_effb_xgb, x_root=None, y_effs=None, y_effb=None): """Plot Signal and Background efficiencies vs. cut value (threshold).""" - ax.plot(x_root, y_effs, label="TMVA BDT Eff S", color="blue", linestyle="-", linewidth=2) - ax.plot(x_root, y_effb, label="TMVA BDT Eff B", color="red", linestyle="-", linewidth=2) + if x_root is not None and y_effs is not None and y_effb is not None: + ax.plot(x_root, y_effs, label="TMVA BDT Eff S", color="blue", linestyle="-", linewidth=2) + ax.plot(x_root, y_effb, label="TMVA BDT Eff B", color="red", linestyle="-", linewidth=2) ax.plot(x_joblib, y_effs_xgb, label="XGB Eff S", color="cyan", linestyle="--", linewidth=2) ax.plot( x_joblib, y_effb_xgb, label="XGB Eff B", color="darkorange", linestyle="--", linewidth=4 @@ -42,14 +44,15 @@ def plot_efficiencies(ax, x_root, y_effs, y_effb, x_joblib, y_effs_xgb, y_effb_x ax.set_ylim(0, 1.05) -def plot_qfactor(ax, y_effs, y_effb, y_effs_xgb, y_effb_xgb): +def plot_qfactor(ax, y_effs_xgb, y_effb_xgb, y_effs=None, y_effb=None): """Plot Q-factor: Signal efficiency / sqrt(Background efficiency).""" - q_tmva = np.divide(y_effs, np.sqrt(y_effb), out=np.zeros_like(y_effs), where=y_effb != 0) q_xgb = np.divide( y_effs_xgb, np.sqrt(y_effb_xgb), out=np.zeros_like(y_effs_xgb), where=y_effb_xgb != 0 ) - ax.plot(y_effs, q_tmva, label=f"TMVA (Max Q: {np.max(q_tmva):.2f})", color="blue") + if y_effs is not None and y_effb is not None: + q_tmva = np.divide(y_effs, np.sqrt(y_effb), out=np.zeros_like(y_effs), where=y_effb != 0) + ax.plot(y_effs, q_tmva, label=f"TMVA (Max Q: {np.max(q_tmva):.2f})", color="blue") ax.plot( y_effs_xgb, q_xgb, @@ -64,11 +67,12 @@ def plot_qfactor(ax, y_effs, y_effb, y_effs_xgb, y_effb_xgb): ax.set_title("Q-Factor") -def plot_roc(ax, y_effs, y_effb, y_effs_xgb, y_effb_xgb): +def plot_roc(ax, y_effs_xgb, y_effb_xgb, y_effs=None, y_effb=None): """Plot ROC curve: Signal efficiency vs. 1 - Background efficiency.""" - auc_tmva = -np.trapezoid(1 - y_effb, y_effs) auc_xgb = -np.trapezoid(1 - y_effb_xgb, y_effs_xgb) - ax.plot(y_effs, 1 - y_effb, label=f"TMVA (AUC: {auc_tmva:.2f})", color="blue") + if y_effs is not None and y_effb is not None: + auc_tmva = -np.trapezoid(1 - y_effb, y_effs) + ax.plot(y_effs, 1 - y_effb, label=f"TMVA (AUC: {auc_tmva:.2f})", color="blue") ax.plot( y_effs_xgb, 1 - y_effb_xgb, @@ -84,18 +88,20 @@ def plot_roc(ax, y_effs, y_effb, y_effs_xgb, y_effb_xgb): ax.set_title("ROC") -def plot_score_distributions(ax, x_root, y_effs, y_effb, x_joblib, y_effs_xgb, y_effb_xgb): +def plot_score_distributions( + ax, x_joblib, y_effs_xgb, y_effb_xgb, x_root=None, y_effs=None, y_effb=None +): """Reconstructs and plots the probability density of the MVA scores.""" # The derivative of the efficiency curve is the probability density function (PDF) # We use negative gradient because efficiency decreases as threshold increases - pdf_s_tmva = -np.gradient(y_effs, x_root) - pdf_b_tmva = -np.gradient(y_effb, x_root) - pdf_s_xgb = -np.gradient(y_effs_xgb, x_joblib) pdf_b_xgb = -np.gradient(y_effb_xgb, x_joblib) - ax.fill_between(x_root, pdf_s_tmva, alpha=0.2, color="blue", label="TMVA Signal") - ax.fill_between(x_root, pdf_b_tmva, alpha=0.2, color="red", label="TMVA Background") + if x_root is not None and y_effs is not None and y_effb is not None: + pdf_s_tmva = -np.gradient(y_effs, x_root) + pdf_b_tmva = -np.gradient(y_effb, x_root) + ax.fill_between(x_root, pdf_s_tmva, alpha=0.2, color="blue", label="TMVA Signal") + ax.fill_between(x_root, pdf_b_tmva, alpha=0.2, color="red", label="TMVA Background") ax.plot(x_joblib, pdf_s_xgb, color="cyan", linestyle="--", label="XGB Signal", linewidth=4) ax.plot( @@ -110,28 +116,72 @@ def plot_score_distributions(ax, x_root, y_effs, y_effb, x_joblib, y_effs_xgb, y def load_efficiency_tmva(path, ebin, zebin=0): """Load efficiencies from TMVA root files.""" file_path = Path(path) / f"BDT_{ebin}_{zebin}.root" - with uproot.open(file_path) as rf: - base_path = "Method_BDT/BDT_0" - effs_rt = rf[f"{base_path}/MVA_BDT_0_effS"] - effb_rt = rf[f"{base_path}/MVA_BDT_0_effB"] - x_root_raw = ( - effs_rt.axis().centers() if hasattr(effs_rt, "axis") else effs_rt.values(axis=0) + try: + with uproot.open(file_path) as rf: + base_path = "Method_BDT/BDT_0" + effs_rt = rf[f"{base_path}/MVA_BDT_0_effS"] + effb_rt = rf[f"{base_path}/MVA_BDT_0_effB"] + x_root_raw = ( + effs_rt.axis().centers() if hasattr(effs_rt, "axis") else effs_rt.values(axis=0) + ) + x_min = np.min(x_root_raw) + x_max = np.max(x_root_raw) + if x_max == x_min: + _logger.warning( + "TMVA efficiency axis is degenerate in %s (ebin=%s, zebin=%s); skipping TMVA overlay.", + file_path, + ebin, + zebin, + ) + return None + # map [x_min, x_max] -> [0, 1] + x_root = (x_root_raw - x_min) / (x_max - x_min) + y_effs = effs_rt.values() + y_effb = effb_rt.values() + except (OSError, KeyError) as exc: + _logger.warning( + "TMVA efficiency histograms unavailable in %s (ebin=%s, zebin=%s): %s. " + "Plotting XGB only for this bin.", + file_path, + ebin, + zebin, + exc, ) - x_min = np.min(x_root_raw) - x_max = np.max(x_root_raw) - # map [-x_min, x_max] -> [0, 1] - x_root = (x_root_raw - x_min) / (x_max - x_min) - y_effs = effs_rt.values() - y_effb = effb_rt.values() + return None return x_root, y_effs, y_effb -def load_efficiency_xgb(path, ebin): - """Load efficiencies from XGB files.""" +def load_xgb_model_data(path, ebin): + """Load XGB joblib model payload for one energy bin.""" model_file = utils.resolve_joblib_path(Path(path) / f"gammahadron_bdt_ebin{ebin}") - data_joblib = joblib.load(model_file) - df_xgboost = data_joblib["models"]["xgboost"]["efficiency"] + return joblib.load(model_file) + + +def load_efficiency_xgb(data_joblib, ebin, zebin=-1): + """Load efficiencies from XGB model payload.""" + model_data = data_joblib["models"]["xgboost"] + + if zebin < 0: + efficiency_key = "efficiency" + else: + efficiency_key = f"efficiency_ze{zebin}" + if efficiency_key not in model_data: + available_ze_bins = [] + for key in model_data: + match = re.fullmatch(r"efficiency_ze(\d+)", key) + if match: + available_ze_bins.append(int(match.group(1))) + available_ze_bins = sorted(set(available_ze_bins)) + raise KeyError( + f"Efficiency key '{efficiency_key}' not found for ebin {ebin}. " + f"Available zenith bins: {available_ze_bins or 'none'}." + ) + + if efficiency_key not in model_data: + raise KeyError(f"Efficiency key '{efficiency_key}' not found for ebin {ebin}.") + + df_xgboost = model_data[efficiency_key] x_joblib = df_xgboost["threshold"] y_effs_xgb = df_xgboost["signal_efficiency"] @@ -140,11 +190,109 @@ def load_efficiency_xgb(path, ebin): return x_joblib, y_effs_xgb, y_effb_xgb +def xgb_zenith_bins(data_joblib): + """Return available XGB zenith bins from loaded joblib model payload.""" + model_data = data_joblib["models"]["xgboost"] + ze_bins = [] + for key in model_data: + match = re.fullmatch(r"efficiency_ze(\d+)", key) + if match: + ze_bins.append(int(match.group(1))) + return sorted(set(ze_bins)) + + +def tmva_zenith_bins(path, ebin): + """Return available TMVA zenith bins from BDT__.root filenames.""" + ze_bins = [] + pattern = re.compile(rf"^BDT_{ebin}_(\d+)\.root$") + for file_path in Path(path).glob(f"BDT_{ebin}_*.root"): + match = pattern.match(file_path.name) + if match: + ze_bins.append(int(match.group(1))) + return sorted(set(ze_bins)) + + +def resolve_tmva_zebin(xgb_zebin, available_tmva_bins, fallback_tmva_bin): + """Resolve TMVA zenith bin aligned to XGB zenith bin where possible.""" + if xgb_zebin < 0: + return fallback_tmva_bin if fallback_tmva_bin in available_tmva_bins else None + if xgb_zebin in available_tmva_bins: + return xgb_zebin + if fallback_tmva_bin in available_tmva_bins: + return fallback_tmva_bin + return None + + +def zenith_plot_label(zebin): + """Return human-readable zenith label for plot/file naming.""" + return "overall" if zebin < 0 else f"ze{zebin}" + + +def style_axis(ax): + """Apply common style settings to a matplotlib axis.""" + ax.tick_params(labelsize=10) + ax.grid(True, alpha=0.2) + + +def make_figure(x_joblib, y_effs_xgb, y_effb_xgb, x_root=None, y_effs=None, y_effb=None): + """Build 2x2 diagnostics figure for XGB with optional TMVA overlays.""" + fig, axs = plt.subplots(2, 2, figsize=(16, 16), sharex=False) + fig.set_constrained_layout(True) + + for ax in axs.flatten(): + style_axis(ax) + + plot_efficiencies(axs[0, 0], x_joblib, y_effs_xgb, y_effb_xgb, x_root, y_effs, y_effb) + plot_qfactor(axs[0, 1], y_effs_xgb, y_effb_xgb, y_effs, y_effb) + plot_roc(axs[1, 0], y_effs_xgb, y_effb_xgb, y_effs, y_effb) + plot_score_distributions(axs[1, 1], x_joblib, y_effs_xgb, y_effb_xgb, x_root, y_effs, y_effb) + + for ax in axs.flatten(): + ax.legend(fontsize=9, frameon=False, loc="best") + + return fig + + +def selected_xgb_bins(zenith_bin_xgb, available_xgb_bins): + """Resolve which XGB zenith bins to plot.""" + return [-1, *available_xgb_bins] if zenith_bin_xgb is None else [zenith_bin_xgb] + + +def tmva_overlay_data(root_dir, ebin, xgb_zebin, tmva_zebin): + """Return TMVA overlay data tuple or None when TMVA is unavailable for this zenith bin.""" + if root_dir is None: + return None + if tmva_zebin is None: + _logger.warning( + "No TMVA zenith-bin match for XGB %s in ebin %s; plotting XGB only.", + zenith_plot_label(xgb_zebin), + ebin, + ) + return None + return load_efficiency_tmva(root_dir, ebin, tmva_zebin) + + def main(): """Plot TMVA and XGBoost performance metrics.""" parser = argparse.ArgumentParser(description="Plot TMVA and XGBoost metrics.") - parser.add_argument("root_dir", help="Path to the TMVA BDT .root file") - parser.add_argument("joblib_dir", help="Path to the XGB BDT .joblib file") + parser.add_argument( + "--tmva_dir", + type=str, + default=None, + help="Path to TMVA BDT ROOT files (optional).", + ) + parser.add_argument( + "--xgb_dir", + type=str, + required=True, + help="Path to XGB BDT joblib files (required).", + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Output directory for PNG files (default: current directory).", + ) parser.add_argument( "--energy-bin", type=int, @@ -152,35 +300,51 @@ def main(): default=None, help="Plot only a single energy bin (0-8). If omitted, all bins are processed.", ) + parser.add_argument( + "--zenith-bin-tmva", + type=int, + default=0, + help="Zenith bin index for TMVA ROOT files (second digit in BDT__.root). Default: 0.", + ) + parser.add_argument( + "--zenith-bin-xgb", + type=int, + default=None, + help=( + "XGB zenith bin to plot. If omitted, plots overall (-1) and all available ze bins. " + "Use -1 for overall or >=0 for efficiency_zeN." + ), + ) args = parser.parse_args() + root_dir = args.tmva_dir + joblib_dir = args.xgb_dir + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) # assume energy binning is identical in XGB and TMVA files. energy_bins = [args.energy_bin] if args.energy_bin is not None else range(9) for ebin in energy_bins: - x_root, y_effs, y_effb = load_efficiency_tmva(args.root_dir, ebin) - x_joblib, y_effs_xgb, y_effb_xgb = load_efficiency_xgb(args.joblib_dir, ebin) - - fig, axs = plt.subplots(2, 2, figsize=(16, 16), sharex=False) - fig.set_constrained_layout(True) - - for ax in axs.flatten(): - ax.tick_params(labelsize=10) - ax.grid(True, alpha=0.2) - - plot_efficiencies(axs[0, 0], x_root, y_effs, y_effb, x_joblib, y_effs_xgb, y_effb_xgb) - plot_qfactor(axs[0, 1], y_effs, y_effb, y_effs_xgb, y_effb_xgb) - plot_roc(axs[1, 0], y_effs, y_effb, y_effs_xgb, y_effb_xgb) - plot_score_distributions( - axs[1, 1], x_root, y_effs, y_effb, x_joblib, y_effs_xgb, y_effb_xgb - ) - - for ax in axs.flatten(): - ax.legend(fontsize=9, frameon=False, loc="best") - - plt.tight_layout() - _logger.info(f"Plotting plot_performance_metrics for ebin {ebin}") - plt.savefig(f"plot_performance_metrics_ebin{ebin}.png", dpi=300, bbox_inches="tight") - plt.close(fig) + xgb_model_data = load_xgb_model_data(joblib_dir, ebin) + available_xgb_bins = xgb_zenith_bins(xgb_model_data) + available_tmva_bins = tmva_zenith_bins(root_dir, ebin) if root_dir else [] + xgb_bins_to_plot = selected_xgb_bins(args.zenith_bin_xgb, available_xgb_bins) + + for xgb_zebin in xgb_bins_to_plot: + x_joblib, y_effs_xgb, y_effb_xgb = load_efficiency_xgb(xgb_model_data, ebin, xgb_zebin) + tmva_zebin = resolve_tmva_zebin(xgb_zebin, available_tmva_bins, args.zenith_bin_tmva) + + tmva_data = tmva_overlay_data(root_dir, ebin, xgb_zebin, tmva_zebin) + if tmva_data is None: + fig = make_figure(x_joblib, y_effs_xgb, y_effb_xgb) + else: + x_root, y_effs, y_effb = tmva_data + fig = make_figure(x_joblib, y_effs_xgb, y_effb_xgb, x_root, y_effs, y_effb) + + ze_label = zenith_plot_label(xgb_zebin) + _logger.info(f"Plotting plot_performance_metrics for ebin {ebin}, {ze_label}") + output_path = output_dir / f"plot_performance_metrics_ebin{ebin}_{ze_label}.png" + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close(fig) if __name__ == "__main__": diff --git a/tests/test_train_classification_shap.py b/tests/test_train_classification_shap.py index ca288fe..f4b37b7 100644 --- a/tests/test_train_classification_shap.py +++ b/tests/test_train_classification_shap.py @@ -80,7 +80,7 @@ def test_train_classification_caches_shap_importance_and_features(): with patch("eventdisplay_ml.models.evaluate_classification_model") as mock_eval: with patch("eventdisplay_ml.models.evaluation_efficiency") as mock_eff: mock_eval.return_value = expected_shap - mock_eff.return_value = expected_efficiency + mock_eff.return_value = (expected_efficiency, {}) result = models.train_classification([signal_df, background_df], cfg)