From 06166d4778f149024b5675cfe1f05a31576f9571 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Fri, 5 Jun 2026 19:38:12 -0500 Subject: [PATCH 01/57] machinery to fit 1d app mag funcs --- .../experimental/data_loaders/load_feniks.py | 82 +++++++++ diffhtwo/experimental/defaults.py | 4 + diffhtwo/experimental/kernels/N_phot.py | 45 ++++- .../loss_kernels/loss_functions.py | 16 +- .../experimental/loss_kernels/phot_loss.py | 78 +++++++- .../optimizers/Np_specphot_opt.py | 107 ++++++++++- scripts/config_feniks.yaml | 4 +- scripts/config_feniks_lh.yaml | 23 +++ scripts/fit_feniks.py | 31 +--- scripts/fit_feniks_lh.py | 169 ++++++++++++++++++ 10 files changed, 518 insertions(+), 41 deletions(-) create mode 100644 scripts/config_feniks_lh.yaml create mode 100644 scripts/fit_feniks_lh.py diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index d632a6c4..4e70b476 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -18,6 +18,7 @@ FilterInfo, ) from ..latin_hypercube import latin_hypercube as lh +from ..lightcone_generators import generate_lc_data from ..utils import load_feniks_tcurve BASE_PATH = Path(__file__).resolve().parent.parent @@ -76,6 +77,26 @@ def get_mag_ab(phot_table, col_name, ZP=25): return mag_ab +def get_N_1d_mag_bins(mags, mag_bin_edges=None, dmag=0.1, sig_scale=0.5): + mags = mags.reshape(mags.size, 1) + if mag_bin_edges is None: + mag_bin_edges = np.arange(mags.min(), mags.max(), dmag) + + mag_lo = mag_bin_edges[:-1].reshape(mag_bin_edges[:-1].size, 1) + mag_hi = mag_bin_edges[1:].reshape(mag_bin_edges[1:].size, 1) + + sig = jnp.zeros_like(mag_lo) + (dmag * sig_scale) + + N_mags = diffndhist_lomem.tw_ndhist( + mags, + sig, + mag_lo, + mag_hi, + ) + + return mag_bin_edges, N_mags + + def refresh_lh_centroids(DATASET, lh_d_mag): lh_centroids, d_centroids = get_lh_centroids(DATASET.dataset, lh_d_mag) @@ -144,6 +165,11 @@ def get_feniks_data( lh_d_mag=0.6, phot=PHOT, zout=ZOUT, + num_halos=250, + lgmp_min=10.0, + lgmp_max=15.0, + lc_sky_area_degsq=100, + n_z_phot_table=30, ): # Transmission curves and filter mag thresholds @@ -348,6 +374,58 @@ def get_feniks_data( mags = mags[z_mask] zout = zout[z_mask] + # prepare 1D app mag functions in z-bins for fitting + + zbins = np.array( + [ + [0.2, 0.5], + [0.5, 0.8], + [0.8, 1.2], + [1.2, 1.6], + [1.6, 2.0], + ] + ) + n_gals, n_dim = mags.shape + n_bands = n_dim - 1 + n_z_bins = len(zbins) + + magbin_zbins_bands = [] + N_zbins_bands = [] + lc_data = [] + for zbin in range(n_z_bins): + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data.append(generate_lc_data(*lc_args)) + + z_sel = (mags[:, -1] > z_min) & (mags[:, -1] <= z_max) + + magbin_bands = [] + N_bands = [] + for band in range(0, n_bands): + magbin_edges, N_mags = get_N_1d_mag_bins(mags[:, band][z_sel]) + magbin_bands.append(magbin_edges) + N_bands.append(N_mags) + + magbin_zbins_bands.append(magbin_bands) + N_zbins_bands.append(N_bands) + lh_centroids, d_centroids = get_lh_centroids(dataset, lh_d_mag) # run initial diffndhist_lomem with fixed dmag @@ -367,6 +445,10 @@ def get_feniks_data( dataset_dim_labels, mags, mags_labels, + zbins, + magbin_zbins_bands, + N_zbins_bands, + lc_data, filter_info, frac_cat, lh_centroids, diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 1c83c063..62ce707a 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -35,6 +35,10 @@ "dataset_dim_labels", "mags", "mags_labels", + "zbins", + "magbin_zbins_bands", + "N_zbins_bands", + "lc_data", "filter_info", "frac_cat", "lh_centroids", diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index 684b7176..0bdab7f1 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -5,7 +5,50 @@ from dsps.cosmology import DEFAULT_COSMOLOGY from jax import jit as jjit -from .phot_kern import get_colors_mags +from .phot_kern import get_colors_mags, mag_kern + + +@jjit +def N_mags_1d( + ran_key, + param_collection, + magbin_bands, + lc_data, + mag_thresh, + frac_cat, + sig_scale=0.5, +): + obs_mags, gal_weight, phot_kern_results = mag_kern( + ran_key, + param_collection, + lc_data, + mag_thresh, + frac_cat, + ) + + n_gals, n_bands = obs_mags.shape + N_bands = [] + for band in range(0, n_bands): + mags = obs_mags[:, band].reshape(obs_mags[:, band].size, 1) + + magbin_edges = magbin_bands[band] + + sig = jnp.diff(magbin_edges) * sig_scale + sig = sig.reshape(sig.size, 1) + + mag_lo = magbin_edges[:-1].reshape(magbin_edges[:-1].size, 1) + mag_hi = magbin_edges[1:].reshape(magbin_edges[1:].size, 1) + + N_mags = diffndhist_lomem.tw_ndhist_weighted( + mags, + sig, + gal_weight, + mag_lo, + mag_hi, + ) + N_bands.append(N_mags) + + return N_bands @partial(jjit, static_argnames=["redshift_as_last_dimension_in_lh"]) diff --git a/diffhtwo/experimental/loss_kernels/loss_functions.py b/diffhtwo/experimental/loss_kernels/loss_functions.py index 82ae48fa..25aaca59 100644 --- a/diffhtwo/experimental/loss_kernels/loss_functions.py +++ b/diffhtwo/experimental/loss_kernels/loss_functions.py @@ -14,14 +14,14 @@ def mse_w(lg_n_pred, lg_n_target, lg_n_target_err, lg_n_thresh=-10): return jnp.sum(chi2) / nbins -# @jjit -# def poisson_loss(N_pred, N_target, eps=1e-12): -# N_pred = jnp.clip(N_pred, eps, None) -# return jnp.sum(N_pred - N_target * jnp.log(N_pred)) - - @jjit def poisson_loss(N_pred, N_target, eps=1e-12): N_pred = jnp.clip(N_pred, eps, None) - N_eff = jnp.maximum(jnp.sum(N_target), eps) - return jnp.sum(N_pred - N_target * jnp.log(N_pred)) / N_eff + return jnp.sum(N_pred - N_target * jnp.log(N_pred)) + + +# @jjit +# def poisson_loss(N_pred, N_target, eps=1e-12): +# N_pred = jnp.clip(N_pred, eps, None) +# N_eff = jnp.maximum(jnp.sum(N_target), eps) +# return jnp.sum(N_pred - N_target * jnp.log(N_pred)) / N_eff diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index 0086f33d..17b74c32 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -1,10 +1,86 @@ from jax import jit as jjit -from ..kernels.N_phot import N_colors_mags_lh +from ..kernels.N_phot import N_colors_mags_lh, N_mags_1d from ..param_utils import get_param_collection_from_u_theta from .loss_functions import poisson_loss +@jjit +def get_phot_loss_1d( + ran_key, + param_collection, + magbin_bands, + N_bands_data, + lc_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, +): + N_bands_model = N_mags_1d( + ran_key, param_collection, magbin_bands, lc_data, mag_thresh, frac_cat + ) + + n_bands = len(N_bands_data) + phot_loss_1d = 0.0 + for band in range(0, n_bands): + N_model = N_bands_model[band] * (data_sky_area_degsq / lc_data.sky_area_degsq) + phot_loss_1d += poisson_loss(N_model, N_bands_data[band]) + + return phot_loss_1d + + +@jjit +def _loss_phot_kern_1d( + u_theta, + ran_key, + magbin_bands, + N_bands_data, + lc_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, +): + param_collection = get_param_collection_from_u_theta(u_theta) + + phot_loss_1d_args = ( + ran_key, + param_collection, + magbin_bands, + N_bands_data, + lc_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, + ) + phot_loss_1d = get_phot_loss_1d(*phot_loss_1d_args) + + return phot_loss_1d + + +@jjit +def _loss_phot_kern_multiband_multiz( + u_theta, + ran_key, + fitting_data, +): + zbins = fitting_data.zbins + n_z_bins = len(zbins) + phot_loss_multiband_multiz = 0.0 + for zbin in range(n_z_bins): + phot_loss_multiband_multiz += _loss_phot_kern_1d( + u_theta, + ran_key, + fitting_data.magbin_zbins_bands[zbin], + fitting_data.N_zbins_bands[zbin], + fitting_data.lc_data[zbin], + fitting_data.filter_info.mag_thresh, + fitting_data.frac_cat, + fitting_data.data_sky_area_degsq, + ) + + return phot_loss_multiband_multiz + + @jjit def get_phot_loss( ran_key, diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 3652c594..28cb2993 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -11,11 +11,10 @@ import jax.numpy as jnp from jax import jit as jjit from jax import lax, value_and_grad, vmap -from jax.debug import print from jax.example_libraries import optimizers as jax_opt from ..loss_kernels.emline_loss import _loss_emline_kern_multi_line_multi_z -from ..loss_kernels.phot_loss import _loss_phot_kern +from ..loss_kernels.phot_loss import _loss_phot_kern, _loss_phot_kern_multiband_multiz _L_pk = ( None, @@ -34,6 +33,10 @@ value_and_grad(_loss_emline_kern_multi_line_multi_z) ) +_loss_and_grad_phot_kern_multiband_multiz = jjit( + value_and_grad(_loss_phot_kern_multiband_multiz) +) + @partial(jjit, static_argnames=["n_steps", "step_size"]) def fit_N_multi_z( @@ -77,6 +80,46 @@ def _opt_update(opt_state, i): return loss_hist, u_theta_fit +@partial(jjit, static_argnames=["n_steps", "step_size"]) +def fit_N_phot_1d( + u_theta_init, + trainable, + ran_key, + fitting_data, + n_steps=2, + step_size=1e-2, +): + opt_init, opt_update, get_params = jax_opt.adam(step_size) + opt_state = opt_init(u_theta_init) + + other = ( + ran_key, + fitting_data, + ) + + def _opt_update(opt_state, i): + u_theta = get_params(opt_state) + loss, grads = _loss_and_grad_phot_kern_multiband_multiz(u_theta, *other) + # set grads for untrainable params to 0.0 + grads = tuple( + jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) + ) + + # clip gradients + # global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) + + opt_state = opt_update(i, grads, opt_state) + return opt_state, loss + + opt_state, loss_hist = lax.scan(_opt_update, opt_state, jnp.arange(n_steps)) + u_theta_fit = get_params(opt_state) + + return loss_hist, u_theta_fit + + @jjit def pytree_norm(grads): leaves = jax.tree_util.tree_leaves(grads) @@ -88,7 +131,6 @@ def fit_feniks_hizels( u_theta_init, trainable, ran_key, - feniks_meta_data, feniks_fitting_data, hizels_fitting_data, n_steps=2, @@ -99,10 +141,9 @@ def fit_feniks_hizels( def _opt_update(opt_state, i): u_theta = get_params(opt_state) - loss_phot, grad_phot = _loss_and_grad_phot_kern_multi_z( + loss_phot, grad_phot = _loss_and_grad_phot_kern_multiband_multiz( u_theta, ran_key, - feniks_meta_data, feniks_fitting_data, ) loss_emline, grad_emline = _loss_and_grad_emline_kern_multi_line_multi_z( @@ -110,7 +151,7 @@ def _opt_update(opt_state, i): ran_key, hizels_fitting_data, ) - w_phot = 10.0 + w_phot = 1.0 w_emline = 1.0 loss = w_phot * loss_phot + w_emline * loss_emline grads = tuple( @@ -137,6 +178,60 @@ def _opt_update(opt_state, i): return loss_hist, loss_phot_hist, loss_emline_hist, u_theta_fit +# @partial(jjit, static_argnames=["n_steps", "step_size"]) +# def fit_feniks_hizels( +# u_theta_init, +# trainable, +# ran_key, +# feniks_meta_data, +# feniks_fitting_data, +# hizels_fitting_data, +# n_steps=2, +# step_size=1e-2, +# ): +# opt_init, opt_update, get_params = jax_opt.adam(step_size) +# opt_state = opt_init(u_theta_init) + +# def _opt_update(opt_state, i): +# u_theta = get_params(opt_state) +# loss_phot, grad_phot = _loss_and_grad_phot_kern_multi_z( +# u_theta, +# ran_key, +# feniks_meta_data, +# feniks_fitting_data, +# ) +# loss_emline, grad_emline = _loss_and_grad_emline_kern_multi_line_multi_z( +# u_theta, +# ran_key, +# hizels_fitting_data, +# ) +# w_phot = 10.0 +# w_emline = 1.0 +# loss = w_phot * loss_phot + w_emline * loss_emline +# grads = tuple( +# w_phot * gp + w_emline * ge for gp, ge in zip(grad_phot, grad_emline) +# ) +# # set grads for untrainable params to 0.0 +# grads = tuple( +# jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) +# ) + +# # clip gradients +# global_norm = pytree_norm(grads) +# tau = 1.0 +# scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) +# grads = tuple(g * scale for g in grads) + +# opt_state = opt_update(i, grads, opt_state) +# return opt_state, (loss, loss_phot, loss_emline) + +# opt_state, (loss_hist, loss_phot_hist, loss_emline_hist) = lax.scan( +# _opt_update, opt_state, jnp.arange(n_steps) +# ) +# u_theta_fit = get_params(opt_state) +# return loss_hist, loss_phot_hist, loss_emline_hist, u_theta_fit + + @jjit def _loss_sdss_feniks_hizels( u_theta, diff --git a/scripts/config_feniks.yaml b/scripts/config_feniks.yaml index 51c5966c..5e94a961 100644 --- a/scripts/config_feniks.yaml +++ b/scripts/config_feniks.yaml @@ -7,14 +7,12 @@ fit_runid: "runtest" fit_type: "all" feniks: - lh_d_mag: 0.4 - N_centroids: 2000 + num_halos: 100 epoch: n_it: 1 n_steps: 2 step_size: 0.1 - num_halos: 200 defaults: diffstarpop: True diff --git a/scripts/config_feniks_lh.yaml b/scripts/config_feniks_lh.yaml new file mode 100644 index 00000000..bcc1cfef --- /dev/null +++ b/scripts/config_feniks_lh.yaml @@ -0,0 +1,23 @@ +base_path: "/Users/kumail/diffdir" + +start_runid: "run90" +start_fit_type: "all" + +fit_runid: "runtest" +fit_type: "all" + +feniks: + lh_d_mag: 0.4 + N_centroids: 2000 + +epoch: + n_it: 1 + n_steps: 2 + step_size: 0.1 + num_halos: 100 + +defaults: + diffstarpop: True + spspop: True + ssperr: True + merging: True diff --git a/scripts/fit_feniks.py b/scripts/fit_feniks.py index 0e3e7b64..66f24a2c 100644 --- a/scripts/fit_feniks.py +++ b/scripts/fit_feniks.py @@ -1,6 +1,7 @@ import argparse import os import time +from collections import namedtuple from datetime import datetime import jax @@ -21,8 +22,6 @@ from diffhtwo.experimental import param_utils as pu from diffhtwo.experimental.data_loaders import load_feniks -from diffhtwo.experimental.defaults import FENIKS_Z_MAX, FENIKS_Z_MIN -from diffhtwo.experimental.latin_hypercube import lh_utils as lhu from diffhtwo.experimental.optimizers import Np_specphot_opt DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ @@ -51,8 +50,13 @@ # load feniks data ran_key = jran.key(0) - FENIKS = load_feniks.get_feniks_data( - feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] + feniks = load_feniks.get_feniks_data( + feniks_drn, ran_key, ssp_data, num_halos=cfg["feniks"]["num_halos"] + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} ) # start fit dirs @@ -95,32 +99,15 @@ os.system(f"cp {args.config} {fit_diagnostics_save_drn}") - feniks_z_min = [FENIKS_Z_MIN, 1] - feniks_z_max = [1, 2] - initial_pts = [] start = time.time() for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - FENIKS = load_feniks.refresh_lh_centroids(FENIKS, cfg["feniks"]["lh_d_mag"]) - - # FENIKS - feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( - ran_key, - FENIKS, - feniks_z_min, - feniks_z_max, - ssp_data, - cfg["feniks"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", - num_halos=cfg["epoch"]["num_halos"], - ) - loss_hist, u_theta_fit = Np_specphot_opt.fit_N_multi_z( + loss_hist, u_theta_fit = Np_specphot_opt.fit_N_phot_1d( u_theta_fit, trainable_params, ran_key, - feniks_meta_data, feniks_fitting_data, n_steps=cfg["epoch"]["n_steps"], step_size=cfg["epoch"]["step_size"], diff --git a/scripts/fit_feniks_lh.py b/scripts/fit_feniks_lh.py new file mode 100644 index 00000000..0e3e7b64 --- /dev/null +++ b/scripts/fit_feniks_lh.py @@ -0,0 +1,169 @@ +import argparse +import os +import time +from datetime import datetime + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import yaml +from diffsky.data_loaders.hacc_utils import lc_mock +from diffsky.merging.merging_model import DEFAULT_MERGE_PARAMS +from diffsky.param_utils.spspop_param_utils import DEFAULT_SPSPOP_PARAMS +from diffsky.ssp_err_model.defaults import ZERO_SSPERR_PARAMS +from diffstar.diffstarpop.kernels.params.params_diffstarpopfits_mgash import ( + DiffstarPop_Params_Diffstarpopfits_mgash, +) +from dsps import load_ssp_templates +from dsps.data_loaders import load_emline_info as lemi +from jax import random as jran + +from diffhtwo.experimental import param_utils as pu +from diffhtwo.experimental.data_loaders import load_feniks +from diffhtwo.experimental.defaults import FENIKS_Z_MAX, FENIKS_Z_MIN +from diffhtwo.experimental.latin_hypercube import lh_utils as lhu +from diffhtwo.experimental.optimizers import Np_specphot_opt + +DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ + "galacticus_in_plus_ex_situ" +] + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("--config", default="config_feniks.yaml") + args = p.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + feniks_drn = cfg["base_path"] + "/feniks" + ssp_filename = ( + cfg["base_path"] + + "/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5" + ) + + # get ssp data + ssp_data = load_ssp_templates(fn=ssp_filename) + ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) + emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) + emline_wave_table = jnp.array([emline_wave_aa]) + + # load feniks data + ran_key = jran.key(0) + FENIKS = load_feniks.get_feniks_data( + feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] + ) + + # start fit dirs + fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" + param_collection_fit = lc_mock.load_diffsky_param_collection_merging( + fit_start_drn, + cfg["start_runid"] + "_" + cfg["start_fit_type"], + ) + if cfg["defaults"]["diffstarpop"]: + param_collection_fit = param_collection_fit._replace( + diffstarpop_params=DIFFSTARPOP_GALACTICUS_exsitu + ) + if cfg["defaults"]["spspop"]: + param_collection_fit = param_collection_fit._replace( + spspop_params=DEFAULT_SPSPOP_PARAMS + ) + if cfg["defaults"]["ssperr"]: + param_collection_fit = param_collection_fit._replace( + ssperr_params=ZERO_SSPERR_PARAMS + ) + if cfg["defaults"]["merging"]: + param_collection_fit = param_collection_fit._replace( + merging_params=DEFAULT_MERGE_PARAMS + ) + + u_theta_fit = pu.get_u_theta_from_param_collection(param_collection_fit) + + # fit dirs + trainable_params = pu.get_trainable_params(fit_type=cfg["fit_type"]) + fit_save_drn = cfg["base_path"] + "/fits/" + cfg["fit_runid"] + "/" + fit_diagnostics_save_drn = ( + cfg["base_path"] + + "/fits/" + + cfg["fit_runid"] + + "/diagnostic_plots/" + + cfg["fit_type"] + ) + os.makedirs(fit_diagnostics_save_drn + "/loss", exist_ok=True) + os.makedirs(fit_diagnostics_save_drn + "/lh_N_z", exist_ok=True) + + os.system(f"cp {args.config} {fit_diagnostics_save_drn}") + + feniks_z_min = [FENIKS_Z_MIN, 1] + feniks_z_max = [1, 2] + + initial_pts = [] + start = time.time() + for epoch in range(0, cfg["epoch"]["n_it"]): + print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') + FENIKS = load_feniks.refresh_lh_centroids(FENIKS, cfg["feniks"]["lh_d_mag"]) + + # FENIKS + feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( + ran_key, + FENIKS, + feniks_z_min, + feniks_z_max, + ssp_data, + cfg["feniks"]["N_centroids"], + lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", + num_halos=cfg["epoch"]["num_halos"], + ) + + loss_hist, u_theta_fit = Np_specphot_opt.fit_N_multi_z( + u_theta_fit, + trainable_params, + ran_key, + feniks_meta_data, + feniks_fitting_data, + n_steps=cfg["epoch"]["n_steps"], + step_size=cfg["epoch"]["step_size"], + ) + jax.clear_caches() + + param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) + lc_mock.write_diffsky_param_collection_merging( + fit_save_drn, + cfg["fit_runid"] + "_" + cfg["fit_type"], + param_collection_fit, + ) + + if epoch == 0: + STEPS = np.arange(1, cfg["epoch"]["n_steps"] + 1, 1) + + LOSS_HIST = loss_hist + + initial_pts.append((STEPS[0], LOSS_HIST[0])) + else: + steps = np.arange(STEPS[-1] + 1, STEPS[-1] + cfg["epoch"]["n_steps"] + 1, 1) + initial_pts.append((steps[0], loss_hist[0])) + + STEPS = np.concatenate((STEPS, steps)) + LOSS_HIST = np.concatenate((LOSS_HIST, loss_hist)) + + end = time.time() + elapsed = end - start + print( + f'Gradient descent took {elapsed/60:.3f} minutes for {cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]} steps.' + ) + print(f'speed: {elapsed/(cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]):.3f} s/it') + + # gradient descent figure + fig_loss, ax_loss = plt.subplots(1) + + start_step = [s[0] for s in initial_pts] + start_loss = [s[1] for s in initial_pts] + ax_loss.scatter(start_step, start_loss, s=50, c="deepskyblue") + + ax_loss.plot(STEPS, LOSS_HIST, c="deepskyblue") + ax_loss.set_ylabel("Poisson Loss") + ax_loss.set_xlabel("steps") + ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + plt.savefig(fit_diagnostics_save_drn + "/loss/feniks_loss_" + ts + ".png") + plt.close() From 8bf386003c02705d21680593729a07022bb303f0 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Fri, 5 Jun 2026 20:04:51 -0500 Subject: [PATCH 02/57] define mag_bin_edges for feniks in conftest --- diffhtwo/experimental/conftest.py | 4 ++++ diffhtwo/experimental/data_loaders/load_feniks.py | 7 +++++-- scripts/config_feniks.yaml | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/diffhtwo/experimental/conftest.py b/diffhtwo/experimental/conftest.py index 328d94ef..496af452 100644 --- a/diffhtwo/experimental/conftest.py +++ b/diffhtwo/experimental/conftest.py @@ -1,6 +1,7 @@ from pathlib import Path import jax.numpy as jnp +import numpy as np import pytest from dsps.data_loaders import load_emline_info as lemi from dsps.data_loaders import retrieve_fake_fsps_data @@ -41,12 +42,15 @@ def fake_subset_ssp_data(): def feniks(ran_key, fake_subset_ssp_data): ssp_data, emline_wave_aa = fake_subset_ssp_data + mag_bin_edges = np.array([18, 25]) + feniks = load_feniks.get_feniks_data( FENIKS_DRN, ran_key, ssp_data, phot=PHOT, zout=ZOUT, + mag_bin_edges=mag_bin_edges, ) return feniks diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 4e70b476..92ae9c3b 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -77,7 +77,7 @@ def get_mag_ab(phot_table, col_name, ZP=25): return mag_ab -def get_N_1d_mag_bins(mags, mag_bin_edges=None, dmag=0.1, sig_scale=0.5): +def get_N_1d_mag_bins(mags, mag_bin_edges=None, dmag=0.2, sig_scale=0.5): mags = mags.reshape(mags.size, 1) if mag_bin_edges is None: mag_bin_edges = np.arange(mags.min(), mags.max(), dmag) @@ -170,6 +170,7 @@ def get_feniks_data( lgmp_max=15.0, lc_sky_area_degsq=100, n_z_phot_table=30, + mag_bin_edges=None, ): # Transmission curves and filter mag thresholds @@ -419,7 +420,9 @@ def get_feniks_data( magbin_bands = [] N_bands = [] for band in range(0, n_bands): - magbin_edges, N_mags = get_N_1d_mag_bins(mags[:, band][z_sel]) + magbin_edges, N_mags = get_N_1d_mag_bins( + mags[:, band][z_sel], mag_bin_edges=mag_bin_edges + ) magbin_bands.append(magbin_edges) N_bands.append(N_mags) diff --git a/scripts/config_feniks.yaml b/scripts/config_feniks.yaml index 5e94a961..d80b5af7 100644 --- a/scripts/config_feniks.yaml +++ b/scripts/config_feniks.yaml @@ -4,7 +4,7 @@ start_runid: "run90" start_fit_type: "all" fit_runid: "runtest" -fit_type: "all" +fit_type: "diffstarpop+spspop+merging" feniks: num_halos: 100 From 653906e668d8b92bf2c0f16ff5765d8bfac92127 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 6 Jun 2026 14:39:25 -0500 Subject: [PATCH 03/57] 2D conditioned color spaces of Feniks --- .../experimental/data_loaders/load_feniks.py | 292 +++++++++++++----- diffhtwo/experimental/defaults.py | 8 +- scripts/config_diagnostics.yaml | 8 +- scripts/generate_diagnostic_plots.py | 11 +- 4 files changed, 218 insertions(+), 101 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 92ae9c3b..41a3dba5 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -77,24 +77,60 @@ def get_mag_ab(phot_table, col_name, ZP=25): return mag_ab -def get_N_1d_mag_bins(mags, mag_bin_edges=None, dmag=0.2, sig_scale=0.5): - mags = mags.reshape(mags.size, 1) - if mag_bin_edges is None: - mag_bin_edges = np.arange(mags.min(), mags.max(), dmag) +def get_N_1d(dim1, dim1_bin_edges=None, dmag=0.2, sig_scale=0.5): + dataset = dim1.reshape(dim1.size, 1) + if dim1_bin_edges is None: + dim1_bin_edges = np.arange(dim1.min(), dim1.max(), dmag) - mag_lo = mag_bin_edges[:-1].reshape(mag_bin_edges[:-1].size, 1) - mag_hi = mag_bin_edges[1:].reshape(mag_bin_edges[1:].size, 1) + bin_lo = dim1_bin_edges[:-1].reshape(dim1_bin_edges[:-1].size, 1) + bin_hi = dim1_bin_edges[1:].reshape(dim1_bin_edges[1:].size, 1) - sig = jnp.zeros_like(mag_lo) + (dmag * sig_scale) + sig = jnp.zeros_like(bin_lo) + (dmag * sig_scale) - N_mags = diffndhist_lomem.tw_ndhist( - mags, + N_1d = diffndhist_lomem.tw_ndhist( + dataset, sig, - mag_lo, - mag_hi, + bin_lo, + bin_hi, ) - return mag_bin_edges, N_mags + return ( + N_1d, + sig, + bin_lo, + bin_hi, + ) + + +def get_N_2d(dim1, dim2, sig_scale=0.5): + dataset = np.vstack((dim1, dim2)).T + + dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 11) + dim2_bin_edges = np.linspace(dim2.min(), dim2.max(), 11) + + dim1_lo = dim1_bin_edges[:-1] + dim2_lo = dim2_bin_edges[:-1] + bin_lo = np.meshgrid(dim1_lo, dim2_lo, indexing="ij") + bin_lo = np.array(bin_lo).T.reshape(-1, 2) + + dim1_hi = dim1_bin_edges[1:] + dim2_hi = dim2_bin_edges[1:] + bin_hi = np.meshgrid(dim1_hi, dim2_hi, indexing="ij") + bin_hi = np.array(bin_hi).T.reshape(-1, 2) + + sig1 = np.diff(dim1_bin_edges) * sig_scale + sig2 = np.diff(dim2_bin_edges) * sig_scale + sig = np.meshgrid(sig1, sig2, indexing="ij") + sig = np.array(sig).T.reshape(-1, 2) + + N_2d = diffndhist_lomem.tw_ndhist( + dataset, + sig, + bin_lo, + bin_hi, + ) + + return N_2d, sig, bin_lo, bin_hi def refresh_lh_centroids(DATASET, lh_d_mag): @@ -274,31 +310,6 @@ def get_feniks_data( uds_H = uds_H[clean] uds_K = uds_K[clean] - # mask nans - # nans = ( - # (megacam_uS == -99.0) - # | (hsc_g == -99.0) - # | (hsc_r == -99.0) - # | (hsc_i == -99.0) - # | (hsc_z == -99.0) - # | (video_Y == -99) - # | (uds_J == -99.0) - # | (uds_H == -99.0) - # | (uds_K == -99.0) - # ) - - # megacam_uS = megacam_uS[~nans] - # hsc_g = hsc_g[~nans] - # hsc_r = hsc_r[~nans] - # hsc_i = hsc_i[~nans] - # hsc_z = hsc_z[~nans] - # video_Y = video_Y[~nans] - # uds_J = uds_J[~nans] - # uds_H = uds_H[~nans] - # uds_K = uds_K[~nans] - - # zout = zout[~nans] - N_obj_post_cuts = len(zout) frac_cat = N_obj_post_cuts / N_obj_pre_cuts @@ -320,6 +331,7 @@ def get_feniks_data( # derive colors from mags megacam_hsc_uSg = megacam_uS - hsc_g hsc_gr = hsc_g - hsc_r + hsc_rz = hsc_r - hsc_z hsc_ri = hsc_r - hsc_i hsc_iz = hsc_i - hsc_z hsc_uds_zJ = hsc_z - uds_J @@ -370,64 +382,176 @@ def get_feniks_data( ] # mask redshift - z_mask = (zout["z_phot"] > FENIKS_Z_MIN) & (zout["z_phot"] <= FENIKS_Z_MAX) - dataset = dataset[z_mask] - mags = mags[z_mask] - zout = zout[z_mask] + # z_mask = (zout["z_phot"] > FENIKS_Z_MIN) & (zout["z_phot"] <= FENIKS_Z_MAX) + # dataset = dataset[z_mask] + # mags = mags[z_mask] + # zout = zout[z_mask] # prepare 1D app mag functions in z-bins for fitting zbins = np.array( [ - [0.2, 0.5], - [0.5, 0.8], - [0.8, 1.2], - [1.2, 1.6], - [1.6, 2.0], + [0.2, 0.7], + [0.7, 1.5], + [1.5, 2.5], ] ) - n_gals, n_dim = mags.shape - n_bands = n_dim - 1 - n_z_bins = len(zbins) - magbin_zbins_bands = [] - N_zbins_bands = [] lc_data = [] - for zbin in range(n_z_bins): - z_min = zbins[zbin][0] - z_max = zbins[zbin][1] - z_phot_table = 10 ** jnp.linspace( - jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table - ) - lc_args = ( - ran_key, - num_halos, - z_min, - z_max, - lgmp_min, - lgmp_max, - lc_sky_area_degsq, - ssp_data, - tcurves, - z_phot_table, - ) + # Z1 --> get spaces: 2D (g-r, r-i), 1D (u-g | K) + zbin = 0 + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data.append(generate_lc_data(*lc_args)) - lc_data.append(generate_lc_data(*lc_args)) + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + Z1 = namedtuple( + "Z1", + [ + "z_min", + "z_max", + "gr_ri", + "ug", + ], + ) - z_sel = (mags[:, -1] > z_min) & (mags[:, -1] <= z_max) + Gr_ri = namedtuple("Gr_ri", ["N", "sig", "bin_lo", "bin_hi"]) + N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri = get_N_2d( + hsc_gr[z_sel], hsc_ri[z_sel] + ) + gr_ri = Gr_ri(N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri) - magbin_bands = [] - N_bands = [] - for band in range(0, n_bands): - magbin_edges, N_mags = get_N_1d_mag_bins( - mags[:, band][z_sel], mag_bin_edges=mag_bin_edges - ) - magbin_bands.append(magbin_edges) - N_bands.append(N_mags) + Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) + ug = [] + Ug_condK = namedtuple( + "Ug_condK", ["K_min", "K_max", "N", "sig", "bin_lo", "bin_hi"] + ) + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d, sig, bin_lo, bin_hi = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + ug.append(Ug_condK(Kbins[k], Kbins[k + 1], N_1d, sig, bin_lo, bin_hi)) + + z1 = Z1(z_min, z_max, gr_ri, ug) + + # Z2 --> get spaces: 2D (r-z, z-J), 1D (u-g | K) + zbin = 1 + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data.append(generate_lc_data(*lc_args)) + + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + Z2 = namedtuple( + "Z2", + [ + "z_min", + "z_max", + "rz_zJ", + "ug", + ], + ) + + Rz_zJ = namedtuple("Rz_zJ", ["N", "sig", "bin_lo", "bin_hi"]) + N_rz_zJ, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ = get_N_2d( + hsc_rz[z_sel], hsc_uds_zJ[z_sel] + ) + rz_zJ = Rz_zJ(N_rz_zJ, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ) + + Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) + ug = [] + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d, sig, bin_lo, bin_hi = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + ug.append(Ug_condK(Kbins[k], Kbins[k + 1], N_1d, sig, bin_lo, bin_hi)) + + z2 = Z2(z_min, z_max, rz_zJ, ug) + + # Z3 --> get spaces: 2D (z-J, J-H), 1D (u-g | K), 1D (g-r | K) + zbin = 2 + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data.append(generate_lc_data(*lc_args)) + + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + Z3 = namedtuple( + "Z3", + ["z_min", "z_max", "zJ_JH", "ug", "gr"], + ) + + zJ_JH = namedtuple("zJ_JH", ["N", "sig", "bin_lo", "bin_hi"]) + N_zJ_JH, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH = get_N_2d( + hsc_uds_zJ[z_sel], uds_JH[z_sel] + ) + zJ_JH = zJ_JH(N_zJ_JH, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH) + + Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) + ug = [] + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d, sig, bin_lo, bin_hi = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + ug.append(Ug_condK(Kbins[k], Kbins[k + 1], N_1d, sig, bin_lo, bin_hi)) + + gr = [] + Gr_condK = namedtuple( + "Gr_condK", ["K_min", "K_max", "N", "sig", "bin_lo", "bin_hi"] + ) + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d, sig, bin_lo, bin_hi = get_N_1d(hsc_gr[z_sel][K_sel]) + gr.append(Gr_condK(Kbins[k], Kbins[k + 1], N_1d, sig, bin_lo, bin_hi)) - magbin_zbins_bands.append(magbin_bands) - N_zbins_bands.append(N_bands) + z3 = Z3(z_min, z_max, zJ_JH, ug, gr) lh_centroids, d_centroids = get_lh_centroids(dataset, lh_d_mag) @@ -448,9 +572,9 @@ def get_feniks_data( dataset_dim_labels, mags, mags_labels, - zbins, - magbin_zbins_bands, - N_zbins_bands, + z1, + z2, + z3, lc_data, filter_info, frac_cat, diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 62ce707a..e54833ed 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -19,7 +19,7 @@ FENIKS_AREA_DEG2 = 2828.247933129912 / 3600 FENIKS_Z_MIN = 0.2 -FENIKS_Z_MAX = 3.0 +FENIKS_Z_MAX = 2.5 FENIKS_MAGK_THRESH = 24.3 # col mag SDSS_AREA_DEG2 = 7199 @@ -35,9 +35,9 @@ "dataset_dim_labels", "mags", "mags_labels", - "zbins", - "magbin_zbins_bands", - "N_zbins_bands", + "z1", + "z2", + "z3", "lc_data", "filter_info", "frac_cat", diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index e74b3081..0c9ffff2 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run133 -model_nickname: run133_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run133/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run134 +model_nickname: run134_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run134/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -13,7 +13,7 @@ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kr plot_sdss: False plot_feniks: True -plot_hizels: True +plot_hizels: False plots: num_halos : 3000 diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 97122cda..6439d171 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -204,15 +204,8 @@ if cfg["plot_feniks"]: feniks_label = "feniks" # + cfg["model_nickname"].split("_")[0] feniks = load_feniks.get_feniks_data(feniks_drn, ran_key, ssp_data) - feniks_zbins = np.array( - [ - [0.2, 0.5], - [0.5, 0.8], - [0.8, 1.2], - [1.2, 1.6], - [1.6, 2.0], - ] - ) + feniks_zbins = feniks.zbins + if cfg["plots"]["plot_app_mag_funcs"]: print("Generating FENIKS app mag funcs plot...") plot_app_mag_funcs( From 4731c388871505ebf7e9915e3ed1497b0656903f Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 6 Jun 2026 18:07:53 -0500 Subject: [PATCH 04/57] fit_N_phot_2d --- .../experimental/data_loaders/load_feniks.py | 121 ++++++++++++------ diffhtwo/experimental/defaults.py | 1 - diffhtwo/experimental/kernels/N_phot.py | 69 ++++++++++ .../experimental/loss_kernels/phot_loss.py | 106 ++++++++++++++- .../optimizers/Np_specphot_opt.py | 48 ++++++- scripts/fit_feniks.py | 2 +- 6 files changed, 304 insertions(+), 43 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 41a3dba5..f9daf52a 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -397,8 +397,6 @@ def get_feniks_data( ] ) - lc_data = [] - # Z1 --> get spaces: 2D (g-r, r-i), 1D (u-g | K) zbin = 0 z_min = zbins[zbin][0] @@ -420,36 +418,46 @@ def get_feniks_data( z_phot_table, ) - lc_data.append(generate_lc_data(*lc_args)) + lc_data = generate_lc_data(*lc_args) z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z1 = namedtuple( "Z1", - [ - "z_min", - "z_max", - "gr_ri", - "ug", - ], + ["z_min", "z_max", "lc_data", "gr_ri", "ug"], ) - Gr_ri = namedtuple("Gr_ri", ["N", "sig", "bin_lo", "bin_hi"]) + Gr_ri = namedtuple("Gr_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri = get_N_2d( hsc_gr[z_sel], hsc_ri[z_sel] ) - gr_ri = Gr_ri(N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri) + col_idx = [1, 2, 3] + gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) ug = [] Ug_condK = namedtuple( - "Ug_condK", ["K_min", "K_max", "N", "sig", "bin_lo", "bin_hi"] + "Ug_condK", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], ) + col_idx = [0, 1] + cond_idx = 7 for k in range(len(Kbins) - 1): K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d, sig, bin_lo, bin_hi = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) - ug.append(Ug_condK(Kbins[k], Kbins[k + 1], N_1d, sig, bin_lo, bin_hi)) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + ug.append( + Ug_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_ug, + bin_lo_ug, + bin_hi_ug, + N_1d_ug, + ) + ) - z1 = Z1(z_min, z_max, gr_ri, ug) + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug) # Z2 --> get spaces: 2D (r-z, z-J), 1D (u-g | K) zbin = 1 @@ -472,33 +480,42 @@ def get_feniks_data( z_phot_table, ) - lc_data.append(generate_lc_data(*lc_args)) + lc_data = generate_lc_data(*lc_args) z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z2 = namedtuple( "Z2", - [ - "z_min", - "z_max", - "rz_zJ", - "ug", - ], + ["z_min", "z_max", "lc_data", "rz_zJ", "ug"], ) - Rz_zJ = namedtuple("Rz_zJ", ["N", "sig", "bin_lo", "bin_hi"]) + Rz_zJ = namedtuple("Rz_zJ", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) N_rz_zJ, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ = get_N_2d( hsc_rz[z_sel], hsc_uds_zJ[z_sel] ) - rz_zJ = Rz_zJ(N_rz_zJ, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ) + col_idx = [2, 4, 5] + rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) ug = [] + col_idx = [0, 1] + cond_idx = 7 for k in range(len(Kbins) - 1): K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d, sig, bin_lo, bin_hi = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) - ug.append(Ug_condK(Kbins[k], Kbins[k + 1], N_1d, sig, bin_lo, bin_hi)) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + ug.append( + Ug_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_ug, + bin_lo_ug, + bin_hi_ug, + N_1d_ug, + ) + ) - z2 = Z2(z_min, z_max, rz_zJ, ug) + z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug) # Z3 --> get spaces: 2D (z-J, J-H), 1D (u-g | K), 1D (g-r | K) zbin = 2 @@ -521,37 +538,65 @@ def get_feniks_data( z_phot_table, ) - lc_data.append(generate_lc_data(*lc_args)) + lc_data = generate_lc_data(*lc_args) z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z3 = namedtuple( "Z3", - ["z_min", "z_max", "zJ_JH", "ug", "gr"], + ["z_min", "z_max", "lc_data", "zJ_JH", "ug", "gr"], ) - zJ_JH = namedtuple("zJ_JH", ["N", "sig", "bin_lo", "bin_hi"]) + zJ_JH = namedtuple("zJ_JH", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) N_zJ_JH, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH = get_N_2d( hsc_uds_zJ[z_sel], uds_JH[z_sel] ) - zJ_JH = zJ_JH(N_zJ_JH, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH) + col_idx = [4, 5, 6] + zJ_JH = zJ_JH(col_idx, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH, N_zJ_JH) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) ug = [] + col_idx = [0, 1] + cond_idx = 7 for k in range(len(Kbins) - 1): K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d, sig, bin_lo, bin_hi = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) - ug.append(Ug_condK(Kbins[k], Kbins[k + 1], N_1d, sig, bin_lo, bin_hi)) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + ug.append( + Ug_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_ug, + bin_lo_ug, + bin_hi_ug, + N_1d_ug, + ) + ) gr = [] Gr_condK = namedtuple( - "Gr_condK", ["K_min", "K_max", "N", "sig", "bin_lo", "bin_hi"] + "Gr_condK", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], ) + col_idx = [1, 2] + cond_idx = 7 for k in range(len(Kbins) - 1): K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d, sig, bin_lo, bin_hi = get_N_1d(hsc_gr[z_sel][K_sel]) - gr.append(Gr_condK(Kbins[k], Kbins[k + 1], N_1d, sig, bin_lo, bin_hi)) + N_1d_gr, sig_gr, bin_lo_gr, bin_hi_gr = get_N_1d(hsc_gr[z_sel][K_sel]) + gr.append( + Gr_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_gr, + bin_lo_gr, + bin_hi_gr, + N_1d_gr, + ) + ) - z3 = Z3(z_min, z_max, zJ_JH, ug, gr) + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug, gr) lh_centroids, d_centroids = get_lh_centroids(dataset, lh_d_mag) @@ -575,7 +620,6 @@ def get_feniks_data( z1, z2, z3, - lc_data, filter_info, frac_cat, lh_centroids, @@ -595,7 +639,6 @@ def get_feniks_data( "HSC_R", "HSC_I", "HSC_Z", - # "VIDEO_Y", "UDS_J", "UDS_H", "UDS_K", diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index e54833ed..a37f147c 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -38,7 +38,6 @@ "z1", "z2", "z3", - "lc_data", "filter_info", "frac_cat", "lh_centroids", diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index 0bdab7f1..935a63c2 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -1,3 +1,4 @@ +from collections import namedtuple from functools import partial import jax.numpy as jnp @@ -8,6 +9,74 @@ from .phot_kern import get_colors_mags, mag_kern +@jjit +def N_colors_mags( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, +): + obs_mags, gal_weight, phot_kern_results = mag_kern( + ran_key, + param_collection, + z_data.lc_data, + mag_thresh, + frac_cat, + ) + fields = z_data._fields[3:] + for f in range(0, len(fields)): + data = getattr(z_data, fields[f]) + + if isinstance(data, list): + new_list = [] + for d in range(0, len(data)): + data_n = data[d] + + obs_mags_cond = obs_mags[:, data_n.cond_idx] + cond = (obs_mags_cond > data_n.K_min) & (obs_mags_cond <= data_n.K_max) + cond_weight = jnp.where(cond, 1.0, 0.0) + + obs_color = ( + obs_mags[:, data_n.col_idx[0]] - obs_mags[:, data_n.col_idx[1]] + ) + obs_color = obs_color.reshape(obs_color.size, 1) + + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_color, + data_n.sig, + gal_weight * cond_weight, + data_n.bin_lo, + data_n.bin_hi, + ) + + NewTuple = namedtuple( + type(data_n).__name__, [*data_n._fields, "N_model"] + ) + new_list.append(NewTuple(*data_n, N_model)) + z_data = z_data._replace(**{fields[f]: new_list}) + else: + col_idx = data.col_idx + obs_colors = [] + for c in range(0, len(col_idx) - 1): + obs_color = obs_mags[:, col_idx[c]] - obs_mags[:, col_idx[c + 1]] + obs_colors.append(obs_color) + obs_colors = jnp.array(obs_colors).T + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_colors, + data.sig, + gal_weight, + data.bin_lo, + data.bin_hi, + ) + + NewTuple = namedtuple(type(data).__name__, [*data._fields, "N_model"]) + new = NewTuple(*data, N_model) + z_data = z_data._replace(**{fields[f]: new}) + + return z_data + + @jjit def N_mags_1d( ran_key, diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index 17b74c32..165ba9ab 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -1,10 +1,114 @@ from jax import jit as jjit -from ..kernels.N_phot import N_colors_mags_lh, N_mags_1d +from ..kernels.N_phot import N_colors_mags, N_colors_mags_lh, N_mags_1d from ..param_utils import get_param_collection_from_u_theta from .loss_functions import poisson_loss +@jjit +def get_phot_loss_2d( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, +): + z_data = N_colors_mags( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, + ) + phot_loss_2d = 0.0 + fields = z_data._fields[3:] + for f in range(0, len(fields)): + data = getattr(z_data, fields[f]) + + if isinstance(data, list): + for d in range(0, len(data)): + data_n = data[d] + + N_model = data_n.N_model + N_data = data_n.N_data + + N_model = N_model * ( + data_sky_area_degsq / z_data.lc_data.sky_area_degsq + ) + phot_loss_2d += poisson_loss(N_model, N_data) + + else: + N_model = data.N_model + N_data = data.N_data + + N_model = N_model * (data_sky_area_degsq / z_data.lc_data.sky_area_degsq) + phot_loss_2d += poisson_loss(N_model, N_data) + + return phot_loss_2d + + +@jjit +def _loss_phot_kern_2d( + u_theta, + ran_key, + z_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, +): + param_collection = get_param_collection_from_u_theta(u_theta) + + phot_loss_2d = get_phot_loss_2d( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, + ) + + return phot_loss_2d + + +@jjit +def _loss_phot_kern_2d_multiz( + u_theta, + ran_key, + fitting_data, +): + param_collection = get_param_collection_from_u_theta(u_theta) + + phot_loss_2d = 0.0 + + phot_loss_2d += get_phot_loss_2d( + ran_key, + param_collection, + fitting_data.z1, + fitting_data.filter_info.mag_thresh, + fitting_data.frac_cat, + fitting_data.data_sky_area_degsq, + ) + phot_loss_2d += get_phot_loss_2d( + ran_key, + param_collection, + fitting_data.z2, + fitting_data.filter_info.mag_thresh, + fitting_data.frac_cat, + fitting_data.data_sky_area_degsq, + ) + phot_loss_2d += get_phot_loss_2d( + ran_key, + param_collection, + fitting_data.z3, + fitting_data.filter_info.mag_thresh, + fitting_data.frac_cat, + fitting_data.data_sky_area_degsq, + ) + + return phot_loss_2d + + @jjit def get_phot_loss_1d( ran_key, diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 28cb2993..2ac416dd 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -14,7 +14,11 @@ from jax.example_libraries import optimizers as jax_opt from ..loss_kernels.emline_loss import _loss_emline_kern_multi_line_multi_z -from ..loss_kernels.phot_loss import _loss_phot_kern, _loss_phot_kern_multiband_multiz +from ..loss_kernels.phot_loss import ( + _loss_phot_kern, + _loss_phot_kern_2d_multiz, + _loss_phot_kern_multiband_multiz, +) _L_pk = ( None, @@ -37,6 +41,8 @@ value_and_grad(_loss_phot_kern_multiband_multiz) ) +_loss_and_grad_phot_kern_2d_multiz = jjit(value_and_grad(_loss_phot_kern_2d_multiz)) + @partial(jjit, static_argnames=["n_steps", "step_size"]) def fit_N_multi_z( @@ -80,6 +86,46 @@ def _opt_update(opt_state, i): return loss_hist, u_theta_fit +@partial(jjit, static_argnames=["n_steps", "step_size"]) +def fit_N_phot_2d( + u_theta_init, + trainable, + ran_key, + fitting_data, + n_steps=2, + step_size=1e-2, +): + opt_init, opt_update, get_params = jax_opt.adam(step_size) + opt_state = opt_init(u_theta_init) + + other = ( + ran_key, + fitting_data, + ) + + def _opt_update(opt_state, i): + u_theta = get_params(opt_state) + loss, grads = _loss_and_grad_phot_kern_2d_multiz(u_theta, *other) + # set grads for untrainable params to 0.0 + grads = tuple( + jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) + ) + + # clip gradients + # global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) + + opt_state = opt_update(i, grads, opt_state) + return opt_state, loss + + opt_state, loss_hist = lax.scan(_opt_update, opt_state, jnp.arange(n_steps)) + u_theta_fit = get_params(opt_state) + + return loss_hist, u_theta_fit + + @partial(jjit, static_argnames=["n_steps", "step_size"]) def fit_N_phot_1d( u_theta_init, diff --git a/scripts/fit_feniks.py b/scripts/fit_feniks.py index 66f24a2c..d57909a3 100644 --- a/scripts/fit_feniks.py +++ b/scripts/fit_feniks.py @@ -104,7 +104,7 @@ for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - loss_hist, u_theta_fit = Np_specphot_opt.fit_N_phot_1d( + loss_hist, u_theta_fit = Np_specphot_opt.fit_N_phot_2d( u_theta_fit, trainable_params, ran_key, From debd9ef06b16592b4790a308be48c877d535631e Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 6 Jun 2026 18:49:27 -0500 Subject: [PATCH 05/57] add K band app mag func to fitting --- .../experimental/data_loaders/load_feniks.py | 30 ++++++++++++++----- diffhtwo/experimental/kernels/N_phot.py | 18 +++++++++++ scripts/config_diagnostics.yaml | 6 ++-- scripts/generate_diagnostic_plots.py | 8 ++++- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index f9daf52a..0a8c35b2 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -397,7 +397,7 @@ def get_feniks_data( ] ) - # Z1 --> get spaces: 2D (g-r, r-i), 1D (u-g | K) + # Z1 --> get spaces: 2D (g-r, r-i), 1D (u-g | K), 1D (K) zbin = 0 z_min = zbins[zbin][0] z_max = zbins[zbin][1] @@ -423,7 +423,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z1 = namedtuple( "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "ug"], + ["z_min", "z_max", "lc_data", "gr_ri", "ug", "k"], ) Gr_ri = namedtuple("Gr_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -457,7 +457,15 @@ def get_feniks_data( ) ) - z1 = Z1(z_min, z_max, lc_data, gr_ri, ug) + K = namedtuple( + "K", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_idx = 7 + N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) + k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, k) # Z2 --> get spaces: 2D (r-z, z-J), 1D (u-g | K) zbin = 1 @@ -485,7 +493,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z2 = namedtuple( "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "ug"], + ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "k"], ) Rz_zJ = namedtuple("Rz_zJ", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -515,7 +523,11 @@ def get_feniks_data( ) ) - z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug) + mag_idx = 7 + N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) + k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + + z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, k) # Z3 --> get spaces: 2D (z-J, J-H), 1D (u-g | K), 1D (g-r | K) zbin = 2 @@ -543,7 +555,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z3 = namedtuple( "Z3", - ["z_min", "z_max", "lc_data", "zJ_JH", "ug", "gr"], + ["z_min", "z_max", "lc_data", "zJ_JH", "ug", "gr", "k"], ) zJ_JH = namedtuple("zJ_JH", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -596,7 +608,11 @@ def get_feniks_data( ) ) - z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug, gr) + mag_idx = 7 + N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) + k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug, gr, k) lh_centroids, d_centroids = get_lh_centroids(dataset, lh_d_mag) diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index 935a63c2..66a5389c 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -55,6 +55,24 @@ def N_colors_mags( ) new_list.append(NewTuple(*data_n, N_model)) z_data = z_data._replace(**{fields[f]: new_list}) + + elif "mag_idx" in data._fields: + mag_idx = data.mag_idx + obs_mag = obs_mags[:, mag_idx] + obs_mag = obs_mag.reshape(obs_mag.size, 1) + + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_mag, + data.sig, + gal_weight, + data.bin_lo, + data.bin_hi, + ) + + NewTuple = namedtuple(type(data).__name__, [*data._fields, "N_model"]) + new = NewTuple(*data, N_model) + z_data = z_data._replace(**{fields[f]: new}) + else: col_idx = data.col_idx obs_colors = [] diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 0c9ffff2..3e363cee 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run134 -model_nickname: run134_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run134/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run135 +model_nickname: run135_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run135/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 6439d171..0cd5c9e6 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -204,7 +204,13 @@ if cfg["plot_feniks"]: feniks_label = "feniks" # + cfg["model_nickname"].split("_")[0] feniks = load_feniks.get_feniks_data(feniks_drn, ran_key, ssp_data) - feniks_zbins = feniks.zbins + feniks_zbins = np.array( + [ + (feniks.z1.z_min, feniks.z1.z_max), + (feniks.z2.z_min, feniks.z2.z_max), + (feniks.z3.z_min, feniks.z3.z_max), + ] + ) if cfg["plots"]["plot_app_mag_funcs"]: print("Generating FENIKS app mag funcs plot...") From 81e3fb27a2e83291401e0641528d64a948a5429b Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 6 Jun 2026 20:14:15 -0500 Subject: [PATCH 06/57] adjust mag weights --- .../experimental/data_loaders/load_feniks.py | 66 ++++++++++++++----- diffhtwo/experimental/kernels/N_phot.py | 31 +++++++-- scripts/config_diagnostics.yaml | 6 +- 3 files changed, 75 insertions(+), 28 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 0a8c35b2..0d7e171f 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -256,17 +256,7 @@ def get_feniks_data( filter_info = FilterInfo(feniks_mag_thresh, feniks_in_lh, tcurves) # get mag thresh cuts - mag_thresh = ( - (megacam_uS < feniks_mag_thresh.MegaCam_uS) - & (hsc_g < feniks_mag_thresh.HSC_G) - & (hsc_r < feniks_mag_thresh.HSC_R) - & (hsc_i < feniks_mag_thresh.HSC_I) - & (hsc_z < feniks_mag_thresh.HSC_Z) - # & (video_Y < feniks_mag_thresh.VIDEO_Y) - & (uds_J < feniks_mag_thresh.UDS_J) - & (uds_H < feniks_mag_thresh.UDS_H) - & (uds_K < feniks_mag_thresh.UDS_K) - ) + mag_thresh = uds_K < feniks_mag_thresh.UDS_K # apply mag_thresh cuts and record n_gals. # This is the starting point from which any further cuts will @@ -397,6 +387,7 @@ def get_feniks_data( ] ) + ############################################################################## # Z1 --> get spaces: 2D (g-r, r-i), 1D (u-g | K), 1D (K) zbin = 0 z_min = zbins[zbin][0] @@ -427,8 +418,13 @@ def get_feniks_data( ) Gr_ri = namedtuple("Gr_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + mag_sel_gr_ri = ( + (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) + & (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) + & (hsc_i[z_sel] < feniks_mag_thresh.HSC_I) + ) N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri = get_N_2d( - hsc_gr[z_sel], hsc_ri[z_sel] + hsc_gr[z_sel][mag_sel_gr_ri], hsc_ri[z_sel][mag_sel_gr_ri] ) col_idx = [1, 2, 3] gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) @@ -439,11 +435,16 @@ def get_feniks_data( "Ug_condK", ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], ) + mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( + hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ) col_idx = [0, 1] cond_idx = 7 for k in range(len(Kbins) - 1): K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( + megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] + ) ug.append( Ug_condK( col_idx, @@ -467,6 +468,7 @@ def get_feniks_data( z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, k) + ############################################################################## # Z2 --> get spaces: 2D (r-z, z-J), 1D (u-g | K) zbin = 1 z_min = zbins[zbin][0] @@ -497,19 +499,29 @@ def get_feniks_data( ) Rz_zJ = namedtuple("Rz_zJ", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + mag_sel_rz_zJ = ( + (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) + & (hsc_z[z_sel] < feniks_mag_thresh.HSC_Z) + & (uds_J[z_sel] < feniks_mag_thresh.UDS_J) + ) N_rz_zJ, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ = get_N_2d( - hsc_rz[z_sel], hsc_uds_zJ[z_sel] + hsc_rz[z_sel][mag_sel_rz_zJ], hsc_uds_zJ[z_sel][mag_sel_rz_zJ] ) col_idx = [2, 4, 5] rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) ug = [] + mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( + hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ) col_idx = [0, 1] cond_idx = 7 for k in range(len(Kbins) - 1): K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( + megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] + ) ug.append( Ug_condK( col_idx, @@ -529,6 +541,7 @@ def get_feniks_data( z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, k) + ############################################################################## # Z3 --> get spaces: 2D (z-J, J-H), 1D (u-g | K), 1D (g-r | K) zbin = 2 z_min = zbins[zbin][0] @@ -559,19 +572,29 @@ def get_feniks_data( ) zJ_JH = namedtuple("zJ_JH", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + mag_sel_zJ_JH = ( + (hsc_z[z_sel] < feniks_mag_thresh.HSC_Z) + & (uds_J[z_sel] < feniks_mag_thresh.UDS_J) + & (uds_H[z_sel] < feniks_mag_thresh.UDS_H) + ) N_zJ_JH, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH = get_N_2d( - hsc_uds_zJ[z_sel], uds_JH[z_sel] + hsc_uds_zJ[z_sel][mag_sel_zJ_JH], uds_JH[z_sel][mag_sel_zJ_JH] ) col_idx = [4, 5, 6] zJ_JH = zJ_JH(col_idx, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH, N_zJ_JH) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) ug = [] + mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( + hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ) col_idx = [0, 1] cond_idx = 7 for k in range(len(Kbins) - 1): K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d(megacam_hsc_uSg[z_sel][K_sel]) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( + megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] + ) ug.append( Ug_condK( col_idx, @@ -590,11 +613,16 @@ def get_feniks_data( "Gr_condK", ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], ) + mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( + hsc_r[z_sel] < feniks_mag_thresh.HSC_R + ) col_idx = [1, 2] cond_idx = 7 for k in range(len(Kbins) - 1): K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_gr, sig_gr, bin_lo_gr, bin_hi_gr = get_N_1d(hsc_gr[z_sel][K_sel]) + N_1d_gr, sig_gr, bin_lo_gr, bin_hi_gr = get_N_1d( + hsc_gr[z_sel][mag_sel_gr & K_sel] + ) gr.append( Gr_condK( col_idx, @@ -614,6 +642,8 @@ def get_feniks_data( z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug, gr, k) + ############################################################################## + lh_centroids, d_centroids = get_lh_centroids(dataset, lh_d_mag) # run initial diffndhist_lomem with fixed dmag diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index 66a5389c..d54ee1df 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -25,6 +25,7 @@ def N_colors_mags( frac_cat, ) fields = z_data._fields[3:] + mag_thresh = jnp.array(mag_thresh) for f in range(0, len(fields)): data = getattr(z_data, fields[f]) @@ -32,20 +33,25 @@ def N_colors_mags( new_list = [] for d in range(0, len(data)): data_n = data[d] + col_idx = data_n.col_idx + # get cond weight obs_mags_cond = obs_mags[:, data_n.cond_idx] cond = (obs_mags_cond > data_n.K_min) & (obs_mags_cond <= data_n.K_max) - cond_weight = jnp.where(cond, 1.0, 0.0) + weight = jnp.where(cond, gal_weight, 0.0) - obs_color = ( - obs_mags[:, data_n.col_idx[0]] - obs_mags[:, data_n.col_idx[1]] - ) + # get mag_sel weight + for c in range(0, len(col_idx)): + mag_sel = obs_mags[:, col_idx[c]] < mag_thresh[col_idx[c]] + weight *= jnp.where(mag_sel, 1.0, 0.0) + + obs_color = obs_mags[:, col_idx[0]] - obs_mags[:, col_idx[1]] obs_color = obs_color.reshape(obs_color.size, 1) N_model = diffndhist_lomem.tw_ndhist_weighted( obs_color, data_n.sig, - gal_weight * cond_weight, + weight, data_n.bin_lo, data_n.bin_hi, ) @@ -61,10 +67,14 @@ def N_colors_mags( obs_mag = obs_mags[:, mag_idx] obs_mag = obs_mag.reshape(obs_mag.size, 1) + # get mag_sel weight + mag_sel = obs_mags[:, mag_idx] < mag_thresh[mag_idx] + weight = jnp.where(mag_sel, gal_weight, 0.0) + N_model = diffndhist_lomem.tw_ndhist_weighted( obs_mag, data.sig, - gal_weight, + weight, data.bin_lo, data.bin_hi, ) @@ -80,10 +90,17 @@ def N_colors_mags( obs_color = obs_mags[:, col_idx[c]] - obs_mags[:, col_idx[c + 1]] obs_colors.append(obs_color) obs_colors = jnp.array(obs_colors).T + + # get mag_sel weight + weight = gal_weight.copy() + for c in range(0, len(col_idx)): + mag_sel = obs_mags[:, col_idx[c]] < mag_thresh[col_idx[c]] + weight *= jnp.where(mag_sel, 1.0, 0.0) + N_model = diffndhist_lomem.tw_ndhist_weighted( obs_colors, data.sig, - gal_weight, + weight, data.bin_lo, data.bin_hi, ) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 3e363cee..20fe5689 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run135 -model_nickname: run135_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run135/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run136 +model_nickname: run136_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run136/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From 7543846c3c822dcf1fb2adcf8cfea93705cbd502 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 6 Jun 2026 20:59:50 -0500 Subject: [PATCH 07/57] mag_thresh --- diffhtwo/experimental/data_loaders/load_feniks.py | 11 ++++++++++- scripts/config_diagnostics.yaml | 6 +++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 0d7e171f..2cc6ebd4 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -256,7 +256,16 @@ def get_feniks_data( filter_info = FilterInfo(feniks_mag_thresh, feniks_in_lh, tcurves) # get mag thresh cuts - mag_thresh = uds_K < feniks_mag_thresh.UDS_K + mag_thresh = ( + (megacam_uS < feniks_mag_thresh.MegaCam_uS) + & (hsc_g < feniks_mag_thresh.HSC_G) + & (hsc_r < feniks_mag_thresh.HSC_R) + & (hsc_i < feniks_mag_thresh.HSC_I) + & (hsc_z < feniks_mag_thresh.HSC_Z) + & (uds_J < feniks_mag_thresh.UDS_J) + & (uds_H < feniks_mag_thresh.UDS_H) + & (uds_K < feniks_mag_thresh.UDS_K) + ) # apply mag_thresh cuts and record n_gals. # This is the starting point from which any further cuts will diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 20fe5689..6930b3f8 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run136 -model_nickname: run136_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run136/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run137 +model_nickname: run137_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run137/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From 9b121d40759c1ba9bee0c36e0daccfd6987b2879 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 6 Jun 2026 21:21:16 -0500 Subject: [PATCH 08/57] add u-band app mag func to loss --- .../experimental/data_loaders/load_feniks.py | 36 +++++++++++++++---- scripts/config_diagnostics.yaml | 6 ++-- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 2cc6ebd4..60f54fea 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -423,7 +423,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z1 = namedtuple( "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "ug", "k"], + ["z_min", "z_max", "lc_data", "gr_ri", "ug", "k", "u"], ) Gr_ri = namedtuple("Gr_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -475,7 +475,15 @@ def get_feniks_data( N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, k) + U = namedtuple( + "U", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_idx = 0 + N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(megacam_uS[z_sel]) + u = U(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, k, u) ############################################################################## # Z2 --> get spaces: 2D (r-z, z-J), 1D (u-g | K) @@ -504,7 +512,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z2 = namedtuple( "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "k"], + ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "k", "u"], ) Rz_zJ = namedtuple("Rz_zJ", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -548,7 +556,15 @@ def get_feniks_data( N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, k) + U = namedtuple( + "U", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_idx = 0 + N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(megacam_uS[z_sel]) + u = U(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + + z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, k, u) ############################################################################## # Z3 --> get spaces: 2D (z-J, J-H), 1D (u-g | K), 1D (g-r | K) @@ -577,7 +593,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) Z3 = namedtuple( "Z3", - ["z_min", "z_max", "lc_data", "zJ_JH", "ug", "gr", "k"], + ["z_min", "z_max", "lc_data", "zJ_JH", "ug", "gr", "k", "u"], ) zJ_JH = namedtuple("zJ_JH", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -649,7 +665,15 @@ def get_feniks_data( N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug, gr, k) + U = namedtuple( + "U", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_idx = 0 + N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(megacam_uS[z_sel]) + u = U(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug, gr, k, u) ############################################################################## diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 6930b3f8..7430ef5e 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run137 -model_nickname: run137_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run137/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run138 +model_nickname: run138_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run138/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From 23d6b51adeaccb79d57b615438c6b483fd6cf5f6 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 13:27:18 -0500 Subject: [PATCH 09/57] fit 2D colors spaces in coarse z-bins and 1D app mag funcs in fine z-bins --- .../experimental/data_loaders/load_feniks.py | 198 ++++++++++++------ diffhtwo/experimental/defaults.py | 5 +- .../experimental/diagnostics/plot_phot.py | 12 +- diffhtwo/experimental/kernels/N_phot.py | 52 ++--- .../experimental/loss_kernels/phot_loss.py | 120 +++++------ scripts/config_diagnostics.yaml | 6 +- scripts/config_feniks.yaml | 3 +- scripts/fit_feniks.py | 6 +- 8 files changed, 242 insertions(+), 160 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 60f54fea..9e871831 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -201,7 +201,8 @@ def get_feniks_data( lh_d_mag=0.6, phot=PHOT, zout=ZOUT, - num_halos=250, + num_halos_coarse_zbins=150, + num_halos_fine_zbins=250, lgmp_min=10.0, lgmp_max=15.0, lc_sky_area_degsq=100, @@ -334,7 +335,7 @@ def get_feniks_data( hsc_ri = hsc_r - hsc_i hsc_iz = hsc_i - hsc_z hsc_uds_zJ = hsc_z - uds_J - # video_uds_YJ = video_Y - uds_J + hsc_uds_rK = hsc_r - uds_K uds_JH = uds_J - uds_H uds_HK = uds_H - uds_K @@ -386,8 +387,8 @@ def get_feniks_data( # mags = mags[z_mask] # zout = zout[z_mask] - # prepare 1D app mag functions in z-bins for fitting - + ############################################################################## + # prepare 2D and 1D color spaces in coarse z-bins for fitting zbins = np.array( [ [0.2, 0.7], @@ -397,7 +398,16 @@ def get_feniks_data( ) ############################################################################## - # Z1 --> get spaces: 2D (g-r, r-i), 1D (u-g | K), 1D (K) + # Z1 spaces: + # 2D (g-r, r-i) + # 2D (u-g, r-K) + # 1D (u-g | K) + + colors = [] + Z1 = namedtuple( + "Z1", + ["z_min", "z_max", "lc_data", "gr_ri", "ug_rK", "ug"], + ) zbin = 0 z_min = zbins[zbin][0] z_max = zbins[zbin][1] @@ -407,7 +417,7 @@ def get_feniks_data( ) lc_args = ( ran_key, - num_halos, + num_halos_coarse_zbins, z_min, z_max, lgmp_min, @@ -421,11 +431,8 @@ def get_feniks_data( lc_data = generate_lc_data(*lc_args) z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) - Z1 = namedtuple( - "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "ug", "k", "u"], - ) + # 2D (g - r, r - i) Gr_ri = namedtuple("Gr_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) mag_sel_gr_ri = ( (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) @@ -438,7 +445,22 @@ def get_feniks_data( col_idx = [1, 2, 3] gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) + # 2D (u - g, r - K) + Ug_rK = namedtuple("Ug_rK", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + mag_sel_ugr = ( + (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) + & (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) + & (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) + ) + N_ug_rK, sig_ug_rK, bin_lo_ug_rK, bin_hi_ug_rK = get_N_2d( + megacam_hsc_uSg[z_sel][mag_sel_ugr], hsc_uds_rK[z_sel][mag_sel_ugr] + ) + col_idx = [0, 1, 7] + ug_rK = Ug_rK(col_idx, sig_ug_rK, bin_lo_ug_rK, bin_hi_ug_rK, N_ug_rK) + + # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) + ug = [] Ug_condK = namedtuple( "Ug_condK", @@ -467,26 +489,18 @@ def get_feniks_data( ) ) - K = namedtuple( - "K", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - mag_idx = 7 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) - k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - - U = namedtuple( - "U", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - mag_idx = 0 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(megacam_uS[z_sel]) - u = U(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - - z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, k, u) + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug) + colors.append(z1) ############################################################################## - # Z2 --> get spaces: 2D (r-z, z-J), 1D (u-g | K) + # Z2 spaces: + # 2D (r - z, z - J) + # 1D (u - g | K) + + Z2 = namedtuple( + "Z2", + ["z_min", "z_max", "lc_data", "rz_zJ", "ug"], + ) zbin = 1 z_min = zbins[zbin][0] z_max = zbins[zbin][1] @@ -496,7 +510,7 @@ def get_feniks_data( ) lc_args = ( ran_key, - num_halos, + num_halos_coarse_zbins, z_min, z_max, lgmp_min, @@ -510,11 +524,8 @@ def get_feniks_data( lc_data = generate_lc_data(*lc_args) z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) - Z2 = namedtuple( - "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "k", "u"], - ) + # 2D (r - z, z - J) Rz_zJ = namedtuple("Rz_zJ", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) mag_sel_rz_zJ = ( (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) @@ -527,7 +538,9 @@ def get_feniks_data( col_idx = [2, 4, 5] rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ) + # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) + ug = [] mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( hsc_g[z_sel] < feniks_mag_thresh.HSC_G @@ -552,22 +565,20 @@ def get_feniks_data( ) ) - mag_idx = 7 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) - k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - - U = namedtuple( - "U", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - mag_idx = 0 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(megacam_uS[z_sel]) - u = U(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - - z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, k, u) + z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug) + colors.append(z2) ############################################################################## - # Z3 --> get spaces: 2D (z-J, J-H), 1D (u-g | K), 1D (g-r | K) + # Z3 spaces: + # 2D (z - J, J - H) + # 2D (u - g, g - r) + # 1D (u - g | K) + # 1D (g - r | K) + + Z3 = namedtuple( + "Z3", + ["z_min", "z_max", "lc_data", "zJ_JH", "ug_gr", "ug", "gr"], + ) zbin = 2 z_min = zbins[zbin][0] z_max = zbins[zbin][1] @@ -577,7 +588,7 @@ def get_feniks_data( ) lc_args = ( ran_key, - num_halos, + num_halos_coarse_zbins, z_min, z_max, lgmp_min, @@ -591,11 +602,8 @@ def get_feniks_data( lc_data = generate_lc_data(*lc_args) z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) - Z3 = namedtuple( - "Z3", - ["z_min", "z_max", "lc_data", "zJ_JH", "ug", "gr", "k", "u"], - ) + # 2D (z - J, J - H) zJ_JH = namedtuple("zJ_JH", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) mag_sel_zJ_JH = ( (hsc_z[z_sel] < feniks_mag_thresh.HSC_Z) @@ -608,7 +616,22 @@ def get_feniks_data( col_idx = [4, 5, 6] zJ_JH = zJ_JH(col_idx, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH, N_zJ_JH) + # 2D (u - g, g - r) + Ug_gr = namedtuple("Ug_gr", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + mag_sel_ugr = ( + (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) + & (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) + & (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) + ) + N_ug_gr, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr = get_N_2d( + megacam_hsc_uSg[z_sel][mag_sel_ugr], hsc_gr[z_sel][mag_sel_ugr] + ) + col_idx = [0, 1, 2] + ug_gr = Ug_gr(col_idx, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr, N_ug_gr) + + # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) + ug = [] mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( hsc_g[z_sel] < feniks_mag_thresh.HSC_G @@ -633,6 +656,7 @@ def get_feniks_data( ) ) + # 1D (g - r | K) gr = [] Gr_condK = namedtuple( "Gr_condK", @@ -661,19 +685,70 @@ def get_feniks_data( ) ) - mag_idx = 7 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) - k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr) + colors.append(z3) + ############################################################################## + # prepare 1D app mag funcs in finer z-bins for fitting + fine_zbins = np.array( + [ + [0.2, 0.5], + [0.5, 0.7], + [0.7, 1.0], + [1.0, 1.5], + [1.5, 2.0], + [2.0, 2.5], + ] + ) + ############################################################################## + AppMagFuncs = namedtuple( + "AppMagFuncs", + ["z_min", "z_max", "lc_data", "k", "u"], + ) + K = namedtuple( + "K", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) U = namedtuple( "U", ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], ) - mag_idx = 0 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(megacam_uS[z_sel]) - u = U(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + app_mag_funcs = [] + for zbin in range(0, len(fine_zbins)): + z_min = fine_zbins[zbin][0] + z_max = fine_zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_fine_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + + # 1D (K) + mag_idx = 7 + N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) + k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + + # 1D (u) + mag_idx = 0 + N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(megacam_uS[z_sel]) + u = U(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug, gr, k, u) + app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, k, u)) ############################################################################## @@ -696,9 +771,8 @@ def get_feniks_data( dataset_dim_labels, mags, mags_labels, - z1, - z2, - z3, + colors, + app_mag_funcs, filter_info, frac_cat, lh_centroids, diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index a37f147c..4b78e6b2 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -35,9 +35,8 @@ "dataset_dim_labels", "mags", "mags_labels", - "z1", - "z2", - "z3", + "colors", + "app_mag_funcs", "filter_info", "frac_cat", "lh_centroids", diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index 0baed1fa..496ed08d 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -713,22 +713,24 @@ def plot_app_mag_funcs( ) ax[row, col].set_xticks(np.arange(15, 30, 2)) + ax[row, col].minorticks_on() ax[row, col].tick_params( which="major", - length=3, - width=1.5, direction="in", top=True, right=True, - labelsize=labelsize, + length=6, + width=1, + labelsize=10, ) ax[row, col].tick_params( which="minor", - length=1.5, - width=1.5, direction="in", top=True, right=True, + length=3, + width=0.8, + labelsize=10, ) ax[row, col].set_ylim(-6.9, -2.5) diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index d54ee1df..bc0b05bf 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -27,17 +27,19 @@ def N_colors_mags( fields = z_data._fields[3:] mag_thresh = jnp.array(mag_thresh) for f in range(0, len(fields)): - data = getattr(z_data, fields[f]) + space = getattr(z_data, fields[f]) - if isinstance(data, list): + if isinstance(space, list): new_list = [] - for d in range(0, len(data)): - data_n = data[d] - col_idx = data_n.col_idx + for s in range(0, len(space)): + space_n = space[s] + col_idx = space_n.col_idx # get cond weight - obs_mags_cond = obs_mags[:, data_n.cond_idx] - cond = (obs_mags_cond > data_n.K_min) & (obs_mags_cond <= data_n.K_max) + obs_mags_cond = obs_mags[:, space_n.cond_idx] + cond = (obs_mags_cond > space_n.K_min) & ( + obs_mags_cond <= space_n.K_max + ) weight = jnp.where(cond, gal_weight, 0.0) # get mag_sel weight @@ -50,20 +52,20 @@ def N_colors_mags( N_model = diffndhist_lomem.tw_ndhist_weighted( obs_color, - data_n.sig, + space_n.sig, weight, - data_n.bin_lo, - data_n.bin_hi, + space_n.bin_lo, + space_n.bin_hi, ) NewTuple = namedtuple( - type(data_n).__name__, [*data_n._fields, "N_model"] + type(space_n).__name__, [*space_n._fields, "N_model"] ) - new_list.append(NewTuple(*data_n, N_model)) + new_list.append(NewTuple(*space_n, N_model)) z_data = z_data._replace(**{fields[f]: new_list}) - elif "mag_idx" in data._fields: - mag_idx = data.mag_idx + elif "mag_idx" in space._fields: + mag_idx = space.mag_idx obs_mag = obs_mags[:, mag_idx] obs_mag = obs_mag.reshape(obs_mag.size, 1) @@ -73,18 +75,18 @@ def N_colors_mags( N_model = diffndhist_lomem.tw_ndhist_weighted( obs_mag, - data.sig, + space.sig, weight, - data.bin_lo, - data.bin_hi, + space.bin_lo, + space.bin_hi, ) - NewTuple = namedtuple(type(data).__name__, [*data._fields, "N_model"]) - new = NewTuple(*data, N_model) + NewTuple = namedtuple(type(space).__name__, [*space._fields, "N_model"]) + new = NewTuple(*space, N_model) z_data = z_data._replace(**{fields[f]: new}) else: - col_idx = data.col_idx + col_idx = space.col_idx obs_colors = [] for c in range(0, len(col_idx) - 1): obs_color = obs_mags[:, col_idx[c]] - obs_mags[:, col_idx[c + 1]] @@ -99,14 +101,14 @@ def N_colors_mags( N_model = diffndhist_lomem.tw_ndhist_weighted( obs_colors, - data.sig, + space.sig, weight, - data.bin_lo, - data.bin_hi, + space.bin_lo, + space.bin_hi, ) - NewTuple = namedtuple(type(data).__name__, [*data._fields, "N_model"]) - new = NewTuple(*data, N_model) + NewTuple = namedtuple(type(space).__name__, [*space._fields, "N_model"]) + new = NewTuple(*space, N_model) z_data = z_data._replace(**{fields[f]: new}) return z_data diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index 165ba9ab..8ead0d54 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -1,4 +1,5 @@ from jax import jit as jjit +from jax.debug import print from ..kernels.N_phot import N_colors_mags, N_colors_mags_lh, N_mags_1d from ..param_utils import get_param_collection_from_u_theta @@ -6,67 +7,49 @@ @jjit -def get_phot_loss_2d( +def get_phot_loss_2d_multiz( ran_key, param_collection, - z_data, + data, mag_thresh, frac_cat, data_sky_area_degsq, ): - z_data = N_colors_mags( - ran_key, - param_collection, - z_data, - mag_thresh, - frac_cat, - ) phot_loss_2d = 0.0 - fields = z_data._fields[3:] - for f in range(0, len(fields)): - data = getattr(z_data, fields[f]) + for z in range(0, len(data)): + z_data = data[z] - if isinstance(data, list): - for d in range(0, len(data)): - data_n = data[d] - - N_model = data_n.N_model - N_data = data_n.N_data - - N_model = N_model * ( - data_sky_area_degsq / z_data.lc_data.sky_area_degsq - ) - phot_loss_2d += poisson_loss(N_model, N_data) + z_data_model = N_colors_mags( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, + ) + fields = z_data_model._fields[3:] + for f in range(0, len(fields)): + space = getattr(z_data_model, fields[f]) - else: - N_model = data.N_model - N_data = data.N_data + if isinstance(space, list): + for s in range(0, len(space)): + space_n = space[s] - N_model = N_model * (data_sky_area_degsq / z_data.lc_data.sky_area_degsq) - phot_loss_2d += poisson_loss(N_model, N_data) + N_model = space_n.N_model + N_data = space_n.N_data - return phot_loss_2d + N_model = N_model * ( + data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq + ) + phot_loss_2d += poisson_loss(N_model, N_data) + else: + N_model = space.N_model + N_data = space.N_data -@jjit -def _loss_phot_kern_2d( - u_theta, - ran_key, - z_data, - mag_thresh, - frac_cat, - data_sky_area_degsq, -): - param_collection = get_param_collection_from_u_theta(u_theta) - - phot_loss_2d = get_phot_loss_2d( - ran_key, - param_collection, - z_data, - mag_thresh, - frac_cat, - data_sky_area_degsq, - ) + N_model = N_model * ( + data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq + ) + phot_loss_2d += poisson_loss(N_model, N_data) return phot_loss_2d @@ -81,26 +64,20 @@ def _loss_phot_kern_2d_multiz( phot_loss_2d = 0.0 - phot_loss_2d += get_phot_loss_2d( + # get color loss + phot_loss_2d += get_phot_loss_2d_multiz( ran_key, param_collection, - fitting_data.z1, + fitting_data.colors, fitting_data.filter_info.mag_thresh, fitting_data.frac_cat, fitting_data.data_sky_area_degsq, ) - phot_loss_2d += get_phot_loss_2d( + # get app mag func loss + phot_loss_2d += get_phot_loss_2d_multiz( ran_key, param_collection, - fitting_data.z2, - fitting_data.filter_info.mag_thresh, - fitting_data.frac_cat, - fitting_data.data_sky_area_degsq, - ) - phot_loss_2d += get_phot_loss_2d( - ran_key, - param_collection, - fitting_data.z3, + fitting_data.app_mag_funcs, fitting_data.filter_info.mag_thresh, fitting_data.frac_cat, fitting_data.data_sky_area_degsq, @@ -109,6 +86,29 @@ def _loss_phot_kern_2d_multiz( return phot_loss_2d +# @jjit +# def _loss_phot_kern_2d( +# u_theta, +# ran_key, +# z_data, +# mag_thresh, +# frac_cat, +# data_sky_area_degsq, +# ): +# param_collection = get_param_collection_from_u_theta(u_theta) + +# phot_loss_2d = get_phot_loss_2d( +# ran_key, +# param_collection, +# z_data, +# mag_thresh, +# frac_cat, +# data_sky_area_degsq, +# ) + +# return phot_loss_2d + + @jjit def get_phot_loss_1d( ran_key, diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 7430ef5e..42fba88d 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run138 -model_nickname: run138_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run138/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run140 +model_nickname: run140_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run140/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks diff --git a/scripts/config_feniks.yaml b/scripts/config_feniks.yaml index d80b5af7..68151f04 100644 --- a/scripts/config_feniks.yaml +++ b/scripts/config_feniks.yaml @@ -7,7 +7,8 @@ fit_runid: "runtest" fit_type: "diffstarpop+spspop+merging" feniks: - num_halos: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 200 epoch: n_it: 1 diff --git a/scripts/fit_feniks.py b/scripts/fit_feniks.py index d57909a3..f8b0750e 100644 --- a/scripts/fit_feniks.py +++ b/scripts/fit_feniks.py @@ -51,7 +51,11 @@ # load feniks data ran_key = jran.key(0) feniks = load_feniks.get_feniks_data( - feniks_drn, ran_key, ssp_data, num_halos=cfg["feniks"]["num_halos"] + feniks_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], ) remove = {"dataset_dim_labels", "mags_labels"} FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) From 2b6c451258f78a15b54b4513f93516059e9177e0 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 15:05:52 -0500 Subject: [PATCH 10/57] include r-band app mag func in fitting --- .../experimental/data_loaders/load_feniks.py | 36 ++++++++++++------- .../experimental/diagnostics/plot_phot.py | 16 ++++----- scripts/config_diagnostics.yaml | 6 ++-- scripts/generate_diagnostic_plots.py | 18 ++++++---- 4 files changed, 45 insertions(+), 31 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 9e871831..8a073c06 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -703,16 +703,21 @@ def get_feniks_data( ############################################################################## AppMagFuncs = namedtuple( "AppMagFuncs", - ["z_min", "z_max", "lc_data", "k", "u"], - ) - K = namedtuple( - "K", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ["z_min", "z_max", "lc_data", "u", "r", "k"], ) U = namedtuple( "U", ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], ) + R = namedtuple( + "R", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + K = namedtuple( + "K", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + app_mag_funcs = [] for zbin in range(0, len(fine_zbins)): z_min = fine_zbins[zbin][0] @@ -738,17 +743,22 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + # 1D (u) + mag_idx_u = 0 + N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(megacam_uS[z_sel]) + u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u) + + # 1D (r) + mag_idx_r = 2 + N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(hsc_r[z_sel]) + r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r) + # 1D (K) - mag_idx = 7 + mag_idx_k = 7 N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) - k = K(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - - # 1D (u) - mag_idx = 0 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(megacam_uS[z_sel]) - u = U(mag_idx, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + k = K(mag_idx_k, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, k, u)) + app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, r, k)) ############################################################################## diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index 496ed08d..b19432b3 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -613,10 +613,10 @@ def plot_app_mag_funcs( alpha = 0.75 s = 10 - fig, ax = plt.subplots(2, 4, figsize=(fig_width, fig_height)) - fig.subplots_adjust( - left=0.05, hspace=0.3, top=0.875, right=0.99, bottom=0.1, wspace=0.1 + fig, ax = plt.subplots( + 2, 4, figsize=(fig_width, fig_height), constrained_layout=True ) + fig.get_layout_engine().set(rect=[0, 0, 1, 0.95]) handles = [ mlines.Line2D([], [], color=c, linewidth=6, solid_capstyle="butt", label=label) @@ -632,7 +632,7 @@ def plot_app_mag_funcs( handleheight=0.5, columnspacing=0.8, handletextpad=0.1, - bbox_to_anchor=(0.5, 0.92), + bbox_to_anchor=(0.5, 0.99), fontsize=7, ) @@ -721,7 +721,7 @@ def plot_app_mag_funcs( right=True, length=6, width=1, - labelsize=10, + labelsize=labelsize, ) ax[row, col].tick_params( which="minor", @@ -730,7 +730,7 @@ def plot_app_mag_funcs( right=True, length=3, width=0.8, - labelsize=10, + labelsize=labelsize, ) ax[row, col].set_ylim(-6.9, -2.5) @@ -750,8 +750,8 @@ def plot_app_mag_funcs( ax[1, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) fig.savefig( savedir + "/" + data_label + "_app_mag_funcs.png", - bbox_inches="tight", - dpi=200, + # bbox_extra_artists=(leg,), + dpi=300, ) if plt_show: plt.show() diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 42fba88d..99e2ed46 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run140 -model_nickname: run140_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run140/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run142 +model_nickname: run142_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run142/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 0cd5c9e6..35c6ff1e 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -204,13 +204,6 @@ if cfg["plot_feniks"]: feniks_label = "feniks" # + cfg["model_nickname"].split("_")[0] feniks = load_feniks.get_feniks_data(feniks_drn, ran_key, ssp_data) - feniks_zbins = np.array( - [ - (feniks.z1.z_min, feniks.z1.z_max), - (feniks.z2.z_min, feniks.z2.z_max), - (feniks.z3.z_min, feniks.z3.z_max), - ] - ) if cfg["plots"]["plot_app_mag_funcs"]: print("Generating FENIKS app mag funcs plot...") @@ -272,6 +265,17 @@ plt_show=False, ) + feniks_zbins = np.array( + [ + [0.2, 0.5], + [0.5, 0.7], + [0.7, 1.0], + [1.0, 1.5], + [1.5, 2.0], + [2.0, 2.5], + ] + ) + for zbin in range(0, len(feniks_zbins)): z_min = feniks_zbins[zbin][0] z_max = feniks_zbins[zbin][1] From 887ab7d4a0ab567e8a4334a7c4e612404cb3c49f Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 15:45:13 -0500 Subject: [PATCH 11/57] =?UTF-8?q?(r=20=E2=88=92=20i=20|=20K),=20(r=20?= =?UTF-8?q?=E2=88=92=20z=20|=20K),=20(J=20=E2=88=92=20H=20|=20K)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit extract residual quenching scatter at fixed stellar mass with these three conditional colors for each of the three z-bins, respectively --- .../experimental/data_loaders/load_feniks.py | 107 ++++++++++++++++-- scripts/config_diagnostics.yaml | 6 +- 2 files changed, 101 insertions(+), 12 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 8a073c06..e7f20445 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -399,14 +399,15 @@ def get_feniks_data( ############################################################################## # Z1 spaces: - # 2D (g-r, r-i) - # 2D (u-g, r-K) - # 1D (u-g | K) + # 2D (g - r, r - i) + # 2D (u - g, r - K) + # 1D (u - g | K) + # 1D (r − i | K): residual quenching scatter at fixed stellar mass colors = [] Z1 = namedtuple( "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "ug_rK", "ug"], + ["z_min", "z_max", "lc_data", "gr_ri", "ug_rK", "ug", "ri"], ) zbin = 0 z_min = zbins[zbin][0] @@ -489,17 +490,47 @@ def get_feniks_data( ) ) - z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug) + # 1D (r − i | K) + ri = [] + Ri_condK = namedtuple( + "Ri_condK", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( + hsc_i[z_sel] < feniks_mag_thresh.HSC_I + ) + col_idx = [2, 3] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_ri, sig_ri, bin_lo_ri, bin_hi_ri = get_N_1d( + hsc_ri[z_sel][mag_sel_ri & K_sel] + ) + ri.append( + Ri_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_ri, + bin_lo_ri, + bin_hi_ri, + N_1d_ri, + ) + ) + + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug, ri) colors.append(z1) ############################################################################## # Z2 spaces: # 2D (r - z, z - J) # 1D (u - g | K) + # 1D (r − z | K): residual quenching scatter at fixed stellar mass Z2 = namedtuple( "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "ug"], + ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz"], ) zbin = 1 z_min = zbins[zbin][0] @@ -564,8 +595,36 @@ def get_feniks_data( N_1d_ug, ) ) + # 1D (r - z | K) + rz = [] + Rz_condK = namedtuple( + "Rz_condK", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( + hsc_z[z_sel] < feniks_mag_thresh.HSC_Z + ) + col_idx = [2, 4] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_rz, sig_rz, bin_lo_rz, bin_hi_rz = get_N_1d( + hsc_rz[z_sel][mag_sel_rz & K_sel] + ) + rz.append( + Rz_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_rz, + bin_lo_rz, + bin_hi_rz, + N_1d_rz, + ) + ) - z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug) + z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz) colors.append(z2) ############################################################################## @@ -574,10 +633,11 @@ def get_feniks_data( # 2D (u - g, g - r) # 1D (u - g | K) # 1D (g - r | K) + # 1D (J − H | K): residual quenching scatter at fixed stellar mass Z3 = namedtuple( "Z3", - ["z_min", "z_max", "lc_data", "zJ_JH", "ug_gr", "ug", "gr"], + ["z_min", "z_max", "lc_data", "zJ_JH", "ug_gr", "ug", "gr", "jh"], ) zbin = 2 z_min = zbins[zbin][0] @@ -685,7 +745,36 @@ def get_feniks_data( ) ) - z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr) + # 1D (J − H | K) + jh = [] + JH_condK = namedtuple( + "JH_condK", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + uds_H[z_sel] < feniks_mag_thresh.UDS_H + ) + col_idx = [5, 6] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( + uds_JH[z_sel][mag_sel_jh & K_sel] + ) + jh.append( + JH_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_jh, + bin_lo_jh, + bin_hi_jh, + N_1d_jh, + ) + ) + + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh) colors.append(z3) ############################################################################## diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 99e2ed46..fe24edf7 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run142 -model_nickname: run142_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run142/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run143 +model_nickname: run143_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run143/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From 969412a950e545ab588c47e18c4a15405699aa56 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 17:56:18 -0500 Subject: [PATCH 12/57] =?UTF-8?q?1D=20(i=20=E2=88=92=20z=20|=20K)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../experimental/data_loaders/load_feniks.py | 34 +++++++++++++++++-- scripts/config_diagnostics.yaml | 6 ++-- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index e7f20445..439111c9 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -403,11 +403,12 @@ def get_feniks_data( # 2D (u - g, r - K) # 1D (u - g | K) # 1D (r − i | K): residual quenching scatter at fixed stellar mass + # 1D (i - z | K): completely unconstrained so including it here colors = [] Z1 = namedtuple( "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "ug_rK", "ug", "ri"], + ["z_min", "z_max", "lc_data", "gr_ri", "ug_rK", "ug", "ri", "iz"], ) zbin = 0 z_min = zbins[zbin][0] @@ -519,7 +520,36 @@ def get_feniks_data( ) ) - z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug, ri) + # 1D (i − z | K) + iz = [] + Iz_condK = namedtuple( + "Iz_condK", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_sel_iz = (hsc_i[z_sel] < feniks_mag_thresh.HSC_I) & ( + hsc_z[z_sel] < feniks_mag_thresh.HSC_Z + ) + col_idx = [3, 4] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_iz, sig_iz, bin_lo_iz, bin_hi_iz = get_N_1d( + hsc_iz[z_sel][mag_sel_iz & K_sel] + ) + iz.append( + Iz_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_iz, + bin_lo_iz, + bin_hi_iz, + N_1d_iz, + ) + ) + + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug, ri, iz) colors.append(z1) ############################################################################## diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index fe24edf7..d0ff5af9 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run143 -model_nickname: run143_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run143/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run145 +model_nickname: run145_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run145/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From 30440fcfe8e6f468c44010ecd18b297218a4dea3 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 19:48:19 -0500 Subject: [PATCH 13/57] fit hizels+feniks --- .../experimental/diagnostics/plot_burstpop.py | 2 +- .../experimental/loss_kernels/phot_loss.py | 1 - .../optimizers/Np_specphot_opt.py | 64 ++----------------- scripts/config_diagnostics.yaml | 6 +- scripts/config_diffsky.yaml | 3 +- scripts/fit_diffsky.py | 57 ++++------------- 6 files changed, 24 insertions(+), 109 deletions(-) diff --git a/diffhtwo/experimental/diagnostics/plot_burstpop.py b/diffhtwo/experimental/diagnostics/plot_burstpop.py index 795249c4..f3267ea8 100644 --- a/diffhtwo/experimental/diagnostics/plot_burstpop.py +++ b/diffhtwo/experimental/diagnostics/plot_burstpop.py @@ -143,7 +143,7 @@ def plot_lgfburst_mh_z( ] ) fig, ax = plt.subplots(1, 2, figsize=(10, 4), width_ratios=[1, 1.2]) - vmin, vmax = -6, -2.5 + vmin, vmax = -6, -1.5 """Plot fburst w/ halo mass and redshift""" ax[0].hexbin( diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index 8ead0d54..aee02125 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -1,5 +1,4 @@ from jax import jit as jjit -from jax.debug import print from ..kernels.N_phot import N_colors_mags, N_colors_mags_lh, N_mags_1d from ..param_utils import get_param_collection_from_u_theta diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 2ac416dd..3f282ecf 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -187,7 +187,7 @@ def fit_feniks_hizels( def _opt_update(opt_state, i): u_theta = get_params(opt_state) - loss_phot, grad_phot = _loss_and_grad_phot_kern_multiband_multiz( + loss_phot, grad_phot = _loss_and_grad_phot_kern_2d_multiz( u_theta, ran_key, feniks_fitting_data, @@ -209,10 +209,10 @@ def _opt_update(opt_state, i): ) # clip gradients - global_norm = pytree_norm(grads) - tau = 1.0 - scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) - grads = tuple(g * scale for g in grads) + # global_norm = pytree_norm(grads) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) opt_state = opt_update(i, grads, opt_state) return opt_state, (loss, loss_phot, loss_emline) @@ -224,60 +224,6 @@ def _opt_update(opt_state, i): return loss_hist, loss_phot_hist, loss_emline_hist, u_theta_fit -# @partial(jjit, static_argnames=["n_steps", "step_size"]) -# def fit_feniks_hizels( -# u_theta_init, -# trainable, -# ran_key, -# feniks_meta_data, -# feniks_fitting_data, -# hizels_fitting_data, -# n_steps=2, -# step_size=1e-2, -# ): -# opt_init, opt_update, get_params = jax_opt.adam(step_size) -# opt_state = opt_init(u_theta_init) - -# def _opt_update(opt_state, i): -# u_theta = get_params(opt_state) -# loss_phot, grad_phot = _loss_and_grad_phot_kern_multi_z( -# u_theta, -# ran_key, -# feniks_meta_data, -# feniks_fitting_data, -# ) -# loss_emline, grad_emline = _loss_and_grad_emline_kern_multi_line_multi_z( -# u_theta, -# ran_key, -# hizels_fitting_data, -# ) -# w_phot = 10.0 -# w_emline = 1.0 -# loss = w_phot * loss_phot + w_emline * loss_emline -# grads = tuple( -# w_phot * gp + w_emline * ge for gp, ge in zip(grad_phot, grad_emline) -# ) -# # set grads for untrainable params to 0.0 -# grads = tuple( -# jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) -# ) - -# # clip gradients -# global_norm = pytree_norm(grads) -# tau = 1.0 -# scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) -# grads = tuple(g * scale for g in grads) - -# opt_state = opt_update(i, grads, opt_state) -# return opt_state, (loss, loss_phot, loss_emline) - -# opt_state, (loss_hist, loss_phot_hist, loss_emline_hist) = lax.scan( -# _opt_update, opt_state, jnp.arange(n_steps) -# ) -# u_theta_fit = get_params(opt_state) -# return loss_hist, loss_phot_hist, loss_emline_hist, u_theta_fit - - @jjit def _loss_sdss_feniks_hizels( u_theta, diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index d0ff5af9..93fb61fc 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run145 -model_nickname: run145_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run145/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run147 +model_nickname: run147_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run147/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks diff --git a/scripts/config_diffsky.yaml b/scripts/config_diffsky.yaml index d367eca2..9b90f0df 100644 --- a/scripts/config_diffsky.yaml +++ b/scripts/config_diffsky.yaml @@ -13,7 +13,8 @@ sdss: feniks: lh_d_mag: 0.4 N_centroids: 100 - num_halos: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 100 hizels: num_halos: 100 diff --git a/scripts/fit_diffsky.py b/scripts/fit_diffsky.py index 2e7563fd..779e2dbd 100644 --- a/scripts/fit_diffsky.py +++ b/scripts/fit_diffsky.py @@ -1,6 +1,7 @@ import argparse import os import time +from collections import namedtuple from datetime import datetime from pathlib import Path @@ -22,8 +23,6 @@ from diffhtwo.experimental import param_utils as pu from diffhtwo.experimental.data_loaders import load_feniks, load_hizels -from diffhtwo.experimental.defaults import FENIKS_Z_MIN -from diffhtwo.experimental.latin_hypercube import lh_utils as lhu from diffhtwo.experimental.optimizers import Np_specphot_opt DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ @@ -51,14 +50,19 @@ ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) halpha_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) - # load sdss data - # ran_key = jran.key(0) - # SDSS = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) - # load feniks data ran_key = jran.key(0) - FENIKS = load_feniks.get_feniks_data( - feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] + feniks = load_feniks.get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} ) # load hizels data @@ -66,7 +70,7 @@ hizels_drn, ran_key, ssp_data, - FENIKS.filter_info.tcurves, + feniks.filter_info.tcurves, halpha_wave_aa, num_halos=cfg["hizels"]["num_halos"], ) @@ -111,45 +115,11 @@ os.system(f"cp {args.config} {fit_diagnostics_save_drn}") - # SDSS - # sdss_z_min = [SDSS_Z_MIN, 0.08, 0.14] - # sdss_z_max = [0.08, 0.14, SDSS_Z_MAX] - - # FENIKS - feniks_z_min = [FENIKS_Z_MIN, 1] - feniks_z_max = [1, 2] - initial_pts = [] start = time.time() for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - # SDSS - # sdss = load_sdss.refresh_lh_centroids(SDSS) - # sdss_meta_data, sdss_fitting_data = lhu.get_zbins_lh_lc( - # ran_key, - # SDSS, - # sdss_z_min, - # sdss_z_max, - # ssp_data, - # cfg["sdss"]["N_centroids"], - # lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", - # num_halos=cfg["sdss"]["num_halos"], - # ) - - # FENIKS - FENIKS = load_feniks.refresh_lh_centroids(FENIKS, cfg["feniks"]["lh_d_mag"]) - feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( - ran_key, - FENIKS, - feniks_z_min, - feniks_z_max, - ssp_data, - cfg["feniks"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", - num_halos=cfg["feniks"]["num_halos"], - ) - ( loss_hist, loss_phot_hist, @@ -159,7 +129,6 @@ u_theta_fit, trainable_params, ran_key, - feniks_meta_data, feniks_fitting_data, hizels_fitting_data, n_steps=cfg["epoch"]["n_steps"], From a2971b02c98aeb809fcfe7f1c0339c096e54bb0b Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 21:09:02 -0500 Subject: [PATCH 14/57] =?UTF-8?q?add=20in=201D=20(J=20=E2=88=92=20H=20|=20?= =?UTF-8?q?K)=20in=20z1=20and=20z2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../experimental/data_loaders/load_feniks.py | 68 +++++++++++++++++-- .../experimental/diagnostics/plot_phot.py | 2 +- scripts/config_diagnostics.yaml | 6 +- 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 439111c9..d760ba78 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -404,11 +404,12 @@ def get_feniks_data( # 1D (u - g | K) # 1D (r − i | K): residual quenching scatter at fixed stellar mass # 1D (i - z | K): completely unconstrained so including it here + # 1D (J − H | K) colors = [] Z1 = namedtuple( "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "ug_rK", "ug", "ri", "iz"], + ["z_min", "z_max", "lc_data", "gr_ri", "ug_rK", "ug", "ri", "iz", "jh"], ) zbin = 0 z_min = zbins[zbin][0] @@ -549,7 +550,36 @@ def get_feniks_data( ) ) - z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug, ri, iz) + # 1D (J − H | K) + jh = [] + JH_condK = namedtuple( + "JH_condK", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + uds_H[z_sel] < feniks_mag_thresh.UDS_H + ) + col_idx = [5, 6] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( + uds_JH[z_sel][mag_sel_jh & K_sel] + ) + jh.append( + JH_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_jh, + bin_lo_jh, + bin_hi_jh, + N_1d_jh, + ) + ) + + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug, ri, iz, jh) colors.append(z1) ############################################################################## @@ -557,10 +587,11 @@ def get_feniks_data( # 2D (r - z, z - J) # 1D (u - g | K) # 1D (r − z | K): residual quenching scatter at fixed stellar mass + # 1D (J − H | K) Z2 = namedtuple( "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz"], + ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh"], ) zbin = 1 z_min = zbins[zbin][0] @@ -654,7 +685,36 @@ def get_feniks_data( ) ) - z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz) + # 1D (J − H | K) + jh = [] + JH_condK = namedtuple( + "JH_condK", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + ) + mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + uds_H[z_sel] < feniks_mag_thresh.UDS_H + ) + col_idx = [5, 6] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( + uds_JH[z_sel][mag_sel_jh & K_sel] + ) + jh.append( + JH_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_jh, + bin_lo_jh, + bin_hi_jh, + N_1d_jh, + ) + ) + + z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh) colors.append(z2) ############################################################################## diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index b19432b3..daccb9bd 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -633,7 +633,7 @@ def plot_app_mag_funcs( columnspacing=0.8, handletextpad=0.1, bbox_to_anchor=(0.5, 0.99), - fontsize=7, + fontsize=10, ) xlim = [] diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 93fb61fc..03e7e182 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run147 -model_nickname: run147_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run147/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run148 +model_nickname: run148_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run148/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From 19c00727ba64a228823133a67ae4487de32f6da9 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 22:23:26 -0500 Subject: [PATCH 15/57] normalize phot and emline loss based on number of bins --- .../experimental/data_loaders/load_feniks.py | 22 +++++++++++++++++++ .../experimental/data_loaders/load_hizels.py | 20 ++++++++++++++++- diffhtwo/experimental/defaults.py | 1 + .../optimizers/Np_specphot_opt.py | 4 ++-- scripts/config_diagnostics.yaml | 8 +++---- 5 files changed, 48 insertions(+), 7 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index d760ba78..156a47db 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -387,6 +387,8 @@ def get_feniks_data( # mags = mags[z_mask] # zout = zout[z_mask] + nbins = 0 + ############################################################################## # prepare 2D and 1D color spaces in coarse z-bins for fitting zbins = np.array( @@ -447,6 +449,7 @@ def get_feniks_data( ) col_idx = [1, 2, 3] gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) + nbins += bin_lo_gr_ri.size # 2D (u - g, r - K) Ug_rK = namedtuple("Ug_rK", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -460,6 +463,7 @@ def get_feniks_data( ) col_idx = [0, 1, 7] ug_rK = Ug_rK(col_idx, sig_ug_rK, bin_lo_ug_rK, bin_hi_ug_rK, N_ug_rK) + nbins += bin_lo_ug_rK.size # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) @@ -491,6 +495,7 @@ def get_feniks_data( N_1d_ug, ) ) + nbins += bin_lo_ug.size # 1D (r − i | K) ri = [] @@ -520,6 +525,7 @@ def get_feniks_data( N_1d_ri, ) ) + nbins += bin_lo_ri.size # 1D (i − z | K) iz = [] @@ -549,6 +555,7 @@ def get_feniks_data( N_1d_iz, ) ) + nbins += bin_lo_iz.size # 1D (J − H | K) jh = [] @@ -578,6 +585,7 @@ def get_feniks_data( N_1d_jh, ) ) + nbins += bin_lo_jh.size z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug, ri, iz, jh) colors.append(z1) @@ -629,6 +637,7 @@ def get_feniks_data( ) col_idx = [2, 4, 5] rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ) + nbins += bin_lo_rz_zJ.size # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) @@ -656,6 +665,8 @@ def get_feniks_data( N_1d_ug, ) ) + nbins += bin_lo_ug.size + # 1D (r - z | K) rz = [] Rz_condK = namedtuple( @@ -684,6 +695,7 @@ def get_feniks_data( N_1d_rz, ) ) + nbins += bin_lo_rz.size # 1D (J − H | K) jh = [] @@ -713,6 +725,7 @@ def get_feniks_data( N_1d_jh, ) ) + nbins += bin_lo_jh.size z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh) colors.append(z2) @@ -765,6 +778,7 @@ def get_feniks_data( ) col_idx = [4, 5, 6] zJ_JH = zJ_JH(col_idx, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH, N_zJ_JH) + nbins += bin_lo_zJ_JH.size # 2D (u - g, g - r) Ug_gr = namedtuple("Ug_gr", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -778,6 +792,7 @@ def get_feniks_data( ) col_idx = [0, 1, 2] ug_gr = Ug_gr(col_idx, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr, N_ug_gr) + nbins += bin_lo_ug_gr.size # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) @@ -805,6 +820,7 @@ def get_feniks_data( N_1d_ug, ) ) + nbins += bin_lo_ug.size # 1D (g - r | K) gr = [] @@ -834,6 +850,7 @@ def get_feniks_data( N_1d_gr, ) ) + nbins += bin_lo_gr.size # 1D (J − H | K) jh = [] @@ -863,6 +880,7 @@ def get_feniks_data( N_1d_jh, ) ) + nbins += bin_lo_jh.size z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh) colors.append(z3) @@ -926,16 +944,19 @@ def get_feniks_data( mag_idx_u = 0 N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(megacam_uS[z_sel]) u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u) + nbins += bin_lo_u.size # 1D (r) mag_idx_r = 2 N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(hsc_r[z_sel]) r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r) + nbins += bin_lo_r.size # 1D (K) mag_idx_k = 7 N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) k = K(mag_idx_k, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + nbins += bin_lo_k.size app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, r, k)) @@ -962,6 +983,7 @@ def get_feniks_data( mags_labels, colors, app_mag_funcs, + nbins, filter_info, frac_cat, lh_centroids, diff --git a/diffhtwo/experimental/data_loaders/load_hizels.py b/diffhtwo/experimental/data_loaders/load_hizels.py index de3c7bfc..949ec609 100644 --- a/diffhtwo/experimental/data_loaders/load_hizels.py +++ b/diffhtwo/experimental/data_loaders/load_hizels.py @@ -18,6 +18,7 @@ "z", "dz", "lc_data", + "nbins", ], ) DELTA_L_HALPHA = -0.4 # uncorrect HiZELS h-alpha L for dust (A_halpha = 1 mag) @@ -42,6 +43,7 @@ def get_hizels_data( hizels_halpha_lg_phi_data, hizels_halpha_z_data, hizels_halpha_delta_z_data, + hizels_halpha_nbins, ) = get_hizels_halpha(drn) line_wave_aa = [halpha_wave_aa] @@ -82,7 +84,15 @@ def get_hizels_data( lc_data.append(line_lc_data) return Hizels( - line_wave_aa, lg_Lbin_edges, N_data, vol_Mpc3_data, lg_phi_data, z, dz, lc_data + line_wave_aa, + lg_Lbin_edges, + N_data, + vol_Mpc3_data, + lg_phi_data, + z, + dz, + lc_data, + hizels_halpha_nbins, ) @@ -213,6 +223,13 @@ def get_hizels_halpha(drn): ) ) + hizels_halpha_nbins = ( + (lg_halpha_Lbin_edges_z0p4.size - 1) + + (lg_halpha_Lbin_edges_z0p84.size - 1) + + (lg_halpha_Lbin_edges_z1p47.size - 1) + + (lg_halpha_Lbin_edges_z2p23.size - 1) + ) + hizels_lg_halpha_Lbin_edges_data = [ lg_halpha_Lbin_edges_z0p4, lg_halpha_Lbin_edges_z0p84, @@ -262,4 +279,5 @@ def get_hizels_halpha(drn): hizels_halpha_lg_phi_data, hizels_halpha_z_data, hizels_halpha_delta_z_data, + hizels_halpha_nbins, ) diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 4b78e6b2..961288e9 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -37,6 +37,7 @@ "mags_labels", "colors", "app_mag_funcs", + "nbins", "filter_info", "frac_cat", "lh_centroids", diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 3f282ecf..8bb89332 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -197,8 +197,8 @@ def _opt_update(opt_state, i): ran_key, hizels_fitting_data, ) - w_phot = 1.0 - w_emline = 1.0 + w_phot = 1.0 / feniks_fitting_data.nbins + w_emline = 1.0 / hizels_fitting_data.nbins loss = w_phot * loss_phot + w_emline * loss_emline grads = tuple( w_phot * gp + w_emline * ge for gp, ge in zip(grad_phot, grad_emline) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 03e7e182..4ae8f10e 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run148 -model_nickname: run148_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run148/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run149 +model_nickname: run149_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run149/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -13,7 +13,7 @@ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kr plot_sdss: False plot_feniks: True -plot_hizels: False +plot_hizels: True plots: num_halos : 3000 From 640f077f346ea2453653fad41c2a245d0873884c Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 23:03:22 -0500 Subject: [PATCH 16/57] grad clipping on --- diffhtwo/experimental/diagnostics/plot_phot.py | 2 +- diffhtwo/experimental/optimizers/Np_specphot_opt.py | 8 ++++---- scripts/config_diagnostics.yaml | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index daccb9bd..ffedc410 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -632,7 +632,7 @@ def plot_app_mag_funcs( handleheight=0.5, columnspacing=0.8, handletextpad=0.1, - bbox_to_anchor=(0.5, 0.99), + bbox_to_anchor=(0.5, 1.02), fontsize=10, ) diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 8bb89332..2dc665d2 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -209,10 +209,10 @@ def _opt_update(opt_state, i): ) # clip gradients - # global_norm = pytree_norm(grads) - # tau = 1.0 - # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) - # grads = tuple(g * scale for g in grads) + global_norm = pytree_norm(grads) + tau = 1.0 + scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + grads = tuple(g * scale for g in grads) opt_state = opt_update(i, grads, opt_state) return opt_state, (loss, loss_phot, loss_emline) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 4ae8f10e..427b92c7 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run149 -model_nickname: run149_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run149/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run151 +model_nickname: run151_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run151/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From dd87477a0a70ff5ea86e9a75767ad32c63db4832 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 7 Jun 2026 23:18:03 -0500 Subject: [PATCH 17/57] plot weighted loss_phot and loss_emline --- diffhtwo/experimental/optimizers/Np_specphot_opt.py | 6 +++++- scripts/config_diagnostics.yaml | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 2dc665d2..4ac45256 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -199,7 +199,11 @@ def _opt_update(opt_state, i): ) w_phot = 1.0 / feniks_fitting_data.nbins w_emline = 1.0 / hizels_fitting_data.nbins - loss = w_phot * loss_phot + w_emline * loss_emline + + loss_phot = w_phot * loss_phot + loss_emline = w_emline * loss_emline + loss = loss_phot + loss_emline + grads = tuple( w_phot * gp + w_emline * ge for gp, ge in zip(grad_phot, grad_emline) ) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 427b92c7..88daca79 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run151 -model_nickname: run151_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run151/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run152 +model_nickname: run152_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run152/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From c50520c6c0d33663e8729dd3d510fa3b44fae7bf Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Mon, 8 Jun 2026 09:47:19 -0500 Subject: [PATCH 18/57] no grad clipping --- .../experimental/data_loaders/load_feniks.py | 47 ++++++++++--------- .../experimental/data_loaders/load_hizels.py | 21 +++++++-- diffhtwo/experimental/defaults.py | 3 +- .../optimizers/Np_specphot_opt.py | 12 ++--- scripts/config_diagnostics.yaml | 6 +-- 5 files changed, 51 insertions(+), 38 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 156a47db..7e49ebe6 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -283,7 +283,7 @@ def get_feniks_data( uds_H = uds_H[mag_thresh] uds_K = uds_K[mag_thresh] - N_obj_pre_cuts = len(zout) + n_gals_pre_cuts = len(zout) # remove mags with bad data in any of the bands clean = ( @@ -310,8 +310,8 @@ def get_feniks_data( uds_H = uds_H[clean] uds_K = uds_K[clean] - N_obj_post_cuts = len(zout) - frac_cat = N_obj_post_cuts / N_obj_pre_cuts + n_gals_post_cuts = len(zout) + frac_cat = n_gals_post_cuts / n_gals_pre_cuts mags = np.vstack( ( @@ -387,7 +387,7 @@ def get_feniks_data( # mags = mags[z_mask] # zout = zout[z_mask] - nbins = 0 + n_bins = 0 ############################################################################## # prepare 2D and 1D color spaces in coarse z-bins for fitting @@ -449,7 +449,7 @@ def get_feniks_data( ) col_idx = [1, 2, 3] gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) - nbins += bin_lo_gr_ri.size + n_bins += bin_lo_gr_ri.size # 2D (u - g, r - K) Ug_rK = namedtuple("Ug_rK", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -463,7 +463,7 @@ def get_feniks_data( ) col_idx = [0, 1, 7] ug_rK = Ug_rK(col_idx, sig_ug_rK, bin_lo_ug_rK, bin_hi_ug_rK, N_ug_rK) - nbins += bin_lo_ug_rK.size + n_bins += bin_lo_ug_rK.size # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) @@ -495,7 +495,7 @@ def get_feniks_data( N_1d_ug, ) ) - nbins += bin_lo_ug.size + n_bins += bin_lo_ug.size # 1D (r − i | K) ri = [] @@ -525,7 +525,7 @@ def get_feniks_data( N_1d_ri, ) ) - nbins += bin_lo_ri.size + n_bins += bin_lo_ri.size # 1D (i − z | K) iz = [] @@ -555,7 +555,7 @@ def get_feniks_data( N_1d_iz, ) ) - nbins += bin_lo_iz.size + n_bins += bin_lo_iz.size # 1D (J − H | K) jh = [] @@ -585,7 +585,7 @@ def get_feniks_data( N_1d_jh, ) ) - nbins += bin_lo_jh.size + n_bins += bin_lo_jh.size z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug, ri, iz, jh) colors.append(z1) @@ -637,7 +637,7 @@ def get_feniks_data( ) col_idx = [2, 4, 5] rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ) - nbins += bin_lo_rz_zJ.size + n_bins += bin_lo_rz_zJ.size # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) @@ -665,7 +665,7 @@ def get_feniks_data( N_1d_ug, ) ) - nbins += bin_lo_ug.size + n_bins += bin_lo_ug.size # 1D (r - z | K) rz = [] @@ -695,7 +695,7 @@ def get_feniks_data( N_1d_rz, ) ) - nbins += bin_lo_rz.size + n_bins += bin_lo_rz.size # 1D (J − H | K) jh = [] @@ -725,7 +725,7 @@ def get_feniks_data( N_1d_jh, ) ) - nbins += bin_lo_jh.size + n_bins += bin_lo_jh.size z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh) colors.append(z2) @@ -778,7 +778,7 @@ def get_feniks_data( ) col_idx = [4, 5, 6] zJ_JH = zJ_JH(col_idx, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH, N_zJ_JH) - nbins += bin_lo_zJ_JH.size + n_bins += bin_lo_zJ_JH.size # 2D (u - g, g - r) Ug_gr = namedtuple("Ug_gr", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -792,7 +792,7 @@ def get_feniks_data( ) col_idx = [0, 1, 2] ug_gr = Ug_gr(col_idx, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr, N_ug_gr) - nbins += bin_lo_ug_gr.size + n_bins += bin_lo_ug_gr.size # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) @@ -820,7 +820,7 @@ def get_feniks_data( N_1d_ug, ) ) - nbins += bin_lo_ug.size + n_bins += bin_lo_ug.size # 1D (g - r | K) gr = [] @@ -850,7 +850,7 @@ def get_feniks_data( N_1d_gr, ) ) - nbins += bin_lo_gr.size + n_bins += bin_lo_gr.size # 1D (J − H | K) jh = [] @@ -880,7 +880,7 @@ def get_feniks_data( N_1d_jh, ) ) - nbins += bin_lo_jh.size + n_bins += bin_lo_jh.size z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh) colors.append(z3) @@ -944,19 +944,19 @@ def get_feniks_data( mag_idx_u = 0 N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(megacam_uS[z_sel]) u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u) - nbins += bin_lo_u.size + n_bins += bin_lo_u.size # 1D (r) mag_idx_r = 2 N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(hsc_r[z_sel]) r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r) - nbins += bin_lo_r.size + n_bins += bin_lo_r.size # 1D (K) mag_idx_k = 7 N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) k = K(mag_idx_k, sig_k, bin_lo_k, bin_hi_k, N_1d_k) - nbins += bin_lo_k.size + n_bins += bin_lo_k.size app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, r, k)) @@ -983,7 +983,8 @@ def get_feniks_data( mags_labels, colors, app_mag_funcs, - nbins, + n_bins, + n_gals_post_cuts, filter_info, frac_cat, lh_centroids, diff --git a/diffhtwo/experimental/data_loaders/load_hizels.py b/diffhtwo/experimental/data_loaders/load_hizels.py index 949ec609..caa624d8 100644 --- a/diffhtwo/experimental/data_loaders/load_hizels.py +++ b/diffhtwo/experimental/data_loaders/load_hizels.py @@ -18,7 +18,8 @@ "z", "dz", "lc_data", - "nbins", + "n_bins", + "n_gals", ], ) DELTA_L_HALPHA = -0.4 # uncorrect HiZELS h-alpha L for dust (A_halpha = 1 mag) @@ -43,7 +44,8 @@ def get_hizels_data( hizels_halpha_lg_phi_data, hizels_halpha_z_data, hizels_halpha_delta_z_data, - hizels_halpha_nbins, + hizels_halpha_n_bins, + hizels_halpha_n_gals, ) = get_hizels_halpha(drn) line_wave_aa = [halpha_wave_aa] @@ -92,7 +94,8 @@ def get_hizels_data( z, dz, lc_data, - hizels_halpha_nbins, + hizels_halpha_n_bins, + hizels_halpha_n_gals, ) @@ -223,7 +226,7 @@ def get_hizels_halpha(drn): ) ) - hizels_halpha_nbins = ( + hizels_halpha_n_bins = ( (lg_halpha_Lbin_edges_z0p4.size - 1) + (lg_halpha_Lbin_edges_z0p84.size - 1) + (lg_halpha_Lbin_edges_z1p47.size - 1) @@ -244,6 +247,13 @@ def get_hizels_halpha(drn): halpha_N_data_z2p23, ] + hizels_halpha_n_gals = ( + (halpha_N_data_z0p4.sum()) + + (halpha_N_data_z0p84.sum()) + + (halpha_N_data_z1p47.sum()) + + (halpha_N_data_z2p23.sum()) + ) + hizels_halpha_vol_Mpc3 = [ halpha_vol_Mpc3_z0p4, halpha_vol_Mpc3_z0p84, @@ -279,5 +289,6 @@ def get_hizels_halpha(drn): hizels_halpha_lg_phi_data, hizels_halpha_z_data, hizels_halpha_delta_z_data, - hizels_halpha_nbins, + hizels_halpha_n_bins, + hizels_halpha_n_gals, ) diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 961288e9..497ec798 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -37,7 +37,8 @@ "mags_labels", "colors", "app_mag_funcs", - "nbins", + "n_bins", + "n_gals", "filter_info", "frac_cat", "lh_centroids", diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 4ac45256..14570620 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -197,8 +197,8 @@ def _opt_update(opt_state, i): ran_key, hizels_fitting_data, ) - w_phot = 1.0 / feniks_fitting_data.nbins - w_emline = 1.0 / hizels_fitting_data.nbins + w_phot = 1.0 / 5 + w_emline = 1.0 loss_phot = w_phot * loss_phot loss_emline = w_emline * loss_emline @@ -213,10 +213,10 @@ def _opt_update(opt_state, i): ) # clip gradients - global_norm = pytree_norm(grads) - tau = 1.0 - scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) - grads = tuple(g * scale for g in grads) + # global_norm = pytree_norm(grads) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) opt_state = opt_update(i, grads, opt_state) return opt_state, (loss, loss_phot, loss_emline) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 88daca79..a0781c42 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run152 -model_nickname: run152_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run152/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run153 +model_nickname: run153_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run153/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From 9c7c3d16f20eec63135e94c37a5fedd5d3371f16 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Mon, 8 Jun 2026 11:17:30 -0500 Subject: [PATCH 19/57] load_sdss --- .../experimental/data_loaders/load_sdss.py | 23 +++++++++++++++---- .../optimizers/Np_specphot_opt.py | 8 +++---- scripts/config_diagnostics.yaml | 12 +++++----- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 85db9efc..7f04b439 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -11,17 +11,14 @@ SDSS_MAGR_THRESH, SDSS_Z_MAX, SDSS_Z_MIN, - Dataset, FilterInfo, ) from ..latin_hypercube import latin_hypercube as lh -Sdss = namedtuple("Sdss", Dataset._fields) - LH_N_CENTROIDS = 20_000 LH_SIG = 3.5 LH_D_MAG = 0.1 -LH_D_Z = 0.01 +LH_D_Z = 0.05 def apply_ra_dec_cut(sdss, ra_min=120, ra_max=240, dec_min=0, dec_max=60): @@ -200,3 +197,21 @@ def get_sdss_data( "sdss_z", ], ) + +Sdss = namedtuple( + "Sdss", + [ + "dataset", + "dataset_dim_labels", + "mags", + "mags_labels", + "filter_info", + "frac_cat", + "lh_centroids", + "d_centroids", + "N_data", + "lh_dmag", + "lh_dz", + "data_sky_area_degsq", + ], +) diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 14570620..27d51cca 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -72,10 +72,10 @@ def _opt_update(opt_state, i): ) # clip gradients - global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) - tau = 1.0 - scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) - grads = tuple(g * scale for g in grads) + # global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) opt_state = opt_update(i, grads, opt_state) return opt_state, loss diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index a0781c42..df0f8ec9 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run153 -model_nickname: run153_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run153/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run150 +model_nickname: run150_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run150/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,9 +11,9 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: False -plot_feniks: True -plot_hizels: True +plot_sdss: True +plot_feniks: False +plot_hizels: False plots: num_halos : 3000 From a293ba2934e86b830f4d85ec32cb4d5ba08d964a Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Mon, 8 Jun 2026 23:58:43 -0500 Subject: [PATCH 20/57] load sdss 2D --- diffhtwo/experimental/data_loaders/N_utils.py | 59 ++++ .../experimental/data_loaders/load_feniks.py | 59 +--- .../experimental/data_loaders/load_sdss.py | 209 +++++++++++-- diffhtwo/experimental/defaults.py | 2 - .../experimental/diagnostics/plot_phot.py | 291 +++++++++++++++--- scripts/config_diagnostics.yaml | 12 +- scripts/config_sdss.yaml | 7 +- scripts/fit_feniks.py | 32 +- scripts/fit_sdss.py | 32 +- scripts/generate_diagnostic_plots.py | 23 ++ 10 files changed, 552 insertions(+), 174 deletions(-) create mode 100644 diffhtwo/experimental/data_loaders/N_utils.py diff --git a/diffhtwo/experimental/data_loaders/N_utils.py b/diffhtwo/experimental/data_loaders/N_utils.py new file mode 100644 index 00000000..e6dd4241 --- /dev/null +++ b/diffhtwo/experimental/data_loaders/N_utils.py @@ -0,0 +1,59 @@ +import jax.numpy as jnp +import numpy as np +from diffsky import diffndhist_lomem + + +def get_N_1d(dim1, dim1_bin_edges=None, dmag=0.2, sig_scale=0.5): + dataset = dim1.reshape(dim1.size, 1) + if dim1_bin_edges is None: + dim1_bin_edges = np.arange(dim1.min(), dim1.max(), dmag) + + bin_lo = dim1_bin_edges[:-1].reshape(dim1_bin_edges[:-1].size, 1) + bin_hi = dim1_bin_edges[1:].reshape(dim1_bin_edges[1:].size, 1) + + sig = jnp.zeros_like(bin_lo) + (dmag * sig_scale) + + N_1d = diffndhist_lomem.tw_ndhist( + dataset, + sig, + bin_lo, + bin_hi, + ) + + return ( + N_1d, + sig, + bin_lo, + bin_hi, + ) + + +def get_N_2d(dim1, dim2, sig_scale=0.5): + dataset = np.vstack((dim1, dim2)).T + + dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 11) + dim2_bin_edges = np.linspace(dim2.min(), dim2.max(), 11) + + dim1_lo = dim1_bin_edges[:-1] + dim2_lo = dim2_bin_edges[:-1] + bin_lo = np.meshgrid(dim1_lo, dim2_lo, indexing="ij") + bin_lo = np.array(bin_lo).T.reshape(-1, 2) + + dim1_hi = dim1_bin_edges[1:] + dim2_hi = dim2_bin_edges[1:] + bin_hi = np.meshgrid(dim1_hi, dim2_hi, indexing="ij") + bin_hi = np.array(bin_hi).T.reshape(-1, 2) + + sig1 = np.diff(dim1_bin_edges) * sig_scale + sig2 = np.diff(dim2_bin_edges) * sig_scale + sig = np.meshgrid(sig1, sig2, indexing="ij") + sig = np.array(sig).T.reshape(-1, 2) + + N_2d = diffndhist_lomem.tw_ndhist( + dataset, + sig, + bin_lo, + bin_hi, + ) + + return N_2d, sig, bin_lo, bin_hi diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 7e49ebe6..06fb9e7b 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -20,6 +20,7 @@ from ..latin_hypercube import latin_hypercube as lh from ..lightcone_generators import generate_lc_data from ..utils import load_feniks_tcurve +from .N_utils import get_N_1d, get_N_2d BASE_PATH = Path(__file__).resolve().parent.parent FENIKS_FILTERS_PATH = BASE_PATH / "data" / "feniks_filters" @@ -77,62 +78,6 @@ def get_mag_ab(phot_table, col_name, ZP=25): return mag_ab -def get_N_1d(dim1, dim1_bin_edges=None, dmag=0.2, sig_scale=0.5): - dataset = dim1.reshape(dim1.size, 1) - if dim1_bin_edges is None: - dim1_bin_edges = np.arange(dim1.min(), dim1.max(), dmag) - - bin_lo = dim1_bin_edges[:-1].reshape(dim1_bin_edges[:-1].size, 1) - bin_hi = dim1_bin_edges[1:].reshape(dim1_bin_edges[1:].size, 1) - - sig = jnp.zeros_like(bin_lo) + (dmag * sig_scale) - - N_1d = diffndhist_lomem.tw_ndhist( - dataset, - sig, - bin_lo, - bin_hi, - ) - - return ( - N_1d, - sig, - bin_lo, - bin_hi, - ) - - -def get_N_2d(dim1, dim2, sig_scale=0.5): - dataset = np.vstack((dim1, dim2)).T - - dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 11) - dim2_bin_edges = np.linspace(dim2.min(), dim2.max(), 11) - - dim1_lo = dim1_bin_edges[:-1] - dim2_lo = dim2_bin_edges[:-1] - bin_lo = np.meshgrid(dim1_lo, dim2_lo, indexing="ij") - bin_lo = np.array(bin_lo).T.reshape(-1, 2) - - dim1_hi = dim1_bin_edges[1:] - dim2_hi = dim2_bin_edges[1:] - bin_hi = np.meshgrid(dim1_hi, dim2_hi, indexing="ij") - bin_hi = np.array(bin_hi).T.reshape(-1, 2) - - sig1 = np.diff(dim1_bin_edges) * sig_scale - sig2 = np.diff(dim2_bin_edges) * sig_scale - sig = np.meshgrid(sig1, sig2, indexing="ij") - sig = np.array(sig).T.reshape(-1, 2) - - N_2d = diffndhist_lomem.tw_ndhist( - dataset, - sig, - bin_lo, - bin_hi, - ) - - return N_2d, sig, bin_lo, bin_hi - - def refresh_lh_centroids(DATASET, lh_d_mag): lh_centroids, d_centroids = get_lh_centroids(DATASET.dataset, lh_d_mag) @@ -983,8 +928,6 @@ def get_feniks_data( mags_labels, colors, app_mag_funcs, - n_bins, - n_gals_post_cuts, filter_info, frac_cat, lh_centroids, diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 7f04b439..6b4b50b1 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -11,9 +11,14 @@ SDSS_MAGR_THRESH, SDSS_Z_MAX, SDSS_Z_MIN, + Dataset, FilterInfo, ) from ..latin_hypercube import latin_hypercube as lh +from ..lightcone_generators import generate_lc_data +from .N_utils import get_N_1d, get_N_2d + +Sdss = namedtuple("Sdss", Dataset._fields) LH_N_CENTROIDS = 20_000 LH_SIG = 3.5 @@ -104,15 +109,21 @@ def get_sdss_data( drn, ran_key, ssp_data, + num_halos_coarse_zbins=150, + num_halos_fine_zbins=250, + lgmp_min=10.0, + lgmp_max=15.0, + lc_sky_area_degsq=100, + n_z_phot_table=30, ): sdss, frac_cat = load_sdss_cuts_applied(drn) sdss_mag_thresh = SdssFilters( - sdss_u=None, - sdss_g=None, + sdss_u=30.0, + sdss_g=30.0, sdss_r=SDSS_MAGR_THRESH, - sdss_i=None, - sdss_z=None, + sdss_i=30.0, + sdss_z=30.0, ) sdss_in_lh = SdssFilters( sdss_u=True, @@ -140,6 +151,7 @@ def get_sdss_data( # derive colors from mags sdss_ug = sdss_u - sdss_g sdss_gr = sdss_g - sdss_r + sdss_ur = sdss_u - sdss_r sdss_ri = sdss_r - sdss_i sdss_iz = sdss_i - sdss_z @@ -158,6 +170,175 @@ def get_sdss_data( ] mag_labels = [r"$u$", r"$g$", r"$r$", r"$i$", r"$z$"] + ############################################################################## + # prepare 2D and 1D color spaces in coarse z-bins for fitting + zbins = np.array( + [ + [0.02, 0.1], + [0.1, 0.2], + ] + ) + ############################################################################## + Colors = namedtuple( + "Colors", + [ + "z_min", + "z_max", + "lc_data", + "ur_ri", + "gr_ri", + "ur", + ], + ) + # 2D (u - r, r - i) + Ur_ri = namedtuple("Ur_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + + # 2D (g - r, r - i) + Gr_ri = namedtuple("Gr_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + + # 1D (u - r | r) + Ur_condr = namedtuple( + "Ur_condr", + ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + ) + + colors = [] + for zbin in range(0, len(zbins)): + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_coarse_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (sdss_redshift > z_min) & (sdss_redshift <= z_max) + + # 2D (u - r, r - i) + N_ur_ri, sig_ur_ri, bin_lo_ur_ri, bin_hi_ur_ri = get_N_2d( + sdss_ur[z_sel], sdss_ri[z_sel] + ) + col_idx = [0, 2, 3] + ur_ri = Ur_ri(col_idx, sig_ur_ri, bin_lo_ur_ri, bin_hi_ur_ri, N_ur_ri) + + # 2D (g - r, r - i) + N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri = get_N_2d( + sdss_gr[z_sel], sdss_ri[z_sel] + ) + col_idx = [1, 2, 3] + gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) + + # 1D (u - r | r) + rbins = np.arange(sdss_r[z_sel].min(), sdss_r[z_sel].max(), 2) + + col_idx = [0, 2] + cond_idx = 2 + ur = [] + for r in range(len(rbins) - 1): + r_sel = (sdss_r[z_sel] > rbins[r]) & (sdss_r[z_sel] <= rbins[r + 1]) + N_1d_ur, sig_ur, bin_lo_ur, bin_hi_ur = get_N_1d(sdss_ur[z_sel][r_sel]) + ur.append( + Ur_condr( + col_idx, + cond_idx, + rbins[r], + rbins[r + 1], + sig_ur, + bin_lo_ur, + bin_hi_ur, + N_1d_ur, + ) + ) + + colors.append( + Colors( + z_min, + z_max, + lc_data, + ur_ri, + gr_ri, + ur, + ) + ) + + ############################################################################## + ############################################################################## + # prepare 1D app mag funcs in finer z-bins for fitting + fine_zbins = np.array( + [ + [0.02, 0.06], + [0.06, 0.1], + [0.1, 0.14], + [0.14, 0.18], + [0.18, 0.2], + ] + ) + ############################################################################## + AppMagFuncs = namedtuple( + "AppMagFuncs", + ["z_min", "z_max", "lc_data", "u", "r"], + ) + U = namedtuple( + "U", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + R = namedtuple( + "R", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + + app_mag_funcs = [] + for zbin in range(0, len(fine_zbins)): + z_min = fine_zbins[zbin][0] + z_max = fine_zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_fine_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (sdss_redshift > z_min) & (sdss_redshift <= z_max) + + # 1D (u) + mag_idx_u = 0 + N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(sdss_u[z_sel]) + u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u) + + # 1D (r) + mag_idx_r = 2 + N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(sdss_r[z_sel]) + r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r) + + app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, r)) + + ############################################################################## + lh_centroids, d_centroids = get_lh_centroids(dataset) # run initial diffndhist_lomem with fixed dmag @@ -176,6 +357,8 @@ def get_sdss_data( dataset_dim_labels, mags, mag_labels, + colors, + app_mag_funcs, filter_info, frac_cat, lh_centroids, @@ -197,21 +380,3 @@ def get_sdss_data( "sdss_z", ], ) - -Sdss = namedtuple( - "Sdss", - [ - "dataset", - "dataset_dim_labels", - "mags", - "mags_labels", - "filter_info", - "frac_cat", - "lh_centroids", - "d_centroids", - "N_data", - "lh_dmag", - "lh_dz", - "data_sky_area_degsq", - ], -) diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 497ec798..4b78e6b2 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -37,8 +37,6 @@ "mags_labels", "colors", "app_mag_funcs", - "n_bins", - "n_gals", "filter_info", "frac_cat", "lh_centroids", diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index ffedc410..7dd5f259 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -568,11 +568,201 @@ def plot_n_mags( plt.close() +# def plot_app_mag_funcs( +# dataset, +# data_label, +# param_collection, +# ran_key, +# ssp_data, +# savedir, +# lgmp_min=10.0, +# lgmp_max=15.0, +# num_halos=5000, +# lc_sky_area_degsq=1000, +# n_z_phot_table=30, +# cosmo_params=DEFAULT_COSMOLOGY, +# fb=FB, +# plt_show=True, +# ): +# dataset_mags = dataset.mags +# data_sky_area_degsq = dataset.data_sky_area_degsq + +# feniks_zbins = np.array( +# [ +# [0.2, 0.5], +# [0.5, 0.8], +# [0.8, 1.2], +# [1.2, 1.6], +# [1.6, 2.0], +# ] +# ) +# labels_z = [" z = " + str(np.round(np.median(z), 2)) for z in feniks_zbins] + +# colors_z = [ +# "#001219", # deep navy +# "#0a7a80", # teal +# "#80cca8", # mint +# "#c8b44a", # warm gold +# "#c87820", # amber +# ] +# fig_width = 7.1 +# fig_height = 5 + +# fontsize = 10 +# labelsize = 10 +# alpha = 0.75 +# s = 10 + +# fig, ax = plt.subplots( +# 2, 4, figsize=(fig_width, fig_height), constrained_layout=True +# ) +# fig.get_layout_engine().set(rect=[0, 0, 1, 0.95]) + +# handles = [ +# mlines.Line2D([], [], color=c, linewidth=6, solid_capstyle="butt", label=label) +# for c, label in zip(colors_z, labels_z) +# ] + +# fig.legend( +# handles=handles, +# loc="upper center", +# ncol=8, +# frameon=False, +# handlelength=3, +# handleheight=0.5, +# columnspacing=0.8, +# handletextpad=0.1, +# bbox_to_anchor=(0.5, 1.02), +# fontsize=10, +# ) + +# xlim = [] +# for zbin in range(0, len(feniks_zbins)): +# z_min = feniks_zbins[zbin][0] +# z_max = feniks_zbins[zbin][1] + +# z_min, z_max = np.round(z_min, 2), np.round(z_max, 2) +# z_mask = (dataset_mags[:, -1] > z_min) & (dataset_mags[:, -1] < z_max) +# dataset_mags_z = dataset_mags[z_mask] +# data_vol_mpc3 = zbin_volume(data_sky_area_degsq, zlow=z_min, zhigh=z_max).value + +# z_phot_table = 10 ** jnp.linspace( +# np.log10(z_min), np.log10(z_max), n_z_phot_table +# ) +# lc_data = generate_lc_data( +# ran_key, +# num_halos, +# z_min, +# z_max, +# lgmp_min, +# lgmp_max, +# lc_sky_area_degsq, +# ssp_data, +# dataset.filter_info.tcurves, +# z_phot_table, +# ) +# obs_mags, weights, phot_kern_results = mag_kern( +# ran_key, +# param_collection, +# lc_data, +# dataset.filter_info.mag_thresh, +# dataset.frac_cat, +# ) + +# n_bands = obs_mags.shape[1] + +# row = 0 +# col = 0 +# dmag = 0.5 +# for i in range(0, n_bands): +# bins = np.arange( +# dataset_mags_z[:, i].min(), +# dataset_mags_z[:, i].max() + dmag, +# dmag, +# ) +# if zbin == 0: +# xlim.append([bins.min() - 0.5, bins.max() + 1]) + +# bin_centers = (bins[1:] + bins[:-1]) / 2 +# ax[row, col].set_xlim(bins[0], bins[-1] + 0.2) +# # ax[0, i].set_xticks([]) + +# n_data, bin_edges = np.histogram( +# dataset_mags_z[:, i], +# weights=np.ones_like(dataset_mags_z[:, i]) * (1 / data_vol_mpc3), +# bins=bins, +# ) +# with warnings.catch_warnings(): +# warnings.filterwarnings("ignore", category=RuntimeWarning) +# ax[row, col].scatter( +# bin_centers, np.log10(n_data), c=colors_z[zbin], alpha=alpha, s=s +# ) + +# ( +# n_diffsky, +# _, +# ) = np.histogram( +# obs_mags[:, i], +# weights=weights * (1 / lc_data.lc_tot_vol_mpc3), +# bins=bins, +# ) +# with warnings.catch_warnings(): +# warnings.filterwarnings("ignore", category=RuntimeWarning) +# ax[row, col].plot( +# bin_centers, np.log10(n_diffsky), c=colors_z[zbin], alpha=alpha +# ) + +# ax[row, col].set_xticks(np.arange(15, 30, 2)) +# ax[row, col].minorticks_on() +# ax[row, col].tick_params( +# which="major", +# direction="in", +# top=True, +# right=True, +# length=6, +# width=1, +# labelsize=labelsize, +# ) +# ax[row, col].tick_params( +# which="minor", +# direction="in", +# top=True, +# right=True, +# length=3, +# width=0.8, +# labelsize=labelsize, +# ) + +# ax[row, col].set_ylim(-6.9, -2.5) +# ax[row, col].set_xlim(xlim[i]) +# ax[row, col].set_xlabel(dataset.mags_labels[i]) + +# if col != 0: +# ax[row, col].set_yticklabels([]) + +# if col == 3: +# row += 1 +# col = 0 +# else: +# col += 1 + +# ax[0, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) +# ax[1, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) +# fig.savefig( +# savedir + "/" + data_label + "_app_mag_funcs.png", +# dpi=300, +# ) +# if plt_show: +# plt.show() +# plt.close() + + def plot_app_mag_funcs( dataset, data_label, param_collection, ran_key, + zbins, ssp_data, savedir, lgmp_min=10.0, @@ -580,6 +770,7 @@ def plot_app_mag_funcs( num_halos=5000, lc_sky_area_degsq=1000, n_z_phot_table=30, + dmag=0.5, cosmo_params=DEFAULT_COSMOLOGY, fb=FB, plt_show=True, @@ -587,59 +778,64 @@ def plot_app_mag_funcs( dataset_mags = dataset.mags data_sky_area_degsq = dataset.data_sky_area_degsq - feniks_zbins = np.array( - [ - [0.2, 0.5], - [0.5, 0.8], - [0.8, 1.2], - [1.2, 1.6], - [1.6, 2.0], - ] - ) - labels_z = [" z = " + str(np.round(np.median(z), 2)) for z in feniks_zbins] + zbins = np.array(zbins) + labels_z = [" z = " + str(np.round(np.median(z), 2)) for z in zbins] colors_z = [ - "#001219", # deep navy - "#0a7a80", # teal - "#80cca8", # mint - "#c8b44a", # warm gold - "#c87820", # amber + "#001219", + "#0a7a80", + "#80cca8", + "#c8b44a", + "#c87820", ] - fig_width = 7.1 - fig_height = 5 + + n_bands = dataset_mags.shape[1] - 1 + if n_bands <= 5: + nrows, ncols = 1, n_bands + else: + ncols = 4 + nrows = int(np.ceil(n_bands / ncols)) + + fig_width = 7.1 * ncols / 4 + fig_height = 5 * nrows / 2 fontsize = 10 labelsize = 10 alpha = 0.75 s = 10 - fig, ax = plt.subplots( - 2, 4, figsize=(fig_width, fig_height), constrained_layout=True + fig, axes = plt.subplots( + nrows, ncols, figsize=(fig_width, fig_height), constrained_layout=True ) - fig.get_layout_engine().set(rect=[0, 0, 1, 0.95]) + if nrows == 1: + axes = axes[np.newaxis, :] + if ncols == 1: + axes = axes[:, np.newaxis] handles = [ mlines.Line2D([], [], color=c, linewidth=6, solid_capstyle="butt", label=label) for c, label in zip(colors_z, labels_z) ] + fig.get_layout_engine().set(rect=[0, 0, 1, 0.92]) + fig.legend( handles=handles, loc="upper center", - ncol=8, + ncol=len(zbins), frameon=False, handlelength=3, handleheight=0.5, columnspacing=0.8, handletextpad=0.1, - bbox_to_anchor=(0.5, 1.02), + bbox_to_anchor=(0.5, 1.0), fontsize=10, ) xlim = [] - for zbin in range(0, len(feniks_zbins)): - z_min = feniks_zbins[zbin][0] - z_max = feniks_zbins[zbin][1] + for zbin in range(len(zbins)): + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] z_min, z_max = np.round(z_min, 2), np.round(z_max, 2) z_mask = (dataset_mags[:, -1] > z_min) & (dataset_mags[:, -1] < z_max) @@ -669,12 +865,9 @@ def plot_app_mag_funcs( dataset.frac_cat, ) - n_bands = obs_mags.shape[1] - row = 0 col = 0 - dmag = 0.5 - for i in range(0, n_bands): + for i in range(n_bands): bins = np.arange( dataset_mags_z[:, i].min(), dataset_mags_z[:, i].max() + dmag, @@ -684,8 +877,7 @@ def plot_app_mag_funcs( xlim.append([bins.min() - 0.5, bins.max() + 1]) bin_centers = (bins[1:] + bins[:-1]) / 2 - ax[row, col].set_xlim(bins[0], bins[-1] + 0.2) - # ax[0, i].set_xticks([]) + axes[row, col].set_xlim(bins[0], bins[-1] + 0.2) n_data, bin_edges = np.histogram( dataset_mags_z[:, i], @@ -694,27 +886,24 @@ def plot_app_mag_funcs( ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - ax[row, col].scatter( + axes[row, col].scatter( bin_centers, np.log10(n_data), c=colors_z[zbin], alpha=alpha, s=s ) - ( - n_diffsky, - _, - ) = np.histogram( + n_diffsky, _ = np.histogram( obs_mags[:, i], weights=weights * (1 / lc_data.lc_tot_vol_mpc3), bins=bins, ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - ax[row, col].plot( + axes[row, col].plot( bin_centers, np.log10(n_diffsky), c=colors_z[zbin], alpha=alpha ) - ax[row, col].set_xticks(np.arange(15, 30, 2)) - ax[row, col].minorticks_on() - ax[row, col].tick_params( + axes[row, col].set_xticks(np.arange(15, 30, 2)) + axes[row, col].minorticks_on() + axes[row, col].tick_params( which="major", direction="in", top=True, @@ -723,7 +912,7 @@ def plot_app_mag_funcs( width=1, labelsize=labelsize, ) - ax[row, col].tick_params( + axes[row, col].tick_params( which="minor", direction="in", top=True, @@ -733,24 +922,28 @@ def plot_app_mag_funcs( labelsize=labelsize, ) - ax[row, col].set_ylim(-6.9, -2.5) - ax[row, col].set_xlim(xlim[i]) - ax[row, col].set_xlabel(dataset.mags_labels[i]) + axes[row, col].set_ylim(-6.9, -2.5) + axes[row, col].set_xlim(xlim[i]) + axes[row, col].set_xlabel(dataset.mags_labels[i]) if col != 0: - ax[row, col].set_yticklabels([]) + axes[row, col].set_yticklabels([]) - if col == 3: + if col == ncols - 1: row += 1 col = 0 else: col += 1 - ax[0, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) - ax[1, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) + for idx in range(n_bands, nrows * ncols): + r, c = divmod(idx, ncols) + axes[r, c].set_visible(False) + + for r in range(nrows): + axes[r, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) + fig.savefig( savedir + "/" + data_label + "_app_mag_funcs.png", - # bbox_extra_artists=(leg,), dpi=300, ) if plt_show: diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index df0f8ec9..e67952f3 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run150 -model_nickname: run150_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run150/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run154 +model_nickname: run154_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run154/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,9 +11,9 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: True -plot_feniks: False -plot_hizels: False +plot_sdss: False +plot_feniks: True +plot_hizels: True plots: num_halos : 3000 diff --git a/scripts/config_sdss.yaml b/scripts/config_sdss.yaml index 645461fc..a44cb376 100644 --- a/scripts/config_sdss.yaml +++ b/scripts/config_sdss.yaml @@ -6,12 +6,15 @@ start_fit_type: "all" fit_runid: "runtest" fit_type: "all" +sdss: + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 200 + epoch: n_it: 1 - n_steps: 25 + n_steps: 2 step_size: 0.1 N_centroids: 2000 - num_halos: 300 defaults: diffstarpop: True diff --git a/scripts/fit_feniks.py b/scripts/fit_feniks.py index f8b0750e..9f096089 100644 --- a/scripts/fit_feniks.py +++ b/scripts/fit_feniks.py @@ -48,21 +48,6 @@ emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) emline_wave_table = jnp.array([emline_wave_aa]) - # load feniks data - ran_key = jran.key(0) - feniks = load_feniks.get_feniks_data( - feniks_drn, - ran_key, - ssp_data, - num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], - num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], - ) - remove = {"dataset_dim_labels", "mags_labels"} - FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) - feniks_fitting_data = FeniksFitting( - **{f: getattr(feniks, f) for f in FeniksFitting._fields} - ) - # start fit dirs fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" param_collection_fit = lc_mock.load_diffsky_param_collection_merging( @@ -105,9 +90,25 @@ initial_pts = [] start = time.time() + ran_key = jran.key(0) for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') + feniks = load_feniks.get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple( + "Feniks", [f for f in feniks._fields if f not in remove] + ) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} + ) + loss_hist, u_theta_fit = Np_specphot_opt.fit_N_phot_2d( u_theta_fit, trainable_params, @@ -116,6 +117,7 @@ n_steps=cfg["epoch"]["n_steps"], step_size=cfg["epoch"]["step_size"], ) + jax.clear_caches() param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) diff --git a/scripts/fit_sdss.py b/scripts/fit_sdss.py index f87db9ea..7dd28dc2 100644 --- a/scripts/fit_sdss.py +++ b/scripts/fit_sdss.py @@ -1,6 +1,7 @@ import argparse import os import time +from collections import namedtuple from datetime import datetime import jax @@ -21,8 +22,6 @@ from diffhtwo.experimental import param_utils as pu from diffhtwo.experimental.data_loaders import load_sdss -from diffhtwo.experimental.defaults import SDSS_Z_MAX, SDSS_Z_MIN -from diffhtwo.experimental.latin_hypercube import lh_utils as lhu from diffhtwo.experimental.optimizers import Np_specphot_opt DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ @@ -49,10 +48,6 @@ emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) emline_wave_table = jnp.array([emline_wave_aa]) - # load sdss data - ran_key = jran.key(0) - SDSS = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) - # start fit dirs fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" param_collection_fit = lc_mock.load_diffsky_param_collection_merging( @@ -93,32 +88,29 @@ os.system(f"cp {args.config} {fit_diagnostics_save_drn}") - sdss_z_min = [SDSS_Z_MIN, 0.08, 0.14] - sdss_z_max = [0.08, 0.14, SDSS_Z_MAX] - + ran_key = jran.key(0) initial_pts = [] start = time.time() for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - sdss = load_sdss.refresh_lh_centroids(SDSS) - # SDSS - sdss_meta_data, sdss_fitting_data = lhu.get_zbins_lh_lc( + sdss = load_sdss.get_sdss_data( + sdss_drn, ran_key, - SDSS, - sdss_z_min, - sdss_z_max, ssp_data, - cfg["epoch"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", - num_halos=cfg["epoch"]["num_halos"], + num_halos_coarse_zbins=cfg["sdss"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["sdss"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + SdssFitting = namedtuple("Sdss", [s for s in sdss._fields if s not in remove]) + sdss_fitting_data = SdssFitting( + **{s: getattr(sdss, s) for s in SdssFitting._fields} ) - loss_hist, u_theta_fit = Np_specphot_opt.fit_N_multi_z( + loss_hist, u_theta_fit = Np_specphot_opt.fit_N_phot_2d( u_theta_fit, trainable_params, ran_key, - sdss_meta_data, sdss_fitting_data, n_steps=cfg["epoch"]["n_steps"], step_size=cfg["epoch"]["step_size"], diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 35c6ff1e..3c72f77a 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -206,12 +206,22 @@ feniks = load_feniks.get_feniks_data(feniks_drn, ran_key, ssp_data) if cfg["plots"]["plot_app_mag_funcs"]: + feniks_zbins = np.array( + [ + [0.2, 0.5], + [0.5, 0.7], + [0.7, 1.0], + [1.0, 1.5], + [1.5, 2.0], + ] + ) print("Generating FENIKS app mag funcs plot...") plot_app_mag_funcs( feniks, feniks_label, param_collection_fit, ran_key, + feniks_zbins, ssp_data, fit_diagnostics_save_drn, num_halos=num_halos, @@ -441,6 +451,19 @@ [0.18, 0.2], ] ) + if cfg["plots"]["plot_app_mag_funcs"]: + print("Generating SDSS app mag funcs plot...") + plot_app_mag_funcs( + sdss, + sdss_label, + param_collection_fit, + ran_key, + sdss_zbins, + ssp_data, + fit_diagnostics_save_drn, + num_halos=num_halos, + plt_show=False, + ) if cfg["plots"]["plot_exsitu_frac"]: print("Generating SDSS ex-situ frac plot...") From 9659943e225d308c3a29622fb2a6dc810f885e61 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 9 Jun 2026 11:37:21 -0500 Subject: [PATCH 21/57] fit_sdss_feniks --- .../optimizers/Np_specphot_opt.py | 51 +++++ scripts/config_diagnostics.yaml | 12 +- scripts/config_sdss_feniks.py | 31 +++ scripts/config_sdss_feniks.yaml | 15 +- scripts/fit_sdss_feniks.py | 118 +++++------ scripts/fits_sdss_feniks.py | 192 ++++++++++++++++++ 6 files changed, 350 insertions(+), 69 deletions(-) create mode 100644 scripts/config_sdss_feniks.py create mode 100644 scripts/fits_sdss_feniks.py diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 27d51cca..21997606 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -172,6 +172,57 @@ def pytree_norm(grads): return jnp.sqrt(sum(jnp.sum(g**2) for g in leaves)) +@partial(jjit, static_argnames=["n_steps", "step_size"]) +def fit_sdss_feniks( + u_theta_init, + trainable, + ran_key, + sdss_fitting_data, + feniks_fitting_data, + n_steps=2, + step_size=1e-2, + w_sdss=1.0, + w_feniks=1.0, +): + opt_init, opt_update, get_params = jax_opt.adam(step_size) + opt_state = opt_init(u_theta_init) + + def _opt_update(opt_state, i): + u_theta = get_params(opt_state) + loss_sdss, grad_sdss = _loss_and_grad_phot_kern_2d_multiz( + u_theta, + ran_key, + sdss_fitting_data, + ) + + loss_feniks, grad_feniks = _loss_and_grad_phot_kern_2d_multiz( + u_theta, + ran_key, + feniks_fitting_data, + ) + + loss_sdss = w_sdss * loss_sdss + loss_feniks = w_feniks * loss_feniks + loss = loss_sdss + loss_feniks + + grads = tuple( + w_sdss * gs + w_feniks * gf for gs, gf in zip(grad_sdss, grad_feniks) + ) + # set grads for untrainable params to 0.0 + grads = tuple( + jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) + ) + + opt_state = opt_update(i, grads, opt_state) + return opt_state, (loss, loss_sdss, loss_feniks) + + opt_state, (loss_hist, loss_sdss_hist, loss_feniks_hist) = lax.scan( + _opt_update, opt_state, jnp.arange(n_steps) + ) + u_theta_fit = get_params(opt_state) + return loss_hist, loss_sdss_hist, loss_feniks_hist, u_theta_fit + + @partial(jjit, static_argnames=["n_steps", "step_size"]) def fit_feniks_hizels( u_theta_init, diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index e67952f3..aed285d2 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run154 -model_nickname: run154_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run154/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run156 +model_nickname: run156_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run156/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,9 +11,9 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: False -plot_feniks: True -plot_hizels: True +plot_sdss: True +plot_feniks: False +plot_hizels: False plots: num_halos : 3000 diff --git a/scripts/config_sdss_feniks.py b/scripts/config_sdss_feniks.py new file mode 100644 index 00000000..9b90f0df --- /dev/null +++ b/scripts/config_sdss_feniks.py @@ -0,0 +1,31 @@ +base_path: "/Users/kumail/diffdir" + +start_runid: "run90" +start_fit_type: "all" + +fit_runid: "runtest" +fit_type: "all" + +sdss: + N_centroids: 100 + num_halos: 100 + +feniks: + lh_d_mag: 0.4 + N_centroids: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 100 + +hizels: + num_halos: 100 + +epoch: + n_it: 1 + n_steps: 2 + step_size: 0.1 + +defaults: + diffstarpop: True + spspop: True + ssperr: True + merging: True diff --git a/scripts/config_sdss_feniks.yaml b/scripts/config_sdss_feniks.yaml index ab7f2b1d..64f3a54d 100644 --- a/scripts/config_sdss_feniks.yaml +++ b/scripts/config_sdss_feniks.yaml @@ -6,18 +6,21 @@ start_fit_type: "all" fit_runid: "runtest" fit_type: "all" -feniks: - lh_d_mag: 0.6 - N_centroids: 2000 - sdss: - N_centroids: 2000 + N_centroids: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 100 + +feniks: + lh_d_mag: 0.4 + N_centroids: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 100 epoch: n_it: 1 n_steps: 2 step_size: 0.1 - num_halos: 200 defaults: diffstarpop: True diff --git a/scripts/fit_sdss_feniks.py b/scripts/fit_sdss_feniks.py index a1353769..6693802f 100644 --- a/scripts/fit_sdss_feniks.py +++ b/scripts/fit_sdss_feniks.py @@ -1,8 +1,11 @@ import argparse import os import time +from collections import namedtuple from datetime import datetime +from pathlib import Path +import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np @@ -20,12 +23,6 @@ from diffhtwo.experimental import param_utils as pu from diffhtwo.experimental.data_loaders import load_feniks, load_sdss -from diffhtwo.experimental.defaults import ( - FENIKS_Z_MIN, - SDSS_Z_MAX, - SDSS_Z_MIN, -) -from diffhtwo.experimental.latin_hypercube import lh_utils as lhu from diffhtwo.experimental.optimizers import Np_specphot_opt DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ @@ -42,7 +39,6 @@ sdss_drn = cfg["base_path"] + "/sdss" feniks_drn = cfg["base_path"] + "/feniks" - ssp_filename = ( cfg["base_path"] + "/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5" @@ -51,15 +47,8 @@ # get ssp data ssp_data = load_ssp_templates(fn=ssp_filename) ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) - emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) - emline_wave_table = jnp.array([emline_wave_aa]) + halpha_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) - # load data - ran_key = jran.key(0) - SDSS = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) - FENIKS = load_feniks.get_feniks_data( - feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] - ) # start fit dirs fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" param_collection_fit = lc_mock.load_diffsky_param_collection_merging( @@ -96,98 +85,113 @@ + cfg["fit_type"] ) os.makedirs(fit_diagnostics_save_drn + "/loss", exist_ok=True) - os.makedirs(fit_diagnostics_save_drn + "/sdss_lh_N_z", exist_ok=True) - os.makedirs(fit_diagnostics_save_drn + "/feniks_lh_N_z", exist_ok=True) + os.makedirs(fit_diagnostics_save_drn + "/lh_N_z", exist_ok=True) os.system(f"cp {args.config} {fit_diagnostics_save_drn}") - sdss_z_min = [SDSS_Z_MIN, 0.1] - sdss_z_max = [0.1, SDSS_Z_MAX] - - feniks_z_min = [FENIKS_Z_MIN, 1] - feniks_z_max = [1, 2] - initial_pts = [] start = time.time() + ran_key = jran.key(0) for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - sdss = load_sdss.refresh_lh_centroids(SDSS) - # SDSS - sdss_meta_data, sdss_fitting_data = lhu.get_zbins_lh_lc( + # load sdss data + sdss = load_sdss.get_sdss_data( + sdss_drn, ran_key, - SDSS, - sdss_z_min, - sdss_z_max, ssp_data, - cfg["sdss"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/sdss_lh_N_z", - num_halos=cfg["epoch"]["num_halos"], + num_halos_coarse_zbins=cfg["sdss"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["sdss"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + SdssFitting = namedtuple("Sdss", [s for s in sdss._fields if s not in remove]) + sdss_fitting_data = SdssFitting( + **{s: getattr(sdss, s) for s in SdssFitting._fields} ) - # FENIKS - feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( + # load feniks data + feniks = load_feniks.get_feniks_data( + feniks_drn, ran_key, - FENIKS, - feniks_z_min, - feniks_z_max, ssp_data, - cfg["feniks"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/feniks_lh_N_z", - num_halos=cfg["epoch"]["num_halos"], + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple( + "Feniks", [f for f in feniks._fields if f not in remove] + ) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} ) - loss_hist, u_theta_fit = Np_specphot_opt.fit_sdss_feniks_hizels( + ( + loss_hist, + loss_sdss_hist, + loss_feniks_hist, + u_theta_fit, + ) = Np_specphot_opt.fit_sdss_feniks( u_theta_fit, trainable_params, ran_key, - sdss_meta_data, sdss_fitting_data, - feniks_meta_data, feniks_fitting_data, - # hizels, - # line_wave_table, n_steps=cfg["epoch"]["n_steps"], step_size=cfg["epoch"]["step_size"], ) + jax.clear_caches() + param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) lc_mock.write_diffsky_param_collection_merging( fit_save_drn, cfg["fit_runid"] + "_" + cfg["fit_type"], param_collection_fit, ) - if epoch == 0: STEPS = np.arange(1, cfg["epoch"]["n_steps"] + 1, 1) - LOSS_HIST = loss_hist - + LOSS_SDSS_HIST = loss_sdss_hist + LOSS_FENIKS_HIST = loss_feniks_hist initial_pts.append((STEPS[0], LOSS_HIST[0])) else: steps = np.arange(STEPS[-1] + 1, STEPS[-1] + cfg["epoch"]["n_steps"] + 1, 1) initial_pts.append((steps[0], loss_hist[0])) - STEPS = np.concatenate((STEPS, steps)) LOSS_HIST = np.concatenate((LOSS_HIST, loss_hist)) - + LOSS_SDSS_HIST = np.concatenate((LOSS_SDSS_HIST, loss_sdss_hist)) + LOSS_FENIKS_HIST = np.concatenate((LOSS_FENIKS_HIST, loss_feniks_hist)) end = time.time() elapsed = end - start print( f'Gradient descent took {elapsed/60:.3f} minutes for {cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]} steps.' ) print(f'speed: {elapsed/(cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]):.3f} s/it') - # gradient descent figure fig_loss, ax_loss = plt.subplots(1) - start_step = [s[0] for s in initial_pts] start_loss = [s[1] for s in initial_pts] - ax_loss.scatter(start_step, start_loss, s=50, c="deepskyblue") - - ax_loss.plot(STEPS, LOSS_HIST, c="deepskyblue") - ax_loss.set_ylabel("Poisson Loss") + ax_loss.scatter(start_step, start_loss, s=50, c="k") + ax_loss.plot(STEPS, LOSS_HIST, c="k", label="total") + ax_loss.plot( + STEPS, + LOSS_SDSS_HIST, + c="#0a7a80", + linestyle="--", + alpha=0.7, + label="sdss", + ) + ax_loss.plot( + STEPS, + LOSS_FENIKS_HIST, + c="#c87820", + linestyle="--", + alpha=0.7, + label="feniks", + ) + ax_loss.legend() + ax_loss.set_ylabel("Poisson Negative Log-Likelihood") ax_loss.set_xlabel("steps") ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") - plt.savefig(fit_diagnostics_save_drn + "/loss/sdss_feniks_loss_" + ts + ".png") + plt.savefig(fit_diagnostics_save_drn + "/loss/loss_" + ts + ".png") plt.close() diff --git a/scripts/fits_sdss_feniks.py b/scripts/fits_sdss_feniks.py new file mode 100644 index 00000000..779e2dbd --- /dev/null +++ b/scripts/fits_sdss_feniks.py @@ -0,0 +1,192 @@ +import argparse +import os +import time +from collections import namedtuple +from datetime import datetime +from pathlib import Path + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import yaml +from diffsky.data_loaders.hacc_utils import lc_mock +from diffsky.merging.merging_model import DEFAULT_MERGE_PARAMS +from diffsky.param_utils.spspop_param_utils import DEFAULT_SPSPOP_PARAMS +from diffsky.ssp_err_model.defaults import ZERO_SSPERR_PARAMS +from diffstar.diffstarpop.kernels.params.params_diffstarpopfits_mgash import ( + DiffstarPop_Params_Diffstarpopfits_mgash, +) +from dsps import load_ssp_templates +from dsps.data_loaders import load_emline_info as lemi +from jax import random as jran + +from diffhtwo.experimental import param_utils as pu +from diffhtwo.experimental.data_loaders import load_feniks, load_hizels +from diffhtwo.experimental.optimizers import Np_specphot_opt + +DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ + "galacticus_in_plus_ex_situ" +] + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("--config", default="config_diffsky.yaml") + args = p.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # sdss_drn = cfg["base_path"] + "/sdss" + feniks_drn = cfg["base_path"] + "/feniks" + hizels_drn = Path(cfg["base_path"] + "/hizels") + ssp_filename = ( + cfg["base_path"] + + "/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5" + ) + + # get ssp data + ssp_data = load_ssp_templates(fn=ssp_filename) + ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) + halpha_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) + + # load feniks data + ran_key = jran.key(0) + feniks = load_feniks.get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} + ) + + # load hizels data + hizels_fitting_data = load_hizels.get_hizels_data( + hizels_drn, + ran_key, + ssp_data, + feniks.filter_info.tcurves, + halpha_wave_aa, + num_halos=cfg["hizels"]["num_halos"], + ) + + # start fit dirs + fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" + param_collection_fit = lc_mock.load_diffsky_param_collection_merging( + fit_start_drn, + cfg["start_runid"] + "_" + cfg["start_fit_type"], + ) + if cfg["defaults"]["diffstarpop"]: + param_collection_fit = param_collection_fit._replace( + diffstarpop_params=DIFFSTARPOP_GALACTICUS_exsitu + ) + if cfg["defaults"]["spspop"]: + param_collection_fit = param_collection_fit._replace( + spspop_params=DEFAULT_SPSPOP_PARAMS + ) + if cfg["defaults"]["ssperr"]: + param_collection_fit = param_collection_fit._replace( + ssperr_params=ZERO_SSPERR_PARAMS + ) + if cfg["defaults"]["merging"]: + param_collection_fit = param_collection_fit._replace( + merging_params=DEFAULT_MERGE_PARAMS + ) + + u_theta_fit = pu.get_u_theta_from_param_collection(param_collection_fit) + + # fit dirs + trainable_params = pu.get_trainable_params(fit_type=cfg["fit_type"]) + fit_save_drn = cfg["base_path"] + "/fits/" + cfg["fit_runid"] + "/" + fit_diagnostics_save_drn = ( + cfg["base_path"] + + "/fits/" + + cfg["fit_runid"] + + "/diagnostic_plots/" + + cfg["fit_type"] + ) + os.makedirs(fit_diagnostics_save_drn + "/loss", exist_ok=True) + os.makedirs(fit_diagnostics_save_drn + "/lh_N_z", exist_ok=True) + + os.system(f"cp {args.config} {fit_diagnostics_save_drn}") + + initial_pts = [] + start = time.time() + for epoch in range(0, cfg["epoch"]["n_it"]): + print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') + + ( + loss_hist, + loss_phot_hist, + loss_emline_hist, + u_theta_fit, + ) = Np_specphot_opt.fit_feniks_hizels( + u_theta_fit, + trainable_params, + ran_key, + feniks_fitting_data, + hizels_fitting_data, + n_steps=cfg["epoch"]["n_steps"], + step_size=cfg["epoch"]["step_size"], + ) + + jax.clear_caches() + + param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) + lc_mock.write_diffsky_param_collection_merging( + fit_save_drn, + cfg["fit_runid"] + "_" + cfg["fit_type"], + param_collection_fit, + ) + if epoch == 0: + STEPS = np.arange(1, cfg["epoch"]["n_steps"] + 1, 1) + LOSS_HIST = loss_hist + LOSS_PHOT_HIST = loss_phot_hist + LOSS_EMLINE_HIST = loss_emline_hist + initial_pts.append((STEPS[0], LOSS_HIST[0])) + else: + steps = np.arange(STEPS[-1] + 1, STEPS[-1] + cfg["epoch"]["n_steps"] + 1, 1) + initial_pts.append((steps[0], loss_hist[0])) + STEPS = np.concatenate((STEPS, steps)) + LOSS_HIST = np.concatenate((LOSS_HIST, loss_hist)) + LOSS_PHOT_HIST = np.concatenate((LOSS_PHOT_HIST, loss_phot_hist)) + LOSS_EMLINE_HIST = np.concatenate((LOSS_EMLINE_HIST, loss_emline_hist)) + end = time.time() + elapsed = end - start + print( + f'Gradient descent took {elapsed/60:.3f} minutes for {cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]} steps.' + ) + print(f'speed: {elapsed/(cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]):.3f} s/it') + # gradient descent figure + fig_loss, ax_loss = plt.subplots(1) + start_step = [s[0] for s in initial_pts] + start_loss = [s[1] for s in initial_pts] + ax_loss.scatter(start_step, start_loss, s=50, c="k") + ax_loss.plot(STEPS, LOSS_HIST, c="k", label="total loss") + ax_loss.plot( + STEPS, + LOSS_PHOT_HIST, + c="#0a7a80", + linestyle="--", + alpha=0.7, + label="phot loss", + ) + ax_loss.plot( + STEPS, + LOSS_EMLINE_HIST, + c="#c87820", + linestyle="--", + alpha=0.7, + label="emline loss", + ) + ax_loss.legend() + ax_loss.set_ylabel("Poisson Negative Log-Likelihood") + ax_loss.set_xlabel("steps") + ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + plt.savefig(fit_diagnostics_save_drn + "/loss/loss_" + ts + ".png") + plt.close() From a124c011b35c145b9dda1989780b50a77194cae7 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 9 Jun 2026 12:12:10 -0500 Subject: [PATCH 22/57] apply completeness cuts in each band in sdss --- diffhtwo/experimental/data_loaders/load_sdss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 6b4b50b1..180e7c73 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -119,11 +119,11 @@ def get_sdss_data( sdss, frac_cat = load_sdss_cuts_applied(drn) sdss_mag_thresh = SdssFilters( - sdss_u=30.0, - sdss_g=30.0, + sdss_u=19.7, + sdss_g=18.0, sdss_r=SDSS_MAGR_THRESH, - sdss_i=30.0, - sdss_z=30.0, + sdss_i=17.0, + sdss_z=17.0, ) sdss_in_lh = SdssFilters( sdss_u=True, From c0b3441e6721733ecf5c72dd3733a69fd32e9370 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 9 Jun 2026 12:24:32 -0500 Subject: [PATCH 23/57] mag_thresh_mask sdss --- diffhtwo/experimental/data_loaders/load_sdss.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 180e7c73..bb2a9f37 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -35,13 +35,19 @@ def apply_ra_dec_cut(sdss, ra_min=120, ra_max=240, dec_min=0, dec_max=60): ] -def load_sdss_cuts_applied(drn): +def load_sdss_cuts_applied(drn, sdss_mag_thresh): sdss = sdl.load_sdss_cat(drn) sdss = apply_ra_dec_cut(sdss) # implement r <= 17.6 - mag_thresh_mask = sdss["modelMag_r"] <= SDSS_MAGR_THRESH + mag_thresh_mask = ( + (sdss["modelMag_u"] < sdss_mag_thresh.sdss_u) + & (sdss["modelMag_g"] < sdss_mag_thresh.sdss_g) + & (sdss["modelMag_r"] < sdss_mag_thresh.sdss_r) + & (sdss["modelMag_i"] < sdss_mag_thresh.sdss_i) + & (sdss["modelMag_z"] < sdss_mag_thresh.sdss_z) + ) sdss = sdss[mag_thresh_mask] N_obj_pre_outlier_cut = len(sdss) @@ -116,8 +122,6 @@ def get_sdss_data( lc_sky_area_degsq=100, n_z_phot_table=30, ): - sdss, frac_cat = load_sdss_cuts_applied(drn) - sdss_mag_thresh = SdssFilters( sdss_u=19.7, sdss_g=18.0, @@ -125,6 +129,8 @@ def get_sdss_data( sdss_i=17.0, sdss_z=17.0, ) + sdss, frac_cat = load_sdss_cuts_applied(drn, sdss_mag_thresh) + sdss_in_lh = SdssFilters( sdss_u=True, sdss_g=False, From 5c43e5cc05ac65b750fdb9c0f4a7feba43bbda5a Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 9 Jun 2026 20:33:59 -0500 Subject: [PATCH 24/57] fix sdss z-bins and add r-i|r in fitting --- .../experimental/data_loaders/load_sdss.py | 51 ++-- .../experimental/diagnostics/plot_phot.py | 224 +++--------------- scripts/config_diagnostics.yaml | 34 +-- scripts/generate_diagnostic_plots.py | 11 +- 4 files changed, 83 insertions(+), 237 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index bb2a9f37..82b85526 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -187,14 +187,7 @@ def get_sdss_data( ############################################################################## Colors = namedtuple( "Colors", - [ - "z_min", - "z_max", - "lc_data", - "ur_ri", - "gr_ri", - "ur", - ], + ["z_min", "z_max", "lc_data", "ur_ri", "gr_ri", "ur", "ri"], ) # 2D (u - r, r - i) Ur_ri = namedtuple("Ur_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) @@ -205,7 +198,13 @@ def get_sdss_data( # 1D (u - r | r) Ur_condr = namedtuple( "Ur_condr", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + ["col_idx", "cond_idx", "r_min", "r_max", "sig", "bin_lo", "bin_hi", "N_data"], + ) + + # 1D (r - i | r) + Ri_condr = namedtuple( + "Ri_condr", + ["col_idx", "cond_idx", "r_min", "r_max", "sig", "bin_lo", "bin_hi", "N_data"], ) colors = [] @@ -269,16 +268,27 @@ def get_sdss_data( ) ) - colors.append( - Colors( - z_min, - z_max, - lc_data, - ur_ri, - gr_ri, - ur, + # 1D (r - i | r) + col_idx = [2, 3] + cond_idx = 2 + ri = [] + for r in range(len(rbins) - 1): + r_sel = (sdss_r[z_sel] > rbins[r]) & (sdss_r[z_sel] <= rbins[r + 1]) + N_1d_ri, sig_ri, bin_lo_ri, bin_hi_ri = get_N_1d(sdss_ri[z_sel][r_sel]) + ri.append( + Ri_condr( + col_idx, + cond_idx, + rbins[r], + rbins[r + 1], + sig_ri, + bin_lo_ri, + bin_hi_ri, + N_1d_ri, + ) ) - ) + + colors.append(Colors(z_min, z_max, lc_data, ur_ri, gr_ri, ur, ri)) ############################################################################## ############################################################################## @@ -287,9 +297,8 @@ def get_sdss_data( [ [0.02, 0.06], [0.06, 0.1], - [0.1, 0.14], - [0.14, 0.18], - [0.18, 0.2], + [0.1, 0.15], + [0.15, 0.20], ] ) ############################################################################## diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index 7dd5f259..acec0b29 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -4,6 +4,7 @@ import numpy as np from diffstar.defaults import FB from dsps.cosmology.defaults import DEFAULT_COSMOLOGY +from scipy.ndimage import gaussian_filter from ..kernels.phot_kern import get_colors_mags, mag_kern from ..lc_utils import zbin_volume @@ -568,195 +569,6 @@ def plot_n_mags( plt.close() -# def plot_app_mag_funcs( -# dataset, -# data_label, -# param_collection, -# ran_key, -# ssp_data, -# savedir, -# lgmp_min=10.0, -# lgmp_max=15.0, -# num_halos=5000, -# lc_sky_area_degsq=1000, -# n_z_phot_table=30, -# cosmo_params=DEFAULT_COSMOLOGY, -# fb=FB, -# plt_show=True, -# ): -# dataset_mags = dataset.mags -# data_sky_area_degsq = dataset.data_sky_area_degsq - -# feniks_zbins = np.array( -# [ -# [0.2, 0.5], -# [0.5, 0.8], -# [0.8, 1.2], -# [1.2, 1.6], -# [1.6, 2.0], -# ] -# ) -# labels_z = [" z = " + str(np.round(np.median(z), 2)) for z in feniks_zbins] - -# colors_z = [ -# "#001219", # deep navy -# "#0a7a80", # teal -# "#80cca8", # mint -# "#c8b44a", # warm gold -# "#c87820", # amber -# ] -# fig_width = 7.1 -# fig_height = 5 - -# fontsize = 10 -# labelsize = 10 -# alpha = 0.75 -# s = 10 - -# fig, ax = plt.subplots( -# 2, 4, figsize=(fig_width, fig_height), constrained_layout=True -# ) -# fig.get_layout_engine().set(rect=[0, 0, 1, 0.95]) - -# handles = [ -# mlines.Line2D([], [], color=c, linewidth=6, solid_capstyle="butt", label=label) -# for c, label in zip(colors_z, labels_z) -# ] - -# fig.legend( -# handles=handles, -# loc="upper center", -# ncol=8, -# frameon=False, -# handlelength=3, -# handleheight=0.5, -# columnspacing=0.8, -# handletextpad=0.1, -# bbox_to_anchor=(0.5, 1.02), -# fontsize=10, -# ) - -# xlim = [] -# for zbin in range(0, len(feniks_zbins)): -# z_min = feniks_zbins[zbin][0] -# z_max = feniks_zbins[zbin][1] - -# z_min, z_max = np.round(z_min, 2), np.round(z_max, 2) -# z_mask = (dataset_mags[:, -1] > z_min) & (dataset_mags[:, -1] < z_max) -# dataset_mags_z = dataset_mags[z_mask] -# data_vol_mpc3 = zbin_volume(data_sky_area_degsq, zlow=z_min, zhigh=z_max).value - -# z_phot_table = 10 ** jnp.linspace( -# np.log10(z_min), np.log10(z_max), n_z_phot_table -# ) -# lc_data = generate_lc_data( -# ran_key, -# num_halos, -# z_min, -# z_max, -# lgmp_min, -# lgmp_max, -# lc_sky_area_degsq, -# ssp_data, -# dataset.filter_info.tcurves, -# z_phot_table, -# ) -# obs_mags, weights, phot_kern_results = mag_kern( -# ran_key, -# param_collection, -# lc_data, -# dataset.filter_info.mag_thresh, -# dataset.frac_cat, -# ) - -# n_bands = obs_mags.shape[1] - -# row = 0 -# col = 0 -# dmag = 0.5 -# for i in range(0, n_bands): -# bins = np.arange( -# dataset_mags_z[:, i].min(), -# dataset_mags_z[:, i].max() + dmag, -# dmag, -# ) -# if zbin == 0: -# xlim.append([bins.min() - 0.5, bins.max() + 1]) - -# bin_centers = (bins[1:] + bins[:-1]) / 2 -# ax[row, col].set_xlim(bins[0], bins[-1] + 0.2) -# # ax[0, i].set_xticks([]) - -# n_data, bin_edges = np.histogram( -# dataset_mags_z[:, i], -# weights=np.ones_like(dataset_mags_z[:, i]) * (1 / data_vol_mpc3), -# bins=bins, -# ) -# with warnings.catch_warnings(): -# warnings.filterwarnings("ignore", category=RuntimeWarning) -# ax[row, col].scatter( -# bin_centers, np.log10(n_data), c=colors_z[zbin], alpha=alpha, s=s -# ) - -# ( -# n_diffsky, -# _, -# ) = np.histogram( -# obs_mags[:, i], -# weights=weights * (1 / lc_data.lc_tot_vol_mpc3), -# bins=bins, -# ) -# with warnings.catch_warnings(): -# warnings.filterwarnings("ignore", category=RuntimeWarning) -# ax[row, col].plot( -# bin_centers, np.log10(n_diffsky), c=colors_z[zbin], alpha=alpha -# ) - -# ax[row, col].set_xticks(np.arange(15, 30, 2)) -# ax[row, col].minorticks_on() -# ax[row, col].tick_params( -# which="major", -# direction="in", -# top=True, -# right=True, -# length=6, -# width=1, -# labelsize=labelsize, -# ) -# ax[row, col].tick_params( -# which="minor", -# direction="in", -# top=True, -# right=True, -# length=3, -# width=0.8, -# labelsize=labelsize, -# ) - -# ax[row, col].set_ylim(-6.9, -2.5) -# ax[row, col].set_xlim(xlim[i]) -# ax[row, col].set_xlabel(dataset.mags_labels[i]) - -# if col != 0: -# ax[row, col].set_yticklabels([]) - -# if col == 3: -# row += 1 -# col = 0 -# else: -# col += 1 - -# ax[0, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) -# ax[1, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) -# fig.savefig( -# savedir + "/" + data_label + "_app_mag_funcs.png", -# dpi=300, -# ) -# if plt_show: -# plt.show() -# plt.close() - - def plot_app_mag_funcs( dataset, data_label, @@ -781,13 +593,22 @@ def plot_app_mag_funcs( zbins = np.array(zbins) labels_z = [" z = " + str(np.round(np.median(z), 2)) for z in zbins] - colors_z = [ - "#001219", - "#0a7a80", - "#80cca8", - "#c8b44a", - "#c87820", - ] + if len(labels_z) == 4: + colors_z = [ + "#001219", + "#0a7a80", + "#80cca8", + "#c87820", + ] + + else: + colors_z = [ + "#001219", + "#0a7a80", + "#80cca8", + "#c8b44a", + "#c87820", + ] n_bands = dataset_mags.shape[1] - 1 if n_bands <= 5: @@ -949,3 +770,14 @@ def plot_app_mag_funcs( if plt_show: plt.show() plt.close() + + +def plot_density( + x, y, ax, bins=80, sigma=1.5, cmap="plasma", n_levels=8, **contourf_kw +): + H, xe, ye = np.histogram2d(x, y, bins=bins) + H = gaussian_filter(H.T, sigma=sigma) + xc = 0.5 * (xe[:-1] + xe[1:]) + yc = 0.5 * (ye[:-1] + ye[1:]) + levels = np.logspace(np.log10(H[H > 0].min()), np.log10(H.max()), n_levels) + ax.contourf(xc, yc, H, levels=levels, cmap=cmap, **contourf_kw) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index aed285d2..9d7c41d1 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run156 -model_nickname: run156_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run156/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run158 +model_nickname: run158_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run158/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -17,19 +17,19 @@ plot_hizels: False plots: num_halos : 3000 - plot_insitu_sm: True + plot_insitu_sm: False plot_app_mag_funcs: True - plot_color_pdfs: True - plot_colors_mags: True - plot_mags: True - plot_ssperr: True - plot_massive_cen_colors: True - plot_merging_sat_colors: True + plot_color_pdfs: False + plot_colors_mags: False + plot_mags: False + plot_ssperr: False + plot_massive_cen_colors: False + plot_merging_sat_colors: False plot_satquench: False - plot_satquench_model: True - plot_insitu_smhm: True - plot_uvj: True - plot_exsitu_frac: True - plot_avpop: True - plot_burstpop: True - plot_fburst_mh_z: True \ No newline at end of file + plot_satquench_model: False + plot_insitu_smhm: False + plot_uvj: False + plot_exsitu_frac: False + plot_avpop: False + plot_burstpop: False + plot_fburst_mh_z: False \ No newline at end of file diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 3c72f77a..9037767e 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -321,6 +321,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_colors_mags"]: @@ -336,6 +337,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_mags"]: @@ -348,6 +350,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_ssperr"]: @@ -446,9 +449,8 @@ [ [0.02, 0.06], [0.06, 0.1], - [0.1, 0.14], - [0.14, 0.18], - [0.18, 0.2], + [0.1, 0.15], + [0.15, 0.2], ] ) if cfg["plots"]["plot_app_mag_funcs"]: @@ -535,6 +537,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_colors_mags"]: @@ -550,6 +553,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_mags"]: @@ -562,6 +566,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_ssperr"]: From a522a7bb714ecd04936d01bb9535e78aa6560644 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 9 Jun 2026 21:20:48 -0500 Subject: [PATCH 25/57] cond_min, cond_max --- .../experimental/data_loaders/load_feniks.py | 88 +++++++++++++++++-- .../experimental/data_loaders/load_sdss.py | 22 ++++- diffhtwo/experimental/kernels/N_phot.py | 4 +- 3 files changed, 102 insertions(+), 12 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 06fb9e7b..7fd51a08 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -416,7 +416,16 @@ def get_feniks_data( ug = [] Ug_condK = namedtuple( "Ug_condK", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( hsc_g[z_sel] < feniks_mag_thresh.HSC_G @@ -446,7 +455,16 @@ def get_feniks_data( ri = [] Ri_condK = namedtuple( "Ri_condK", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( hsc_i[z_sel] < feniks_mag_thresh.HSC_I @@ -476,7 +494,16 @@ def get_feniks_data( iz = [] Iz_condK = namedtuple( "Iz_condK", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) mag_sel_iz = (hsc_i[z_sel] < feniks_mag_thresh.HSC_I) & ( hsc_z[z_sel] < feniks_mag_thresh.HSC_Z @@ -506,7 +533,16 @@ def get_feniks_data( jh = [] JH_condK = namedtuple( "JH_condK", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( uds_H[z_sel] < feniks_mag_thresh.UDS_H @@ -616,7 +652,16 @@ def get_feniks_data( rz = [] Rz_condK = namedtuple( "Rz_condK", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( hsc_z[z_sel] < feniks_mag_thresh.HSC_Z @@ -646,7 +691,16 @@ def get_feniks_data( jh = [] JH_condK = namedtuple( "JH_condK", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( uds_H[z_sel] < feniks_mag_thresh.UDS_H @@ -771,7 +825,16 @@ def get_feniks_data( gr = [] Gr_condK = namedtuple( "Gr_condK", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( hsc_r[z_sel] < feniks_mag_thresh.HSC_R @@ -801,7 +864,16 @@ def get_feniks_data( jh = [] JH_condK = namedtuple( "JH_condK", - ["col_idx", "cond_idx", "K_min", "K_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( uds_H[z_sel] < feniks_mag_thresh.UDS_H diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 82b85526..acace57a 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -198,13 +198,31 @@ def get_sdss_data( # 1D (u - r | r) Ur_condr = namedtuple( "Ur_condr", - ["col_idx", "cond_idx", "r_min", "r_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) # 1D (r - i | r) Ri_condr = namedtuple( "Ri_condr", - ["col_idx", "cond_idx", "r_min", "r_max", "sig", "bin_lo", "bin_hi", "N_data"], + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) colors = [] diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index bc0b05bf..1c0ce554 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -37,8 +37,8 @@ def N_colors_mags( # get cond weight obs_mags_cond = obs_mags[:, space_n.cond_idx] - cond = (obs_mags_cond > space_n.K_min) & ( - obs_mags_cond <= space_n.K_max + cond = (obs_mags_cond > space_n.cond_min) & ( + obs_mags_cond <= space_n.cond_max ) weight = jnp.where(cond, gal_weight, 0.0) From 74f1aec7f6eaa7ed5ec0c4419849d21e619813a0 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 9 Jun 2026 22:48:44 -0500 Subject: [PATCH 26/57] plot_contours --- .../experimental/data_loaders/load_sdss.py | 31 +++- .../experimental/diagnostics/plot_contour.py | 164 ++++++++++++++++++ .../experimental/diagnostics/plot_phot.py | 12 -- scripts/config_diagnostics.yaml | 11 +- scripts/generate_diagnostic_plots.py | 42 ++++- 5 files changed, 239 insertions(+), 21 deletions(-) create mode 100644 diffhtwo/experimental/diagnostics/plot_contour.py diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index acace57a..a41e80c8 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -322,16 +322,28 @@ def get_sdss_data( ############################################################################## AppMagFuncs = namedtuple( "AppMagFuncs", - ["z_min", "z_max", "lc_data", "u", "r"], + ["z_min", "z_max", "lc_data", "u", "g", "r", "i", "z"], ) U = namedtuple( "U", ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], ) + G = namedtuple( + "G", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) R = namedtuple( "R", ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], ) + I = namedtuple( + "I", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) + Z = namedtuple( + "Z", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ) app_mag_funcs = [] for zbin in range(0, len(fine_zbins)): @@ -363,12 +375,27 @@ def get_sdss_data( N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(sdss_u[z_sel]) u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u) + # 1D (g) + mag_idx_g = 1 + N_1d_g, sig_g, bin_lo_g, bin_hi_g = get_N_1d(sdss_g[z_sel]) + g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g) + # 1D (r) mag_idx_r = 2 N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(sdss_r[z_sel]) r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r) - app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, r)) + # 1D (i) + mag_idx_i = 3 + N_1d_i, sig_i, bin_lo_i, bin_hi_i = get_N_1d(sdss_i[z_sel]) + i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i) + + # 1D (z) + mag_idx_z = 4 + N_1d_z, sig_z, bin_lo_z, bin_hi_z = get_N_1d(sdss_z[z_sel]) + z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z) + + app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, g, r, i, z)) ############################################################################## diff --git a/diffhtwo/experimental/diagnostics/plot_contour.py b/diffhtwo/experimental/diagnostics/plot_contour.py new file mode 100644 index 00000000..2507b435 --- /dev/null +++ b/diffhtwo/experimental/diagnostics/plot_contour.py @@ -0,0 +1,164 @@ +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import LinearSegmentedColormap +from scipy.ndimage import gaussian_filter + +from ..kernels.N_phot import N_colors_mags + +plt.rc("font", family="serif", serif=["Times New Roman"]) + +# Pantone: Dress Blues → Classic Blue → Aqua Sky → Minty Green → Illuminating +density_cmap = LinearSegmentedColormap.from_list( + "pantone_density", + [ + "#1B2A4A", # Dress Blues — empty/low + "#0F4C81", # Classic Blue + "#00A591", # Arcadia + "#84BD00", # Greenery + "#FEDF00", # Illuminating — peak density + ], +) +dusk = LinearSegmentedColormap.from_list( + "dusk", + [ + "#1B1F3B", # Evening Blue + "#7B4F9E", # Amethyst Orchid + "#E8A598", # Peach Pink + "#F5E6C8", # Almond Milk + ], +) + + +def plot_density( + bin_lo, bin_hi, N, ax, xlabel, ylabel, cmap, N_model=None, sigma=0.55, n_levels=8 +): + x_edges = np.unique(np.append(bin_lo[:, 0], bin_hi[-1, 0])) + y_edges = np.unique(np.append(bin_lo[:, 1], bin_hi[-1, 1])) + xc = 0.5 * (x_edges[:-1] + x_edges[1:]) + yc = 0.5 * (y_edges[:-1] + y_edges[1:]) + Z = np.log10( + gaussian_filter( + (N / N.sum()).reshape(len(y_edges) - 1, len(x_edges) - 1).astype(float), + sigma=sigma, + ).clip(min=np.finfo(float).tiny) + ) + levels = np.linspace(Z.min(), Z.max(), n_levels) + qm = ax.contourf(xc, yc, Z, levels=levels, cmap=cmap, alpha=0.5) + ax.get_figure().colorbar(qm, ax=ax, label=r"$\log_{10}(N / N_{\rm tot})$") + if N_model is not None: + Z_model = np.log10( + gaussian_filter( + (N_model / N_model.sum()) + .reshape(len(y_edges) - 1, len(x_edges) - 1) + .astype(float), + sigma=sigma, + ).clip(min=np.finfo(float).tiny) + ) + ax.contour( + xc, + yc, + Z_model, + levels=levels, + cmap=cmap, + linewidths=1.5, + alpha=0.9, + linestyles="dashed", + ) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + +def plot_density_raw(bin_lo, bin_hi, N, ax, xlabel, ylabel, cmap, N_model=None): + x_edges = np.unique(np.append(bin_lo[:, 0], bin_hi[-1, 0])) + y_edges = np.unique(np.append(bin_lo[:, 1], bin_hi[-1, 1])) + xc = 0.5 * (x_edges[:-1] + x_edges[1:]) + yc = 0.5 * (y_edges[:-1] + y_edges[1:]) + Z = np.log10( + (N / N.sum()) + .reshape(len(y_edges) - 1, len(x_edges) - 1) + .astype(float) + .clip(min=np.finfo(float).tiny) + ) + qm = ax.pcolormesh(x_edges, y_edges, Z, cmap=cmap) + ax.get_figure().colorbar(qm, ax=ax, label=r"$\log_{10}(N / N_{\rm tot})$") + if N_model is not None: + Z_model = np.log10( + (N_model / N_model.sum()) + .reshape(len(y_edges) - 1, len(x_edges) - 1) + .astype(float) + .clip(min=np.finfo(float).tiny) + ) + levels = np.linspace(Z.min(), Z.max(), 8) + ax.contour(xc, yc, Z_model, levels=levels, cmap=cmap, linewidths=0.8, alpha=0.9) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + +def plot_color_contours( + ran_key, + param_collection, + data, + mag_thresh, + frac_cat, + data_label, + savedir, +): + for z in range(0, len(data)): + z_data = data[z] + + z_data_model = N_colors_mags( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, + ) + fields = z_data_model._fields[3:] + z_min = z_data_model.z_min + z_max = z_data_model.z_max + + for f in range(0, len(fields)): + space = getattr(z_data_model, fields[f]) + + if isinstance(space, list): + pass + + else: + fig, ax = plt.subplots(constrained_layout=True) + fig.suptitle(str(z_min) + " < z < " + str(z_max)) + name = type(space).__name__ + xlabel, ylabel = parse_color_labels(name) + plot_density( + space.bin_lo, + space.bin_hi, + space.N_data, + ax, + xlabel, + ylabel, + dusk, + N_model=space.N_model, + ) + fig.savefig( + savedir + + "/" + + data_label + + "_" + + name + + "_" + + str(z_min) + + "-" + + str(z_max) + + ".png", + dpi=300, + ) + plt.close() + + +def parse_color_labels(name): + # "Ur_ri" → ["u-r", "r-i"] + x_str, y_str = name.lower().split("_") + + def to_label(s): + return f"${s[0]} - {s[1]}$" + + return to_label(x_str), to_label(y_str) diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index acec0b29..3d372e07 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -4,7 +4,6 @@ import numpy as np from diffstar.defaults import FB from dsps.cosmology.defaults import DEFAULT_COSMOLOGY -from scipy.ndimage import gaussian_filter from ..kernels.phot_kern import get_colors_mags, mag_kern from ..lc_utils import zbin_volume @@ -770,14 +769,3 @@ def plot_app_mag_funcs( if plt_show: plt.show() plt.close() - - -def plot_density( - x, y, ax, bins=80, sigma=1.5, cmap="plasma", n_levels=8, **contourf_kw -): - H, xe, ye = np.histogram2d(x, y, bins=bins) - H = gaussian_filter(H.T, sigma=sigma) - xc = 0.5 * (xe[:-1] + xe[1:]) - yc = 0.5 * (ye[:-1] + ye[1:]) - levels = np.logspace(np.log10(H[H > 0].min()), np.log10(H.max()), n_levels) - ax.contourf(xc, yc, H, levels=levels, cmap=cmap, **contourf_kw) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 9d7c41d1..5a534f28 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run158 -model_nickname: run158_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run158/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run159 +model_nickname: run159_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run159/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -17,8 +17,8 @@ plot_hizels: False plots: num_halos : 3000 - plot_insitu_sm: False - plot_app_mag_funcs: True + plot_color_contours: True + plot_app_mag_funcs: False plot_color_pdfs: False plot_colors_mags: False plot_mags: False @@ -28,6 +28,7 @@ plots: plot_satquench: False plot_satquench_model: False plot_insitu_smhm: False + plot_insitu_sm: False plot_uvj: False plot_exsitu_frac: False plot_avpop: False diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 9037767e..6bf1d9ab 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -29,6 +29,7 @@ plot_lgfburst_mh_z, ) from diffhtwo.experimental.diagnostics.plot_cen import plot_massive_cen_colors +from diffhtwo.experimental.diagnostics.plot_contour import plot_color_contours from diffhtwo.experimental.diagnostics.plot_halpha import ( plot_halpha, plot_halpha_insitu_exsitu, @@ -203,7 +204,25 @@ """ if cfg["plot_feniks"]: feniks_label = "feniks" # + cfg["model_nickname"].split("_")[0] - feniks = load_feniks.get_feniks_data(feniks_drn, ran_key, ssp_data) + feniks = load_feniks.get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=num_halos, + num_halos_fine_zbins=int(num_halos / 2), + ) + + if cfg["plots"]["plot_color_contours"]: + print("Generating FENIKS color contours plot...") + plot_color_contours( + ran_key, + param_collection_fit, + feniks.colors, + feniks.filter_info.mag_thresh, + feniks.frac_cat, + feniks_label, + fit_diagnostics_save_drn, + ) if cfg["plots"]["plot_app_mag_funcs"]: feniks_zbins = np.array( @@ -444,7 +463,13 @@ """ if cfg["plot_sdss"]: sdss_label = "sdss" # + cfg["model_nickname"].split("_")[0] - sdss = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) + sdss = load_sdss.get_sdss_data( + sdss_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=num_halos, + num_halos_fine_zbins=int(num_halos / 2), + ) sdss_zbins = np.array( [ [0.02, 0.06], @@ -453,6 +478,19 @@ [0.15, 0.2], ] ) + + if cfg["plots"]["plot_color_contours"]: + print("Generating SDSS color contours plot...") + plot_color_contours( + ran_key, + param_collection_fit, + sdss.colors, + sdss.filter_info.mag_thresh, + sdss.frac_cat, + sdss_label, + fit_diagnostics_save_drn, + ) + if cfg["plots"]["plot_app_mag_funcs"]: print("Generating SDSS app mag funcs plot...") plot_app_mag_funcs( From 6393e8121422edfda0a93bee563e71b8f412513a Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 00:43:28 -0500 Subject: [PATCH 27/57] fit all app mag funcs feniks --- .../experimental/data_loaders/load_feniks.py | 54 ++++++++++++++----- .../experimental/data_loaders/load_sdss.py | 2 + scripts/config_diagnostics.yaml | 10 ++-- scripts/generate_diagnostic_plots.py | 4 +- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 7fd51a08..166386ca 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -917,20 +917,20 @@ def get_feniks_data( ############################################################################## AppMagFuncs = namedtuple( "AppMagFuncs", - ["z_min", "z_max", "lc_data", "u", "r", "k"], + ["z_min", "z_max", "lc_data", "u", "g", "r", "i", "z", "J", "H", "K"], ) - U = namedtuple( - "U", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - R = namedtuple( - "R", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - K = namedtuple( - "K", + AppMagFunc = namedtuple( + "AppMagFunc", ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], ) + U = namedtuple("U", AppMagFunc._fields) + G = namedtuple("G", AppMagFunc._fields) + R = namedtuple("R", AppMagFunc._fields) + I = namedtuple("I", AppMagFunc._fields) + Z = namedtuple("Z", AppMagFunc._fields) + J = namedtuple("J", AppMagFunc._fields) + H = namedtuple("H", AppMagFunc._fields) + K = namedtuple("K", AppMagFunc._fields) app_mag_funcs = [] for zbin in range(0, len(fine_zbins)): @@ -963,19 +963,49 @@ def get_feniks_data( u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u) n_bins += bin_lo_u.size + # 1D (g) + mag_idx_g = 1 + N_1d_g, sig_g, bin_lo_g, bin_hi_g = get_N_1d(hsc_g[z_sel]) + g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g) + n_bins += bin_lo_g.size + # 1D (r) mag_idx_r = 2 N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(hsc_r[z_sel]) r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r) n_bins += bin_lo_r.size + # 1D (i) + mag_idx_i = 3 + N_1d_i, sig_i, bin_lo_i, bin_hi_i = get_N_1d(hsc_i[z_sel]) + i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i) + n_bins += bin_lo_i.size + + # 1D (z) + mag_idx_z = 4 + N_1d_z, sig_z, bin_lo_z, bin_hi_z = get_N_1d(hsc_z[z_sel]) + z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z) + n_bins += bin_lo_z.size + + # 1D (J) + mag_idx_j = 5 + N_1d_j, sig_j, bin_lo_j, bin_hi_j = get_N_1d(uds_J[z_sel]) + j = J(mag_idx_j, sig_j, bin_lo_j, bin_hi_j, N_1d_j) + n_bins += bin_lo_j.size + + # 1D (H) + mag_idx_h = 6 + N_1d_h, sig_h, bin_lo_h, bin_hi_h = get_N_1d(uds_H[z_sel]) + h = H(mag_idx_h, sig_h, bin_lo_h, bin_hi_h, N_1d_h) + n_bins += bin_lo_h.size + # 1D (K) mag_idx_k = 7 N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) k = K(mag_idx_k, sig_k, bin_lo_k, bin_hi_k, N_1d_k) n_bins += bin_lo_k.size - app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, r, k)) + app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, g, r, i, z, j, h, k)) ############################################################################## diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index a41e80c8..fb4d6623 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -266,6 +266,7 @@ def get_sdss_data( # 1D (u - r | r) rbins = np.arange(sdss_r[z_sel].min(), sdss_r[z_sel].max(), 2) + print(rbins) col_idx = [0, 2] cond_idx = 2 @@ -336,6 +337,7 @@ def get_sdss_data( "R", ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], ) + I = namedtuple( "I", ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 5a534f28..b8389907 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run159 -model_nickname: run159_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run159/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run148 +model_nickname: run148_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run148/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,8 +11,8 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: True -plot_feniks: False +plot_sdss: False +plot_feniks: True plot_hizels: False plots: diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 6bf1d9ab..5a144ac2 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -213,7 +213,7 @@ ) if cfg["plots"]["plot_color_contours"]: - print("Generating FENIKS color contours plot...") + print("Generating FENIKS color contour plots...") plot_color_contours( ran_key, param_collection_fit, @@ -480,7 +480,7 @@ ) if cfg["plots"]["plot_color_contours"]: - print("Generating SDSS color contours plot...") + print("Generating SDSS color contour plots...") plot_color_contours( ran_key, param_collection_fit, From 46eaae55154ed6384fe43c099bff332241a1c8c1 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 12:50:15 -0500 Subject: [PATCH 28/57] 2D color-mag diagrams in fitting sdss and feniks --- .../experimental/data_loaders/load_feniks.py | 463 ++++-------------- .../experimental/data_loaders/load_sdss.py | 98 +--- diffhtwo/experimental/defaults.py | 11 + .../experimental/diagnostics/plot_contour.py | 17 +- diffhtwo/experimental/kernels/N_phot.py | 59 ++- 5 files changed, 179 insertions(+), 469 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 166386ca..e8634640 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -14,8 +14,11 @@ FENIKS_MAGK_THRESH, FENIKS_Z_MAX, FENIKS_Z_MIN, + AppMagFunc, + ColorColor, Dataset, FilterInfo, + MagColor, ) from ..latin_hypercube import latin_hypercube as lh from ..lightcone_generators import generate_lc_data @@ -347,16 +350,14 @@ def get_feniks_data( ############################################################################## # Z1 spaces: # 2D (g - r, r - i) - # 2D (u - g, r - K) - # 1D (u - g | K) - # 1D (r − i | K): residual quenching scatter at fixed stellar mass - # 1D (i - z | K): completely unconstrained so including it here - # 1D (J − H | K) + # 2D (K, g - r) + # 2D (K, r - i) + # 2D (K, J - H) colors = [] Z1 = namedtuple( "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "ug_rK", "ug", "ri", "iz", "jh"], + ["z_min", "z_max", "lc_data", "gr_ri", "K_ri", "K_gr", "K_JH"], ) zbin = 0 z_min = zbins[zbin][0] @@ -383,7 +384,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) # 2D (g - r, r - i) - Gr_ri = namedtuple("Gr_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + Gr_ri = namedtuple("Gr_ri", ColorColor._fields) mag_sel_gr_ri = ( (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) @@ -396,191 +397,57 @@ def get_feniks_data( gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) n_bins += bin_lo_gr_ri.size - # 2D (u - g, r - K) - Ug_rK = namedtuple("Ug_rK", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) - mag_sel_ugr = ( - (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) - & (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) - & (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) - ) - N_ug_rK, sig_ug_rK, bin_lo_ug_rK, bin_hi_ug_rK = get_N_2d( - megacam_hsc_uSg[z_sel][mag_sel_ugr], hsc_uds_rK[z_sel][mag_sel_ugr] - ) - col_idx = [0, 1, 7] - ug_rK = Ug_rK(col_idx, sig_ug_rK, bin_lo_ug_rK, bin_hi_ug_rK, N_ug_rK) - n_bins += bin_lo_ug_rK.size - - # 1D (u - g | K) - Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) - - ug = [] - Ug_condK = namedtuple( - "Ug_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) - mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( - hsc_g[z_sel] < feniks_mag_thresh.HSC_G - ) - col_idx = [0, 1] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( - megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] - ) - ug.append( - Ug_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_ug, - bin_lo_ug, - bin_hi_ug, - N_1d_ug, - ) - ) - n_bins += bin_lo_ug.size - - # 1D (r − i | K) - ri = [] - Ri_condK = namedtuple( - "Ri_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + # 2D (K, r - i) + K_ri = namedtuple("K_ri", MagColor._fields) mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( hsc_i[z_sel] < feniks_mag_thresh.HSC_I ) + N_K_ri, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri = get_N_2d( + uds_K[z_sel][mag_sel_ri], hsc_ri[z_sel][mag_sel_ri] + ) + mag_idx = 7 col_idx = [2, 3] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ri, sig_ri, bin_lo_ri, bin_hi_ri = get_N_1d( - hsc_ri[z_sel][mag_sel_ri & K_sel] - ) - ri.append( - Ri_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_ri, - bin_lo_ri, - bin_hi_ri, - N_1d_ri, - ) - ) - n_bins += bin_lo_ri.size + K_ri = K_ri(mag_idx, col_idx, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri, N_K_ri) + n_bins += bin_lo_K_ri.size - # 1D (i − z | K) - iz = [] - Iz_condK = namedtuple( - "Iz_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], + # 2D (K, g - r) + K_gr = namedtuple("K_gr", MagColor._fields) + mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( + hsc_r[z_sel] < feniks_mag_thresh.HSC_R ) - mag_sel_iz = (hsc_i[z_sel] < feniks_mag_thresh.HSC_I) & ( - hsc_z[z_sel] < feniks_mag_thresh.HSC_Z + N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( + uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr] ) - col_idx = [3, 4] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_iz, sig_iz, bin_lo_iz, bin_hi_iz = get_N_1d( - hsc_iz[z_sel][mag_sel_iz & K_sel] - ) - iz.append( - Iz_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_iz, - bin_lo_iz, - bin_hi_iz, - N_1d_iz, - ) - ) - n_bins += bin_lo_iz.size + mag_idx = 7 + col_idx = [1, 2] + K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr) + n_bins += bin_lo_K_gr.size - # 1D (J − H | K) - jh = [] - JH_condK = namedtuple( - "JH_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) - mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + # 2D (K, J - H) + K_JH = namedtuple("K_JH", MagColor._fields) + mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( uds_H[z_sel] < feniks_mag_thresh.UDS_H ) + N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( + uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH] + ) + mag_idx = 7 col_idx = [5, 6] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( - uds_JH[z_sel][mag_sel_jh & K_sel] - ) - jh.append( - JH_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_jh, - bin_lo_jh, - bin_hi_jh, - N_1d_jh, - ) - ) - n_bins += bin_lo_jh.size + K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH) + n_bins += bin_lo_K_JH.size - z1 = Z1(z_min, z_max, lc_data, gr_ri, ug_rK, ug, ri, iz, jh) + z1 = Z1(z_min, z_max, lc_data, gr_ri, K_ri, K_gr, K_JH) colors.append(z1) ############################################################################## # Z2 spaces: # 2D (r - z, z - J) - # 1D (u - g | K) - # 1D (r − z | K): residual quenching scatter at fixed stellar mass - # 1D (J − H | K) + # 2D (K, u - g) + # 2D (K, r - z) Z2 = namedtuple( "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh"], + ["z_min", "z_max", "lc_data", "rz_zJ", "K_ug", "K_rz"], ) zbin = 1 z_min = zbins[zbin][0] @@ -607,7 +474,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) # 2D (r - z, z - J) - Rz_zJ = namedtuple("Rz_zJ", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + Rz_zJ = namedtuple("Rz_zJ", ColorColor._fields) mag_sel_rz_zJ = ( (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & (hsc_z[z_sel] < feniks_mag_thresh.HSC_Z) @@ -620,126 +487,46 @@ def get_feniks_data( rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ) n_bins += bin_lo_rz_zJ.size - # 1D (u - g | K) - Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) - - ug = [] + # 2D (K, u - g) + K_ug = namedtuple("K_ug", MagColor._fields) mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( hsc_g[z_sel] < feniks_mag_thresh.HSC_G ) + N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( + uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug] + ) + mag_idx = 7 col_idx = [0, 1] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( - megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] - ) - ug.append( - Ug_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_ug, - bin_lo_ug, - bin_hi_ug, - N_1d_ug, - ) - ) - n_bins += bin_lo_ug.size + K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug) + n_bins += bin_lo_K_ug.size - # 1D (r - z | K) - rz = [] - Rz_condK = namedtuple( - "Rz_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + # 2D (K, r - z) + K_rz = namedtuple("K_rz", MagColor._fields) mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( hsc_z[z_sel] < feniks_mag_thresh.HSC_Z ) - col_idx = [2, 4] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_rz, sig_rz, bin_lo_rz, bin_hi_rz = get_N_1d( - hsc_rz[z_sel][mag_sel_rz & K_sel] - ) - rz.append( - Rz_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_rz, - bin_lo_rz, - bin_hi_rz, - N_1d_rz, - ) - ) - n_bins += bin_lo_rz.size - - # 1D (J − H | K) - jh = [] - JH_condK = namedtuple( - "JH_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) - mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - uds_H[z_sel] < feniks_mag_thresh.UDS_H + N_K_rz, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz = get_N_2d( + uds_K[z_sel][mag_sel_rz], hsc_rz[z_sel][mag_sel_rz] ) - col_idx = [5, 6] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( - uds_JH[z_sel][mag_sel_jh & K_sel] - ) - jh.append( - JH_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_jh, - bin_lo_jh, - bin_hi_jh, - N_1d_jh, - ) - ) - n_bins += bin_lo_jh.size + mag_idx = 7 + col_idx = [2, 4] + K_rz = K_rz(mag_idx, col_idx, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz, N_K_rz) + n_bins += bin_lo_K_rz.size - z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh) + z2 = Z2(z_min, z_max, lc_data, rz_zJ, K_ug, K_rz) colors.append(z2) ############################################################################## # Z3 spaces: # 2D (z - J, J - H) # 2D (u - g, g - r) - # 1D (u - g | K) - # 1D (g - r | K) - # 1D (J − H | K): residual quenching scatter at fixed stellar mass + # 2D (K, u - g) + # 2D (K, g - r) + # 2D (K, J − H): residual quenching scatter at fixed stellar mass Z3 = namedtuple( "Z3", - ["z_min", "z_max", "lc_data", "zJ_JH", "ug_gr", "ug", "gr", "jh"], + ["z_min", "z_max", "lc_data", "zJ_JH", "ug_gr", "K_ug", "K_gr", "K_JH"], ) zbin = 2 z_min = zbins[zbin][0] @@ -766,7 +553,7 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) # 2D (z - J, J - H) - zJ_JH = namedtuple("zJ_JH", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + zJ_JH = namedtuple("zJ_JH", ColorColor._fields) mag_sel_zJ_JH = ( (hsc_z[z_sel] < feniks_mag_thresh.HSC_Z) & (uds_J[z_sel] < feniks_mag_thresh.UDS_J) @@ -780,7 +567,7 @@ def get_feniks_data( n_bins += bin_lo_zJ_JH.size # 2D (u - g, g - r) - Ug_gr = namedtuple("Ug_gr", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + Ug_gr = namedtuple("Ug_gr", ColorColor._fields) mag_sel_ugr = ( (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) @@ -793,113 +580,46 @@ def get_feniks_data( ug_gr = Ug_gr(col_idx, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr, N_ug_gr) n_bins += bin_lo_ug_gr.size - # 1D (u - g | K) - Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) - - ug = [] + # 2D (K, u - g) + K_ug = namedtuple("K_ug", MagColor._fields) mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( hsc_g[z_sel] < feniks_mag_thresh.HSC_G ) + N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( + uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug] + ) + mag_idx = 7 col_idx = [0, 1] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( - megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] - ) - ug.append( - Ug_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_ug, - bin_lo_ug, - bin_hi_ug, - N_1d_ug, - ) - ) - n_bins += bin_lo_ug.size + K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug) + n_bins += bin_lo_K_ug.size - # 1D (g - r | K) - gr = [] - Gr_condK = namedtuple( - "Gr_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + # 2D (K, g - r) + K_gr = namedtuple("K_gr", MagColor._fields) mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( hsc_r[z_sel] < feniks_mag_thresh.HSC_R ) + N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( + uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr] + ) + mag_idx = 7 col_idx = [1, 2] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_gr, sig_gr, bin_lo_gr, bin_hi_gr = get_N_1d( - hsc_gr[z_sel][mag_sel_gr & K_sel] - ) - gr.append( - Gr_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_gr, - bin_lo_gr, - bin_hi_gr, - N_1d_gr, - ) - ) - n_bins += bin_lo_gr.size + K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr) + n_bins += bin_lo_K_gr.size - # 1D (J − H | K) - jh = [] - JH_condK = namedtuple( - "JH_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) - mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + # 2D (K, J - H) + K_JH = namedtuple("K_JH", MagColor._fields) + mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( uds_H[z_sel] < feniks_mag_thresh.UDS_H ) + N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( + uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH] + ) + mag_idx = 7 col_idx = [5, 6] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( - uds_JH[z_sel][mag_sel_jh & K_sel] - ) - jh.append( - JH_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_jh, - bin_lo_jh, - bin_hi_jh, - N_1d_jh, - ) - ) - n_bins += bin_lo_jh.size + K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH) + n_bins += bin_lo_K_JH.size - z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh) + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, K_ug, K_gr, K_JH) colors.append(z3) ############################################################################## @@ -919,10 +639,7 @@ def get_feniks_data( "AppMagFuncs", ["z_min", "z_max", "lc_data", "u", "g", "r", "i", "z", "J", "H", "K"], ) - AppMagFunc = namedtuple( - "AppMagFunc", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) + U = namedtuple("U", AppMagFunc._fields) G = namedtuple("G", AppMagFunc._fields) R = namedtuple("R", AppMagFunc._fields) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index fb4d6623..aa5a1089 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -11,8 +11,11 @@ SDSS_MAGR_THRESH, SDSS_Z_MAX, SDSS_Z_MIN, + AppMagFunc, + ColorColor, Dataset, FilterInfo, + MagColor, ) from ..latin_hypercube import latin_hypercube as lh from ..lightcone_generators import generate_lc_data @@ -187,43 +190,19 @@ def get_sdss_data( ############################################################################## Colors = namedtuple( "Colors", - ["z_min", "z_max", "lc_data", "ur_ri", "gr_ri", "ur", "ri"], + ["z_min", "z_max", "lc_data", "ur_ri", "gr_ri", "R_ur", "R_ri"], ) # 2D (u - r, r - i) - Ur_ri = namedtuple("Ur_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + Ur_ri = namedtuple("Ur_ri", ColorColor._fields) # 2D (g - r, r - i) - Gr_ri = namedtuple("Gr_ri", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + Gr_ri = namedtuple("Gr_ri", ColorColor._fields) - # 1D (u - r | r) - Ur_condr = namedtuple( - "Ur_condr", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + # 2D (r, u - r) + R_ur = namedtuple("R_ur", MagColor._fields) - # 1D (r - i | r) - Ri_condr = namedtuple( - "Ri_condr", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + # 2D (r, r - i) + R_ri = namedtuple("R_ri", MagColor._fields) colors = [] for zbin in range(0, len(zbins)): @@ -264,50 +243,23 @@ def get_sdss_data( col_idx = [1, 2, 3] gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) - # 1D (u - r | r) - rbins = np.arange(sdss_r[z_sel].min(), sdss_r[z_sel].max(), 2) - print(rbins) - + # 2D (r, u - r) + N_r_ur, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur = get_N_2d( + sdss_r[z_sel], sdss_ur[z_sel] + ) + mag_idx = 2 col_idx = [0, 2] - cond_idx = 2 - ur = [] - for r in range(len(rbins) - 1): - r_sel = (sdss_r[z_sel] > rbins[r]) & (sdss_r[z_sel] <= rbins[r + 1]) - N_1d_ur, sig_ur, bin_lo_ur, bin_hi_ur = get_N_1d(sdss_ur[z_sel][r_sel]) - ur.append( - Ur_condr( - col_idx, - cond_idx, - rbins[r], - rbins[r + 1], - sig_ur, - bin_lo_ur, - bin_hi_ur, - N_1d_ur, - ) - ) - - # 1D (r - i | r) + r_ur = R_ur(mag_idx, col_idx, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur, N_r_ur) + + # 2D (r, r - i) + N_r_ri, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri = get_N_2d( + sdss_r[z_sel], sdss_ri[z_sel] + ) + mag_idx = 2 col_idx = [2, 3] - cond_idx = 2 - ri = [] - for r in range(len(rbins) - 1): - r_sel = (sdss_r[z_sel] > rbins[r]) & (sdss_r[z_sel] <= rbins[r + 1]) - N_1d_ri, sig_ri, bin_lo_ri, bin_hi_ri = get_N_1d(sdss_ri[z_sel][r_sel]) - ri.append( - Ri_condr( - col_idx, - cond_idx, - rbins[r], - rbins[r + 1], - sig_ri, - bin_lo_ri, - bin_hi_ri, - N_1d_ri, - ) - ) - - colors.append(Colors(z_min, z_max, lc_data, ur_ri, gr_ri, ur, ri)) + r_ri = R_ri(mag_idx, col_idx, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri, N_r_ri) + + colors.append(Colors(z_min, z_max, lc_data, ur_ri, gr_ri, r_ur, r_ri)) ############################################################################## ############################################################################## diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 4b78e6b2..015cfca4 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -47,3 +47,14 @@ "data_sky_area_degsq", ], ) + +ColorColor = namedtuple("ColorColor", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) + +MagColor = namedtuple( + "MagColor", ["mag_idx", "col_idx", "sig", "bin_lo", "bin_hi", "N_data"] +) + +AppMagFunc = namedtuple( + "AppMagFunc", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], +) diff --git a/diffhtwo/experimental/diagnostics/plot_contour.py b/diffhtwo/experimental/diagnostics/plot_contour.py index 2507b435..e3bfe5cf 100644 --- a/diffhtwo/experimental/diagnostics/plot_contour.py +++ b/diffhtwo/experimental/diagnostics/plot_contour.py @@ -30,7 +30,7 @@ def plot_density( - bin_lo, bin_hi, N, ax, xlabel, ylabel, cmap, N_model=None, sigma=0.55, n_levels=8 + bin_lo, bin_hi, N, ax, xlabel, ylabel, cmap, N_model=None, sigma=0.5, n_levels=10 ): x_edges = np.unique(np.append(bin_lo[:, 0], bin_hi[-1, 0])) y_edges = np.unique(np.append(bin_lo[:, 1], bin_hi[-1, 1])) @@ -102,6 +102,8 @@ def plot_color_contours( frac_cat, data_label, savedir, + sigma=0.5, + n_levels=10, ): for z in range(0, len(data)): z_data = data[z] @@ -137,6 +139,8 @@ def plot_color_contours( ylabel, dusk, N_model=space.N_model, + sigma=sigma, + n_levels=n_levels, ) fig.savefig( savedir @@ -154,11 +158,10 @@ def plot_color_contours( plt.close() -def parse_color_labels(name): - # "Ur_ri" → ["u-r", "r-i"] - x_str, y_str = name.lower().split("_") +def parse_axis_label(s): + return f"${s[0]}-{s[1]}$" if len(s) == 2 else f"${s}$" - def to_label(s): - return f"${s[0]} - {s[1]}$" - return to_label(x_str), to_label(y_str) +def parse_color_labels(name): + x_str, y_str = name.lower().split("_") + return parse_axis_label(x_str), parse_axis_label(y_str) diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index 1c0ce554..a969b9ab 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -65,25 +65,52 @@ def N_colors_mags( z_data = z_data._replace(**{fields[f]: new_list}) elif "mag_idx" in space._fields: - mag_idx = space.mag_idx - obs_mag = obs_mags[:, mag_idx] - obs_mag = obs_mag.reshape(obs_mag.size, 1) + if "col_idx" in space._fields: + col_idx = space.col_idx + mag_idx = space.mag_idx - # get mag_sel weight - mag_sel = obs_mags[:, mag_idx] < mag_thresh[mag_idx] - weight = jnp.where(mag_sel, gal_weight, 0.0) + mag = obs_mags[:, mag_idx] + obs_color = obs_mags[:, col_idx[0]] - obs_mags[:, col_idx[1]] + obs_mag_color = jnp.vstack((mag, obs_color)).T - N_model = diffndhist_lomem.tw_ndhist_weighted( - obs_mag, - space.sig, - weight, - space.bin_lo, - space.bin_hi, - ) + mag_sel = ( + (obs_mags[:, mag_idx] < mag_thresh[mag_idx]) + & (obs_mags[:, col_idx[0]] < mag_thresh[col_idx[0]]) + & (obs_mags[:, col_idx[1]] < mag_thresh[col_idx[1]]) + ) + weight = jnp.where(mag_sel, gal_weight, 0.0) - NewTuple = namedtuple(type(space).__name__, [*space._fields, "N_model"]) - new = NewTuple(*space, N_model) - z_data = z_data._replace(**{fields[f]: new}) + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_mag_color, + space.sig, + weight, + space.bin_lo, + space.bin_hi, + ) + + NewTuple = namedtuple(type(space).__name__, [*space._fields, "N_model"]) + new = NewTuple(*space, N_model) + z_data = z_data._replace(**{fields[f]: new}) + else: + mag_idx = space.mag_idx + obs_mag = obs_mags[:, mag_idx] + obs_mag = obs_mag.reshape(obs_mag.size, 1) + + # get mag_sel weight + mag_sel = obs_mags[:, mag_idx] < mag_thresh[mag_idx] + weight = jnp.where(mag_sel, gal_weight, 0.0) + + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_mag, + space.sig, + weight, + space.bin_lo, + space.bin_hi, + ) + + NewTuple = namedtuple(type(space).__name__, [*space._fields, "N_model"]) + new = NewTuple(*space, N_model) + z_data = z_data._replace(**{fields[f]: new}) else: col_idx = space.col_idx From c1c655c42ed7edd352e555b3b63a7c82ab2f192e Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 12:54:33 -0500 Subject: [PATCH 29/57] update N_phot and loss --- diffhtwo/experimental/kernels/N_phot.py | 40 ++----------------- .../experimental/loss_kernels/phot_loss.py | 27 ++++--------- 2 files changed, 11 insertions(+), 56 deletions(-) diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index a969b9ab..46824121 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -29,43 +29,9 @@ def N_colors_mags( for f in range(0, len(fields)): space = getattr(z_data, fields[f]) - if isinstance(space, list): - new_list = [] - for s in range(0, len(space)): - space_n = space[s] - col_idx = space_n.col_idx - - # get cond weight - obs_mags_cond = obs_mags[:, space_n.cond_idx] - cond = (obs_mags_cond > space_n.cond_min) & ( - obs_mags_cond <= space_n.cond_max - ) - weight = jnp.where(cond, gal_weight, 0.0) - - # get mag_sel weight - for c in range(0, len(col_idx)): - mag_sel = obs_mags[:, col_idx[c]] < mag_thresh[col_idx[c]] - weight *= jnp.where(mag_sel, 1.0, 0.0) - - obs_color = obs_mags[:, col_idx[0]] - obs_mags[:, col_idx[1]] - obs_color = obs_color.reshape(obs_color.size, 1) - - N_model = diffndhist_lomem.tw_ndhist_weighted( - obs_color, - space_n.sig, - weight, - space_n.bin_lo, - space_n.bin_hi, - ) - - NewTuple = namedtuple( - type(space_n).__name__, [*space_n._fields, "N_model"] - ) - new_list.append(NewTuple(*space_n, N_model)) - z_data = z_data._replace(**{fields[f]: new_list}) - - elif "mag_idx" in space._fields: + if "mag_idx" in space._fields: if "col_idx" in space._fields: + # Magnitude-Color space col_idx = space.col_idx mag_idx = space.mag_idx @@ -92,6 +58,7 @@ def N_colors_mags( new = NewTuple(*space, N_model) z_data = z_data._replace(**{fields[f]: new}) else: + # Apparent Magnitude space mag_idx = space.mag_idx obs_mag = obs_mags[:, mag_idx] obs_mag = obs_mag.reshape(obs_mag.size, 1) @@ -113,6 +80,7 @@ def N_colors_mags( z_data = z_data._replace(**{fields[f]: new}) else: + # Color-Color space col_idx = space.col_idx obs_colors = [] for c in range(0, len(col_idx) - 1): diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index aee02125..08856726 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -29,26 +29,13 @@ def get_phot_loss_2d_multiz( for f in range(0, len(fields)): space = getattr(z_data_model, fields[f]) - if isinstance(space, list): - for s in range(0, len(space)): - space_n = space[s] - - N_model = space_n.N_model - N_data = space_n.N_data - - N_model = N_model * ( - data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq - ) - phot_loss_2d += poisson_loss(N_model, N_data) - - else: - N_model = space.N_model - N_data = space.N_data - - N_model = N_model * ( - data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq - ) - phot_loss_2d += poisson_loss(N_model, N_data) + N_model = space.N_model + N_data = space.N_data + + N_model = N_model * ( + data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq + ) + phot_loss_2d += poisson_loss(N_model, N_data) return phot_loss_2d From c6b5bf0242ac6c85f3475fe18cc9229bb13f5430 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 13:01:45 -0500 Subject: [PATCH 30/57] Update load_sdss.py --- .../experimental/data_loaders/load_sdss.py | 26 ++++--------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index aa5a1089..7a886daf 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -277,27 +277,11 @@ def get_sdss_data( "AppMagFuncs", ["z_min", "z_max", "lc_data", "u", "g", "r", "i", "z"], ) - U = namedtuple( - "U", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - G = namedtuple( - "G", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - R = namedtuple( - "R", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - - I = namedtuple( - "I", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) - Z = namedtuple( - "Z", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], - ) + U = namedtuple("U", AppMagFunc._fields) + G = namedtuple("G", AppMagFunc._fields) + R = namedtuple("R", AppMagFunc._fields) + I = namedtuple("I", AppMagFunc._fields) + Z = namedtuple("Z", AppMagFunc._fields) app_mag_funcs = [] for zbin in range(0, len(fine_zbins)): From c5abb1941302eed2a0ecec89dfd85682c96b7231 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 16:08:42 -0500 Subject: [PATCH 31/57] enlarge mag bin width in CMD --- diffhtwo/experimental/data_loaders/N_utils.py | 8 ++- .../experimental/data_loaders/load_feniks.py | 17 +++--- .../experimental/data_loaders/load_sdss.py | 4 +- .../experimental/diagnostics/plot_contour.py | 56 +++++++++++++++++-- scripts/config_diagnostics.yaml | 40 ++++++------- 5 files changed, 86 insertions(+), 39 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/N_utils.py b/diffhtwo/experimental/data_loaders/N_utils.py index e6dd4241..6fd72041 100644 --- a/diffhtwo/experimental/data_loaders/N_utils.py +++ b/diffhtwo/experimental/data_loaders/N_utils.py @@ -28,10 +28,14 @@ def get_N_1d(dim1, dim1_bin_edges=None, dmag=0.2, sig_scale=0.5): ) -def get_N_2d(dim1, dim2, sig_scale=0.5): +def get_N_2d(dim1, dim2, sig_scale=0.5, dim1_is_mag=False): dataset = np.vstack((dim1, dim2)).T - dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 11) + if dim1_is_mag: + dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 4) + else: + dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 11) + dim2_bin_edges = np.linspace(dim2.min(), dim2.max(), 11) dim1_lo = dim1_bin_edges[:-1] diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index e8634640..a7729f3a 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -283,7 +283,6 @@ def get_feniks_data( hsc_ri = hsc_r - hsc_i hsc_iz = hsc_i - hsc_z hsc_uds_zJ = hsc_z - uds_J - hsc_uds_rK = hsc_r - uds_K uds_JH = uds_J - uds_H uds_HK = uds_H - uds_K @@ -403,7 +402,7 @@ def get_feniks_data( hsc_i[z_sel] < feniks_mag_thresh.HSC_I ) N_K_ri, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri = get_N_2d( - uds_K[z_sel][mag_sel_ri], hsc_ri[z_sel][mag_sel_ri] + uds_K[z_sel][mag_sel_ri], hsc_ri[z_sel][mag_sel_ri], dim1_is_mag=True ) mag_idx = 7 col_idx = [2, 3] @@ -416,7 +415,7 @@ def get_feniks_data( hsc_r[z_sel] < feniks_mag_thresh.HSC_R ) N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( - uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr] + uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr], dim1_is_mag=True ) mag_idx = 7 col_idx = [1, 2] @@ -429,7 +428,7 @@ def get_feniks_data( uds_H[z_sel] < feniks_mag_thresh.UDS_H ) N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( - uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH] + uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH], dim1_is_mag=True ) mag_idx = 7 col_idx = [5, 6] @@ -493,7 +492,7 @@ def get_feniks_data( hsc_g[z_sel] < feniks_mag_thresh.HSC_G ) N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( - uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug] + uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug], dim1_is_mag=True ) mag_idx = 7 col_idx = [0, 1] @@ -506,7 +505,7 @@ def get_feniks_data( hsc_z[z_sel] < feniks_mag_thresh.HSC_Z ) N_K_rz, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz = get_N_2d( - uds_K[z_sel][mag_sel_rz], hsc_rz[z_sel][mag_sel_rz] + uds_K[z_sel][mag_sel_rz], hsc_rz[z_sel][mag_sel_rz], dim1_is_mag=True ) mag_idx = 7 col_idx = [2, 4] @@ -586,7 +585,7 @@ def get_feniks_data( hsc_g[z_sel] < feniks_mag_thresh.HSC_G ) N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( - uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug] + uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug], dim1_is_mag=True ) mag_idx = 7 col_idx = [0, 1] @@ -599,7 +598,7 @@ def get_feniks_data( hsc_r[z_sel] < feniks_mag_thresh.HSC_R ) N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( - uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr] + uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr], dim1_is_mag=True ) mag_idx = 7 col_idx = [1, 2] @@ -612,7 +611,7 @@ def get_feniks_data( uds_H[z_sel] < feniks_mag_thresh.UDS_H ) N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( - uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH] + uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH], dim1_is_mag=True ) mag_idx = 7 col_idx = [5, 6] diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 7a886daf..88927a75 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -245,7 +245,7 @@ def get_sdss_data( # 2D (r, u - r) N_r_ur, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur = get_N_2d( - sdss_r[z_sel], sdss_ur[z_sel] + sdss_r[z_sel], sdss_ur[z_sel], dim1_is_mag=True ) mag_idx = 2 col_idx = [0, 2] @@ -253,7 +253,7 @@ def get_sdss_data( # 2D (r, r - i) N_r_ri, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri = get_N_2d( - sdss_r[z_sel], sdss_ri[z_sel] + sdss_r[z_sel], sdss_ri[z_sel], dim1_is_mag=True ) mag_idx = 2 col_idx = [2, 3] diff --git a/diffhtwo/experimental/diagnostics/plot_contour.py b/diffhtwo/experimental/diagnostics/plot_contour.py index e3bfe5cf..9337092a 100644 --- a/diffhtwo/experimental/diagnostics/plot_contour.py +++ b/diffhtwo/experimental/diagnostics/plot_contour.py @@ -1,3 +1,5 @@ +import matplotlib.lines as mlines +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import LinearSegmentedColormap @@ -30,7 +32,17 @@ def plot_density( - bin_lo, bin_hi, N, ax, xlabel, ylabel, cmap, N_model=None, sigma=0.5, n_levels=10 + bin_lo, + bin_hi, + N, + ax, + xlabel, + ylabel, + cmap, + data_label, + N_model=None, + sigma=0.5, + n_levels=10, ): x_edges = np.unique(np.append(bin_lo[:, 0], bin_hi[-1, 0])) y_edges = np.unique(np.append(bin_lo[:, 1], bin_hi[-1, 1])) @@ -45,6 +57,9 @@ def plot_density( levels = np.linspace(Z.min(), Z.max(), n_levels) qm = ax.contourf(xc, yc, Z, levels=levels, cmap=cmap, alpha=0.5) ax.get_figure().colorbar(qm, ax=ax, label=r"$\log_{10}(N / N_{\rm tot})$") + + legend_handles = [mpatches.Patch(color=cmap(0.7), alpha=0.5, label=data_label)] + if N_model is not None: Z_model = np.log10( gaussian_filter( @@ -64,8 +79,28 @@ def plot_density( alpha=0.9, linestyles="dashed", ) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) + legend_handles.append( + mlines.Line2D( + [], + [], + color=cmap(0.7), + linewidth=1.5, + linestyle="dashed", + alpha=0.9, + label="diffsky", + ) + ) + + ax.legend( + handles=legend_handles, + loc="upper center", + bbox_to_anchor=(0.5, 1.1), + ncol=len(legend_handles), + frameon=False, + fontsize=12, + ) + ax.set_xlabel(xlabel, fontsize=14) + ax.set_ylabel(ylabel, fontsize=14) def plot_density_raw(bin_lo, bin_hi, N, ax, xlabel, ylabel, cmap, N_model=None): @@ -126,8 +161,9 @@ def plot_color_contours( pass else: - fig, ax = plt.subplots(constrained_layout=True) - fig.suptitle(str(z_min) + " < z < " + str(z_max)) + fig, ax = plt.subplots(figsize=(6.4, 5.2), constrained_layout=True) + fig.get_layout_engine().set(h_pad=0.0, hspace=0.0, rect=(0, 0, 1, 0.98)) + fig.suptitle(str(z_min) + " < z < " + str(z_max), fontsize=14) name = type(space).__name__ xlabel, ylabel = parse_color_labels(name) plot_density( @@ -138,6 +174,7 @@ def plot_color_contours( xlabel, ylabel, dusk, + data_label, N_model=space.N_model, sigma=sigma, n_levels=n_levels, @@ -159,7 +196,14 @@ def plot_color_contours( def parse_axis_label(s): - return f"${s[0]}-{s[1]}$" if len(s) == 2 else f"${s}$" + nir_bands = {"j", "h", "k"} + + def fmt(b): + return b.upper() if b in nir_bands else b + + if len(s) == 2: + return f"${fmt(s[0])}-{fmt(s[1])}$" + return f"${fmt(s)}$" def parse_color_labels(name): diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index b8389907..68316ea6 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run148 -model_nickname: run148_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run148/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run163 +model_nickname: run163_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run163/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,26 +11,26 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: False -plot_feniks: True +plot_sdss: True +plot_feniks: False plot_hizels: False plots: num_halos : 3000 plot_color_contours: True - plot_app_mag_funcs: False - plot_color_pdfs: False - plot_colors_mags: False - plot_mags: False - plot_ssperr: False - plot_massive_cen_colors: False - plot_merging_sat_colors: False + plot_app_mag_funcs: True + plot_color_pdfs: True + plot_colors_mags: True + plot_mags: True + plot_ssperr: True + plot_massive_cen_colors: True + plot_merging_sat_colors: True plot_satquench: False - plot_satquench_model: False - plot_insitu_smhm: False - plot_insitu_sm: False - plot_uvj: False - plot_exsitu_frac: False - plot_avpop: False - plot_burstpop: False - plot_fburst_mh_z: False \ No newline at end of file + plot_satquench_model: True + plot_insitu_smhm: True + plot_insitu_sm: True + plot_uvj: True + plot_exsitu_frac: True + plot_avpop: True + plot_burstpop: True + plot_fburst_mh_z: True \ No newline at end of file From 37a03214c7aaac293b11bbbc561158a25f31b9da Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 17:08:15 -0500 Subject: [PATCH 32/57] reintroduce conditional colors --- .../experimental/data_loaders/load_feniks.py | 522 +++++++++++++++--- .../experimental/diagnostics/plot_contour.py | 2 +- diffhtwo/experimental/kernels/N_phot.py | 38 +- .../experimental/loss_kernels/phot_loss.py | 27 +- scripts/config_diagnostics.yaml | 10 +- 5 files changed, 509 insertions(+), 90 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index a7729f3a..b00eedae 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -356,7 +356,7 @@ def get_feniks_data( colors = [] Z1 = namedtuple( "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "K_ri", "K_gr", "K_JH"], + ["z_min", "z_max", "lc_data", "gr_ri", "ug", "ri", "iz", "jh"], ) zbin = 0 z_min = zbins[zbin][0] @@ -396,46 +396,204 @@ def get_feniks_data( gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) n_bins += bin_lo_gr_ri.size - # 2D (K, r - i) - K_ri = namedtuple("K_ri", MagColor._fields) + # 1D (u - g | K) + Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) + + ug = [] + Ug_condK = namedtuple( + "Ug_condK", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], + ) + mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( + hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ) + col_idx = [0, 1] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( + megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] + ) + ug.append( + Ug_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_ug, + bin_lo_ug, + bin_hi_ug, + N_1d_ug, + ) + ) + n_bins += bin_lo_ug.size + + # 1D (r − i | K) + ri = [] + Ri_condK = namedtuple( + "Ri_condK", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], + ) mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( hsc_i[z_sel] < feniks_mag_thresh.HSC_I ) - N_K_ri, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri = get_N_2d( - uds_K[z_sel][mag_sel_ri], hsc_ri[z_sel][mag_sel_ri], dim1_is_mag=True - ) - mag_idx = 7 col_idx = [2, 3] - K_ri = K_ri(mag_idx, col_idx, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri, N_K_ri) - n_bins += bin_lo_K_ri.size + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_ri, sig_ri, bin_lo_ri, bin_hi_ri = get_N_1d( + hsc_ri[z_sel][mag_sel_ri & K_sel] + ) + ri.append( + Ri_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_ri, + bin_lo_ri, + bin_hi_ri, + N_1d_ri, + ) + ) + n_bins += bin_lo_ri.size - # 2D (K, g - r) - K_gr = namedtuple("K_gr", MagColor._fields) - mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( - hsc_r[z_sel] < feniks_mag_thresh.HSC_R + # 1D (i − z | K) + iz = [] + Iz_condK = namedtuple( + "Iz_condK", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) - N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( - uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr], dim1_is_mag=True + mag_sel_iz = (hsc_i[z_sel] < feniks_mag_thresh.HSC_I) & ( + hsc_z[z_sel] < feniks_mag_thresh.HSC_Z ) - mag_idx = 7 - col_idx = [1, 2] - K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr) - n_bins += bin_lo_K_gr.size + col_idx = [3, 4] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_iz, sig_iz, bin_lo_iz, bin_hi_iz = get_N_1d( + hsc_iz[z_sel][mag_sel_iz & K_sel] + ) + iz.append( + Iz_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_iz, + bin_lo_iz, + bin_hi_iz, + N_1d_iz, + ) + ) + n_bins += bin_lo_iz.size - # 2D (K, J - H) - K_JH = namedtuple("K_JH", MagColor._fields) - mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - uds_H[z_sel] < feniks_mag_thresh.UDS_H + # 1D (J − H | K) + jh = [] + JH_condK = namedtuple( + "JH_condK", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) - N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( - uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH], dim1_is_mag=True + mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + uds_H[z_sel] < feniks_mag_thresh.UDS_H ) - mag_idx = 7 col_idx = [5, 6] - K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH) - n_bins += bin_lo_K_JH.size - - z1 = Z1(z_min, z_max, lc_data, gr_ri, K_ri, K_gr, K_JH) + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( + uds_JH[z_sel][mag_sel_jh & K_sel] + ) + jh.append( + JH_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_jh, + bin_lo_jh, + bin_hi_jh, + N_1d_jh, + ) + ) + n_bins += bin_lo_jh.size + + # # 2D (K, r - i) + # K_ri = namedtuple("K_ri", MagColor._fields) + # mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( + # hsc_i[z_sel] < feniks_mag_thresh.HSC_I + # ) + # N_K_ri, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri = get_N_2d( + # uds_K[z_sel][mag_sel_ri], hsc_ri[z_sel][mag_sel_ri], dim1_is_mag=True + # ) + # mag_idx = 7 + # col_idx = [2, 3] + # K_ri = K_ri(mag_idx, col_idx, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri, N_K_ri) + # n_bins += bin_lo_K_ri.size + + # # 2D (K, g - r) + # K_gr = namedtuple("K_gr", MagColor._fields) + # mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( + # hsc_r[z_sel] < feniks_mag_thresh.HSC_R + # ) + # N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( + # uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr], dim1_is_mag=True + # ) + # mag_idx = 7 + # col_idx = [1, 2] + # K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr) + # n_bins += bin_lo_K_gr.size + + # # 2D (K, J - H) + # K_JH = namedtuple("K_JH", MagColor._fields) + # mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + # uds_H[z_sel] < feniks_mag_thresh.UDS_H + # ) + # N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( + # uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH], dim1_is_mag=True + # ) + # mag_idx = 7 + # col_idx = [5, 6] + # K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH) + # n_bins += bin_lo_K_JH.size + + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, ri, iz, jh) colors.append(z1) ############################################################################## @@ -446,7 +604,7 @@ def get_feniks_data( Z2 = namedtuple( "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "K_ug", "K_rz"], + ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh"], ) zbin = 1 z_min = zbins[zbin][0] @@ -486,33 +644,139 @@ def get_feniks_data( rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ) n_bins += bin_lo_rz_zJ.size - # 2D (K, u - g) - K_ug = namedtuple("K_ug", MagColor._fields) + # 1D (u - g | K) + Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) + + ug = [] mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( hsc_g[z_sel] < feniks_mag_thresh.HSC_G ) - N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( - uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug], dim1_is_mag=True - ) - mag_idx = 7 col_idx = [0, 1] - K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug) - n_bins += bin_lo_K_ug.size + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( + megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] + ) + ug.append( + Ug_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_ug, + bin_lo_ug, + bin_hi_ug, + N_1d_ug, + ) + ) + n_bins += bin_lo_ug.size - # 2D (K, r - z) - K_rz = namedtuple("K_rz", MagColor._fields) + # 1D (r - z | K) + rz = [] + Rz_condK = namedtuple( + "Rz_condK", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], + ) mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( hsc_z[z_sel] < feniks_mag_thresh.HSC_Z ) - N_K_rz, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz = get_N_2d( - uds_K[z_sel][mag_sel_rz], hsc_rz[z_sel][mag_sel_rz], dim1_is_mag=True - ) - mag_idx = 7 col_idx = [2, 4] - K_rz = K_rz(mag_idx, col_idx, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz, N_K_rz) - n_bins += bin_lo_K_rz.size + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_rz, sig_rz, bin_lo_rz, bin_hi_rz = get_N_1d( + hsc_rz[z_sel][mag_sel_rz & K_sel] + ) + rz.append( + Rz_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_rz, + bin_lo_rz, + bin_hi_rz, + N_1d_rz, + ) + ) + n_bins += bin_lo_rz.size - z2 = Z2(z_min, z_max, lc_data, rz_zJ, K_ug, K_rz) + # 1D (J − H | K) + jh = [] + JH_condK = namedtuple( + "JH_condK", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], + ) + mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + uds_H[z_sel] < feniks_mag_thresh.UDS_H + ) + col_idx = [5, 6] + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( + uds_JH[z_sel][mag_sel_jh & K_sel] + ) + jh.append( + JH_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_jh, + bin_lo_jh, + bin_hi_jh, + N_1d_jh, + ) + ) + n_bins += bin_lo_jh.size + + # # 2D (K, u - g) + # K_ug = namedtuple("K_ug", MagColor._fields) + # mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( + # hsc_g[z_sel] < feniks_mag_thresh.HSC_G + # ) + # N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( + # uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug], dim1_is_mag=True + # ) + # mag_idx = 7 + # col_idx = [0, 1] + # K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug) + # n_bins += bin_lo_K_ug.size + + # # 2D (K, r - z) + # K_rz = namedtuple("K_rz", MagColor._fields) + # mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( + # hsc_z[z_sel] < feniks_mag_thresh.HSC_Z + # ) + # N_K_rz, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz = get_N_2d( + # uds_K[z_sel][mag_sel_rz], hsc_rz[z_sel][mag_sel_rz], dim1_is_mag=True + # ) + # mag_idx = 7 + # col_idx = [2, 4] + # K_rz = K_rz(mag_idx, col_idx, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz, N_K_rz) + # n_bins += bin_lo_K_rz.size + + z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh) colors.append(z2) ############################################################################## @@ -525,7 +789,7 @@ def get_feniks_data( Z3 = namedtuple( "Z3", - ["z_min", "z_max", "lc_data", "zJ_JH", "ug_gr", "K_ug", "K_gr", "K_JH"], + ["z_min", "z_max", "lc_data", "zJ_JH", "ug_gr", "ug", "gr", "jh"], ) zbin = 2 z_min = zbins[zbin][0] @@ -579,46 +843,152 @@ def get_feniks_data( ug_gr = Ug_gr(col_idx, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr, N_ug_gr) n_bins += bin_lo_ug_gr.size - # 2D (K, u - g) - K_ug = namedtuple("K_ug", MagColor._fields) + # 1D (u - g | K) + Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) + + ug = [] mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( hsc_g[z_sel] < feniks_mag_thresh.HSC_G ) - N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( - uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug], dim1_is_mag=True - ) - mag_idx = 7 col_idx = [0, 1] - K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug) - n_bins += bin_lo_K_ug.size + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( + megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] + ) + ug.append( + Ug_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_ug, + bin_lo_ug, + bin_hi_ug, + N_1d_ug, + ) + ) + n_bins += bin_lo_ug.size - # 2D (K, g - r) - K_gr = namedtuple("K_gr", MagColor._fields) + # 1D (g - r | K) + gr = [] + Gr_condK = namedtuple( + "Gr_condK", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], + ) mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( hsc_r[z_sel] < feniks_mag_thresh.HSC_R ) - N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( - uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr], dim1_is_mag=True - ) - mag_idx = 7 col_idx = [1, 2] - K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr) - n_bins += bin_lo_K_gr.size + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_gr, sig_gr, bin_lo_gr, bin_hi_gr = get_N_1d( + hsc_gr[z_sel][mag_sel_gr & K_sel] + ) + gr.append( + Gr_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_gr, + bin_lo_gr, + bin_hi_gr, + N_1d_gr, + ) + ) + n_bins += bin_lo_gr.size - # 2D (K, J - H) - K_JH = namedtuple("K_JH", MagColor._fields) - mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - uds_H[z_sel] < feniks_mag_thresh.UDS_H + # 1D (J − H | K) + jh = [] + JH_condK = namedtuple( + "JH_condK", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + ], ) - N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( - uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH], dim1_is_mag=True + mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + uds_H[z_sel] < feniks_mag_thresh.UDS_H ) - mag_idx = 7 col_idx = [5, 6] - K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH) - n_bins += bin_lo_K_JH.size - - z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, K_ug, K_gr, K_JH) + cond_idx = 7 + for k in range(len(Kbins) - 1): + K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) + N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( + uds_JH[z_sel][mag_sel_jh & K_sel] + ) + jh.append( + JH_condK( + col_idx, + cond_idx, + Kbins[k], + Kbins[k + 1], + sig_jh, + bin_lo_jh, + bin_hi_jh, + N_1d_jh, + ) + ) + n_bins += bin_lo_jh.size + + # # 2D (K, u - g) + # K_ug = namedtuple("K_ug", MagColor._fields) + # mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( + # hsc_g[z_sel] < feniks_mag_thresh.HSC_G + # ) + # N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( + # uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug], dim1_is_mag=True + # ) + # mag_idx = 7 + # col_idx = [0, 1] + # K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug) + # n_bins += bin_lo_K_ug.size + + # # 2D (K, g - r) + # K_gr = namedtuple("K_gr", MagColor._fields) + # mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( + # hsc_r[z_sel] < feniks_mag_thresh.HSC_R + # ) + # N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( + # uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr], dim1_is_mag=True + # ) + # mag_idx = 7 + # col_idx = [1, 2] + # K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr) + # n_bins += bin_lo_K_gr.size + + # # 2D (K, J - H) + # K_JH = namedtuple("K_JH", MagColor._fields) + # mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + # uds_H[z_sel] < feniks_mag_thresh.UDS_H + # ) + # N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( + # uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH], dim1_is_mag=True + # ) + # mag_idx = 7 + # col_idx = [5, 6] + # K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH) + # n_bins += bin_lo_K_JH.size + + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh) colors.append(z3) ############################################################################## diff --git a/diffhtwo/experimental/diagnostics/plot_contour.py b/diffhtwo/experimental/diagnostics/plot_contour.py index 9337092a..0b931e9b 100644 --- a/diffhtwo/experimental/diagnostics/plot_contour.py +++ b/diffhtwo/experimental/diagnostics/plot_contour.py @@ -162,7 +162,7 @@ def plot_color_contours( else: fig, ax = plt.subplots(figsize=(6.4, 5.2), constrained_layout=True) - fig.get_layout_engine().set(h_pad=0.0, hspace=0.0, rect=(0, 0, 1, 0.98)) + fig.get_layout_engine().set(h_pad=0.0, hspace=0.0, rect=(0, 0, 1, 0.96)) fig.suptitle(str(z_min) + " < z < " + str(z_max), fontsize=14) name = type(space).__name__ xlabel, ylabel = parse_color_labels(name) diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index 46824121..3429e357 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -29,7 +29,43 @@ def N_colors_mags( for f in range(0, len(fields)): space = getattr(z_data, fields[f]) - if "mag_idx" in space._fields: + if isinstance(space, list): + # Colors conditioned on mag space + new_list = [] + for s in range(0, len(space)): + space_n = space[s] + col_idx = space_n.col_idx + + # get cond weight + obs_mags_cond = obs_mags[:, space_n.cond_idx] + cond = (obs_mags_cond > space_n.cond_min) & ( + obs_mags_cond <= space_n.cond_max + ) + weight = jnp.where(cond, gal_weight, 0.0) + + # get mag_sel weight + for c in range(0, len(col_idx)): + mag_sel = obs_mags[:, col_idx[c]] < mag_thresh[col_idx[c]] + weight *= jnp.where(mag_sel, 1.0, 0.0) + + obs_color = obs_mags[:, col_idx[0]] - obs_mags[:, col_idx[1]] + obs_color = obs_color.reshape(obs_color.size, 1) + + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_color, + space_n.sig, + weight, + space_n.bin_lo, + space_n.bin_hi, + ) + + NewTuple = namedtuple( + type(space_n).__name__, [*space_n._fields, "N_model"] + ) + new_list.append(NewTuple(*space_n, N_model)) + z_data = z_data._replace(**{fields[f]: new_list}) + + elif "mag_idx" in space._fields: if "col_idx" in space._fields: # Magnitude-Color space col_idx = space.col_idx diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index 08856726..aee02125 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -29,13 +29,26 @@ def get_phot_loss_2d_multiz( for f in range(0, len(fields)): space = getattr(z_data_model, fields[f]) - N_model = space.N_model - N_data = space.N_data - - N_model = N_model * ( - data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq - ) - phot_loss_2d += poisson_loss(N_model, N_data) + if isinstance(space, list): + for s in range(0, len(space)): + space_n = space[s] + + N_model = space_n.N_model + N_data = space_n.N_data + + N_model = N_model * ( + data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq + ) + phot_loss_2d += poisson_loss(N_model, N_data) + + else: + N_model = space.N_model + N_data = space.N_data + + N_model = N_model * ( + data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq + ) + phot_loss_2d += poisson_loss(N_model, N_data) return phot_loss_2d diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 68316ea6..504ab7a1 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run163 -model_nickname: run163_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run163/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run164 +model_nickname: run164_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run164/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,8 +11,8 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: True -plot_feniks: False +plot_sdss: False +plot_feniks: True plot_hizels: False plots: From 59c0398f794d04c4b6e11a3aadf103d8afec86f6 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 17:56:47 -0500 Subject: [PATCH 33/57] fit only u,rK app mag func for feniks --- diffhtwo/experimental/data_loaders/N_utils.py | 8 +- .../experimental/data_loaders/load_feniks.py | 389 ++++++++---------- diffhtwo/experimental/defaults.py | 23 +- .../experimental/loss_kernels/phot_loss.py | 23 +- scripts/config_diagnostics.yaml | 6 +- 5 files changed, 200 insertions(+), 249 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/N_utils.py b/diffhtwo/experimental/data_loaders/N_utils.py index 6fd72041..e6dd4241 100644 --- a/diffhtwo/experimental/data_loaders/N_utils.py +++ b/diffhtwo/experimental/data_loaders/N_utils.py @@ -28,14 +28,10 @@ def get_N_1d(dim1, dim1_bin_edges=None, dmag=0.2, sig_scale=0.5): ) -def get_N_2d(dim1, dim2, sig_scale=0.5, dim1_is_mag=False): +def get_N_2d(dim1, dim2, sig_scale=0.5): dataset = np.vstack((dim1, dim2)).T - if dim1_is_mag: - dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 4) - else: - dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 11) - + dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 11) dim2_bin_edges = np.linspace(dim2.min(), dim2.max(), 11) dim1_lo = dim1_bin_edges[:-1] diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index b00eedae..de8e9a6d 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -16,6 +16,7 @@ FENIKS_Z_MIN, AppMagFunc, ColorColor, + ColorCondMag, Dataset, FilterInfo, MagColor, @@ -356,7 +357,19 @@ def get_feniks_data( colors = [] Z1 = namedtuple( "Z1", - ["z_min", "z_max", "lc_data", "gr_ri", "ug", "ri", "iz", "jh"], + [ + "z_min", + "z_max", + "lc_data", + "gr_ri", + "ug", + "ri", + "iz", + "jh", + "K_ri", + "K_gr", + "K_JH", + ], ) zbin = 0 z_min = zbins[zbin][0] @@ -393,26 +406,14 @@ def get_feniks_data( hsc_gr[z_sel][mag_sel_gr_ri], hsc_ri[z_sel][mag_sel_gr_ri] ) col_idx = [1, 2, 3] - gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) + gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri, True) n_bins += bin_lo_gr_ri.size # 1D (u - g | K) Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) ug = [] - Ug_condK = namedtuple( - "Ug_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + Ug_condK = namedtuple("Ug_condK", ColorCondMag._fields) mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( hsc_g[z_sel] < feniks_mag_thresh.HSC_G ) @@ -433,25 +434,14 @@ def get_feniks_data( bin_lo_ug, bin_hi_ug, N_1d_ug, + True, ) ) n_bins += bin_lo_ug.size # 1D (r − i | K) ri = [] - Ri_condK = namedtuple( - "Ri_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + Ri_condK = namedtuple("Ri_condK", ColorCondMag._fields) mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( hsc_i[z_sel] < feniks_mag_thresh.HSC_I ) @@ -472,25 +462,14 @@ def get_feniks_data( bin_lo_ri, bin_hi_ri, N_1d_ri, + True, ) ) n_bins += bin_lo_ri.size # 1D (i − z | K) iz = [] - Iz_condK = namedtuple( - "Iz_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + Iz_condK = namedtuple("Iz_condK", ColorCondMag._fields) mag_sel_iz = (hsc_i[z_sel] < feniks_mag_thresh.HSC_I) & ( hsc_z[z_sel] < feniks_mag_thresh.HSC_Z ) @@ -511,25 +490,14 @@ def get_feniks_data( bin_lo_iz, bin_hi_iz, N_1d_iz, + True, ) ) n_bins += bin_lo_iz.size # 1D (J − H | K) jh = [] - JH_condK = namedtuple( - "JH_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + JH_condK = namedtuple("JH_condK", ColorCondMag._fields) mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( uds_H[z_sel] < feniks_mag_thresh.UDS_H ) @@ -550,50 +518,51 @@ def get_feniks_data( bin_lo_jh, bin_hi_jh, N_1d_jh, + True, ) ) n_bins += bin_lo_jh.size - # # 2D (K, r - i) - # K_ri = namedtuple("K_ri", MagColor._fields) - # mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( - # hsc_i[z_sel] < feniks_mag_thresh.HSC_I - # ) - # N_K_ri, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri = get_N_2d( - # uds_K[z_sel][mag_sel_ri], hsc_ri[z_sel][mag_sel_ri], dim1_is_mag=True - # ) - # mag_idx = 7 - # col_idx = [2, 3] - # K_ri = K_ri(mag_idx, col_idx, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri, N_K_ri) - # n_bins += bin_lo_K_ri.size - - # # 2D (K, g - r) - # K_gr = namedtuple("K_gr", MagColor._fields) - # mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( - # hsc_r[z_sel] < feniks_mag_thresh.HSC_R - # ) - # N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( - # uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr], dim1_is_mag=True - # ) - # mag_idx = 7 - # col_idx = [1, 2] - # K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr) - # n_bins += bin_lo_K_gr.size - - # # 2D (K, J - H) - # K_JH = namedtuple("K_JH", MagColor._fields) - # mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - # uds_H[z_sel] < feniks_mag_thresh.UDS_H - # ) - # N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( - # uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH], dim1_is_mag=True - # ) - # mag_idx = 7 - # col_idx = [5, 6] - # K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH) - # n_bins += bin_lo_K_JH.size - - z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, ri, iz, jh) + # 2D (K, r - i) + K_ri = namedtuple("K_ri", MagColor._fields) + mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( + hsc_i[z_sel] < feniks_mag_thresh.HSC_I + ) + N_K_ri, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri = get_N_2d( + uds_K[z_sel][mag_sel_ri], hsc_ri[z_sel][mag_sel_ri] + ) + mag_idx = 7 + col_idx = [2, 3] + K_ri = K_ri(mag_idx, col_idx, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri, N_K_ri, False) + n_bins += bin_lo_K_ri.size + + # 2D (K, g - r) + K_gr = namedtuple("K_gr", MagColor._fields) + mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( + hsc_r[z_sel] < feniks_mag_thresh.HSC_R + ) + N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( + uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr] + ) + mag_idx = 7 + col_idx = [1, 2] + K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr, False) + n_bins += bin_lo_K_gr.size + + # 2D (K, J - H) + K_JH = namedtuple("K_JH", MagColor._fields) + mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + uds_H[z_sel] < feniks_mag_thresh.UDS_H + ) + N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( + uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH] + ) + mag_idx = 7 + col_idx = [5, 6] + K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH, False) + n_bins += bin_lo_K_JH.size + + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, ri, iz, jh, K_ri, K_gr, K_JH) colors.append(z1) ############################################################################## @@ -604,7 +573,7 @@ def get_feniks_data( Z2 = namedtuple( "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh"], + ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh", "K_ug", "K_rz"], ) zbin = 1 z_min = zbins[zbin][0] @@ -641,7 +610,7 @@ def get_feniks_data( hsc_rz[z_sel][mag_sel_rz_zJ], hsc_uds_zJ[z_sel][mag_sel_rz_zJ] ) col_idx = [2, 4, 5] - rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ) + rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ, True) n_bins += bin_lo_rz_zJ.size # 1D (u - g | K) @@ -668,25 +637,14 @@ def get_feniks_data( bin_lo_ug, bin_hi_ug, N_1d_ug, + True, ) ) n_bins += bin_lo_ug.size # 1D (r - z | K) rz = [] - Rz_condK = namedtuple( - "Rz_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + Rz_condK = namedtuple("Rz_condK", ColorCondMag._fields) mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( hsc_z[z_sel] < feniks_mag_thresh.HSC_Z ) @@ -707,25 +665,13 @@ def get_feniks_data( bin_lo_rz, bin_hi_rz, N_1d_rz, + True, ) ) n_bins += bin_lo_rz.size # 1D (J − H | K) jh = [] - JH_condK = namedtuple( - "JH_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( uds_H[z_sel] < feniks_mag_thresh.UDS_H ) @@ -746,37 +692,38 @@ def get_feniks_data( bin_lo_jh, bin_hi_jh, N_1d_jh, + True, ) ) n_bins += bin_lo_jh.size - # # 2D (K, u - g) - # K_ug = namedtuple("K_ug", MagColor._fields) - # mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( - # hsc_g[z_sel] < feniks_mag_thresh.HSC_G - # ) - # N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( - # uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug], dim1_is_mag=True - # ) - # mag_idx = 7 - # col_idx = [0, 1] - # K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug) - # n_bins += bin_lo_K_ug.size - - # # 2D (K, r - z) - # K_rz = namedtuple("K_rz", MagColor._fields) - # mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( - # hsc_z[z_sel] < feniks_mag_thresh.HSC_Z - # ) - # N_K_rz, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz = get_N_2d( - # uds_K[z_sel][mag_sel_rz], hsc_rz[z_sel][mag_sel_rz], dim1_is_mag=True - # ) - # mag_idx = 7 - # col_idx = [2, 4] - # K_rz = K_rz(mag_idx, col_idx, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz, N_K_rz) - # n_bins += bin_lo_K_rz.size - - z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh) + # 2D (K, u - g) + K_ug = namedtuple("K_ug", MagColor._fields) + mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( + hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ) + N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( + uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug] + ) + mag_idx = 7 + col_idx = [0, 1] + K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug, False) + n_bins += bin_lo_K_ug.size + + # 2D (K, r - z) + K_rz = namedtuple("K_rz", MagColor._fields) + mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( + hsc_z[z_sel] < feniks_mag_thresh.HSC_Z + ) + N_K_rz, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz = get_N_2d( + uds_K[z_sel][mag_sel_rz], hsc_rz[z_sel][mag_sel_rz] + ) + mag_idx = 7 + col_idx = [2, 4] + K_rz = K_rz(mag_idx, col_idx, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz, N_K_rz, False) + n_bins += bin_lo_K_rz.size + + z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh, K_ug, K_rz) colors.append(z2) ############################################################################## @@ -789,7 +736,19 @@ def get_feniks_data( Z3 = namedtuple( "Z3", - ["z_min", "z_max", "lc_data", "zJ_JH", "ug_gr", "ug", "gr", "jh"], + [ + "z_min", + "z_max", + "lc_data", + "zJ_JH", + "ug_gr", + "ug", + "gr", + "jh", + "K_ug", + "K_gr", + "K_JH", + ], ) zbin = 2 z_min = zbins[zbin][0] @@ -826,7 +785,7 @@ def get_feniks_data( hsc_uds_zJ[z_sel][mag_sel_zJ_JH], uds_JH[z_sel][mag_sel_zJ_JH] ) col_idx = [4, 5, 6] - zJ_JH = zJ_JH(col_idx, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH, N_zJ_JH) + zJ_JH = zJ_JH(col_idx, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH, N_zJ_JH, True) n_bins += bin_lo_zJ_JH.size # 2D (u - g, g - r) @@ -840,7 +799,7 @@ def get_feniks_data( megacam_hsc_uSg[z_sel][mag_sel_ugr], hsc_gr[z_sel][mag_sel_ugr] ) col_idx = [0, 1, 2] - ug_gr = Ug_gr(col_idx, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr, N_ug_gr) + ug_gr = Ug_gr(col_idx, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr, N_ug_gr, True) n_bins += bin_lo_ug_gr.size # 1D (u - g | K) @@ -867,25 +826,14 @@ def get_feniks_data( bin_lo_ug, bin_hi_ug, N_1d_ug, + True, ) ) n_bins += bin_lo_ug.size # 1D (g - r | K) gr = [] - Gr_condK = namedtuple( - "Gr_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) + Gr_condK = namedtuple("Gr_condK", ColorCondMag._fields) mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( hsc_r[z_sel] < feniks_mag_thresh.HSC_R ) @@ -906,25 +854,13 @@ def get_feniks_data( bin_lo_gr, bin_hi_gr, N_1d_gr, + True, ) ) n_bins += bin_lo_gr.size # 1D (J − H | K) jh = [] - JH_condK = namedtuple( - "JH_condK", - [ - "col_idx", - "cond_idx", - "cond_min", - "cond_max", - "sig", - "bin_lo", - "bin_hi", - "N_data", - ], - ) mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( uds_H[z_sel] < feniks_mag_thresh.UDS_H ) @@ -945,50 +881,51 @@ def get_feniks_data( bin_lo_jh, bin_hi_jh, N_1d_jh, + True, ) ) n_bins += bin_lo_jh.size - # # 2D (K, u - g) - # K_ug = namedtuple("K_ug", MagColor._fields) - # mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( - # hsc_g[z_sel] < feniks_mag_thresh.HSC_G - # ) - # N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( - # uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug], dim1_is_mag=True - # ) - # mag_idx = 7 - # col_idx = [0, 1] - # K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug) - # n_bins += bin_lo_K_ug.size - - # # 2D (K, g - r) - # K_gr = namedtuple("K_gr", MagColor._fields) - # mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( - # hsc_r[z_sel] < feniks_mag_thresh.HSC_R - # ) - # N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( - # uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr], dim1_is_mag=True - # ) - # mag_idx = 7 - # col_idx = [1, 2] - # K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr) - # n_bins += bin_lo_K_gr.size - - # # 2D (K, J - H) - # K_JH = namedtuple("K_JH", MagColor._fields) - # mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - # uds_H[z_sel] < feniks_mag_thresh.UDS_H - # ) - # N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( - # uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH], dim1_is_mag=True - # ) - # mag_idx = 7 - # col_idx = [5, 6] - # K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH) - # n_bins += bin_lo_K_JH.size - - z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh) + # 2D (K, u - g) + K_ug = namedtuple("K_ug", MagColor._fields) + mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( + hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ) + N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( + uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug] + ) + mag_idx = 7 + col_idx = [0, 1] + K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug, False) + n_bins += bin_lo_K_ug.size + + # 2D (K, g - r) + K_gr = namedtuple("K_gr", MagColor._fields) + mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( + hsc_r[z_sel] < feniks_mag_thresh.HSC_R + ) + N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( + uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr] + ) + mag_idx = 7 + col_idx = [1, 2] + K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr, False) + n_bins += bin_lo_K_gr.size + + # 2D (K, J - H) + K_JH = namedtuple("K_JH", MagColor._fields) + mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( + uds_H[z_sel] < feniks_mag_thresh.UDS_H + ) + N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( + uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH] + ) + mag_idx = 7 + col_idx = [5, 6] + K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH, False) + n_bins += bin_lo_K_JH.size + + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh, K_ug, K_gr, K_JH) colors.append(z3) ############################################################################## @@ -1046,49 +983,49 @@ def get_feniks_data( # 1D (u) mag_idx_u = 0 N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(megacam_uS[z_sel]) - u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u) + u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u, True) n_bins += bin_lo_u.size # 1D (g) mag_idx_g = 1 N_1d_g, sig_g, bin_lo_g, bin_hi_g = get_N_1d(hsc_g[z_sel]) - g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g) + g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, False) n_bins += bin_lo_g.size # 1D (r) mag_idx_r = 2 N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(hsc_r[z_sel]) - r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r) + r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r, True) n_bins += bin_lo_r.size # 1D (i) mag_idx_i = 3 N_1d_i, sig_i, bin_lo_i, bin_hi_i = get_N_1d(hsc_i[z_sel]) - i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i) + i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, False) n_bins += bin_lo_i.size # 1D (z) mag_idx_z = 4 N_1d_z, sig_z, bin_lo_z, bin_hi_z = get_N_1d(hsc_z[z_sel]) - z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z) + z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, False) n_bins += bin_lo_z.size # 1D (J) mag_idx_j = 5 N_1d_j, sig_j, bin_lo_j, bin_hi_j = get_N_1d(uds_J[z_sel]) - j = J(mag_idx_j, sig_j, bin_lo_j, bin_hi_j, N_1d_j) + j = J(mag_idx_j, sig_j, bin_lo_j, bin_hi_j, N_1d_j, False) n_bins += bin_lo_j.size # 1D (H) mag_idx_h = 6 N_1d_h, sig_h, bin_lo_h, bin_hi_h = get_N_1d(uds_H[z_sel]) - h = H(mag_idx_h, sig_h, bin_lo_h, bin_hi_h, N_1d_h) + h = H(mag_idx_h, sig_h, bin_lo_h, bin_hi_h, N_1d_h, False) n_bins += bin_lo_h.size # 1D (K) mag_idx_k = 7 N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) - k = K(mag_idx_k, sig_k, bin_lo_k, bin_hi_k, N_1d_k) + k = K(mag_idx_k, sig_k, bin_lo_k, bin_hi_k, N_1d_k, True) n_bins += bin_lo_k.size app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, g, r, i, z, j, h, k)) diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 015cfca4..d5348f52 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -48,13 +48,30 @@ ], ) -ColorColor = namedtuple("ColorColor", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data"]) +ColorColor = namedtuple( + "ColorColor", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data", "fit"] +) + +ColorCondMag = namedtuple( + "ColorCondMag", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + "fit", + ], +) MagColor = namedtuple( - "MagColor", ["mag_idx", "col_idx", "sig", "bin_lo", "bin_hi", "N_data"] + "MagColor", ["mag_idx", "col_idx", "sig", "bin_lo", "bin_hi", "N_data", "fit"] ) AppMagFunc = namedtuple( "AppMagFunc", - ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data"], + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data", "fit"], ) diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index aee02125..edc1fe25 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -32,24 +32,25 @@ def get_phot_loss_2d_multiz( if isinstance(space, list): for s in range(0, len(space)): space_n = space[s] + if space_n.fit: + N_model = space_n.N_model + N_data = space_n.N_data - N_model = space_n.N_model - N_data = space_n.N_data + N_model = N_model * ( + data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq + ) + phot_loss_2d += poisson_loss(N_model, N_data) + + else: + if space.fit: + N_model = space.N_model + N_data = space.N_data N_model = N_model * ( data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq ) phot_loss_2d += poisson_loss(N_model, N_data) - else: - N_model = space.N_model - N_data = space.N_data - - N_model = N_model * ( - data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq - ) - phot_loss_2d += poisson_loss(N_model, N_data) - return phot_loss_2d diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 504ab7a1..915245f3 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run164 -model_nickname: run164_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run164/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run150 +model_nickname: run150_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run150/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks From a0cde610f9f55a6ed8848311de360c8a93353590 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 18:04:56 -0500 Subject: [PATCH 34/57] fit only u and r app mag funcs for sdss --- .../experimental/data_loaders/load_sdss.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 88927a75..356df3ae 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -234,30 +234,30 @@ def get_sdss_data( sdss_ur[z_sel], sdss_ri[z_sel] ) col_idx = [0, 2, 3] - ur_ri = Ur_ri(col_idx, sig_ur_ri, bin_lo_ur_ri, bin_hi_ur_ri, N_ur_ri) + ur_ri = Ur_ri(col_idx, sig_ur_ri, bin_lo_ur_ri, bin_hi_ur_ri, N_ur_ri, True) # 2D (g - r, r - i) N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri = get_N_2d( sdss_gr[z_sel], sdss_ri[z_sel] ) col_idx = [1, 2, 3] - gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri) + gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri, True) # 2D (r, u - r) N_r_ur, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur = get_N_2d( - sdss_r[z_sel], sdss_ur[z_sel], dim1_is_mag=True + sdss_r[z_sel], sdss_ur[z_sel] ) mag_idx = 2 col_idx = [0, 2] - r_ur = R_ur(mag_idx, col_idx, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur, N_r_ur) + r_ur = R_ur(mag_idx, col_idx, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur, N_r_ur, True) # 2D (r, r - i) N_r_ri, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri = get_N_2d( - sdss_r[z_sel], sdss_ri[z_sel], dim1_is_mag=True + sdss_r[z_sel], sdss_ri[z_sel] ) mag_idx = 2 col_idx = [2, 3] - r_ri = R_ri(mag_idx, col_idx, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri, N_r_ri) + r_ri = R_ri(mag_idx, col_idx, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri, N_r_ri, True) colors.append(Colors(z_min, z_max, lc_data, ur_ri, gr_ri, r_ur, r_ri)) @@ -311,27 +311,27 @@ def get_sdss_data( # 1D (u) mag_idx_u = 0 N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(sdss_u[z_sel]) - u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u) + u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u, True) # 1D (g) mag_idx_g = 1 N_1d_g, sig_g, bin_lo_g, bin_hi_g = get_N_1d(sdss_g[z_sel]) - g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g) + g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, False) # 1D (r) mag_idx_r = 2 N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(sdss_r[z_sel]) - r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r) + r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r, True) # 1D (i) mag_idx_i = 3 N_1d_i, sig_i, bin_lo_i, bin_hi_i = get_N_1d(sdss_i[z_sel]) - i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i) + i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, False) # 1D (z) mag_idx_z = 4 N_1d_z, sig_z, bin_lo_z, bin_hi_z = get_N_1d(sdss_z[z_sel]) - z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z) + z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, False) app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, g, r, i, z)) From 738f00a86985eb970582e521298adcee4f161a60 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 18:23:07 -0500 Subject: [PATCH 35/57] Update phot_loss.py --- .../experimental/loss_kernels/phot_loss.py | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index edc1fe25..5ebdece9 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -5,6 +5,26 @@ from .loss_functions import poisson_loss +def _collect_fit_spaces(z_data_model): + """ + Called at trace time (outside jit, or during first trace). + Returns a flat list of (N_model, N_data) pairs for spaces where fit=True. + fit must be a static Python bool on the namedtuple. + """ + fit_spaces = [] + fields = z_data_model._fields[3:] + for f in fields: + space = getattr(z_data_model, f) + if isinstance(space, list): + for space_n in space: + if space_n.fit: # static Python bool — fine at trace time + fit_spaces.append((space_n.N_model, space_n.N_data)) + else: + if space.fit: + fit_spaces.append((space.N_model, space.N_data)) + return fit_spaces + + @jjit def get_phot_loss_2d_multiz( ran_key, @@ -15,9 +35,7 @@ def get_phot_loss_2d_multiz( data_sky_area_degsq, ): phot_loss_2d = 0.0 - for z in range(0, len(data)): - z_data = data[z] - + for z_data in data: z_data_model = N_colors_mags( ran_key, param_collection, @@ -25,31 +43,13 @@ def get_phot_loss_2d_multiz( mag_thresh, frac_cat, ) - fields = z_data_model._fields[3:] - for f in range(0, len(fields)): - space = getattr(z_data_model, fields[f]) - - if isinstance(space, list): - for s in range(0, len(space)): - space_n = space[s] - if space_n.fit: - N_model = space_n.N_model - N_data = space_n.N_data - - N_model = N_model * ( - data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq - ) - phot_loss_2d += poisson_loss(N_model, N_data) - - else: - if space.fit: - N_model = space.N_model - N_data = space.N_data - - N_model = N_model * ( - data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq - ) - phot_loss_2d += poisson_loss(N_model, N_data) + sky_rescale = data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq + + # Structure is resolved at trace time — no JAX boolean issue + fit_spaces = _collect_fit_spaces(z_data_model) + + for N_model, N_data in fit_spaces: + phot_loss_2d += poisson_loss(N_model * sky_rescale, N_data) return phot_loss_2d From e7e8e496d5e23ca5b2cb56e87f1d65d602b0350b Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 18:53:35 -0500 Subject: [PATCH 36/57] Update phot_loss.py --- .../experimental/loss_kernels/phot_loss.py | 50 ++++++++----------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index 5ebdece9..01b46d21 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -1,30 +1,11 @@ from jax import jit as jjit +from jax import lax from ..kernels.N_phot import N_colors_mags, N_colors_mags_lh, N_mags_1d from ..param_utils import get_param_collection_from_u_theta from .loss_functions import poisson_loss -def _collect_fit_spaces(z_data_model): - """ - Called at trace time (outside jit, or during first trace). - Returns a flat list of (N_model, N_data) pairs for spaces where fit=True. - fit must be a static Python bool on the namedtuple. - """ - fit_spaces = [] - fields = z_data_model._fields[3:] - for f in fields: - space = getattr(z_data_model, f) - if isinstance(space, list): - for space_n in space: - if space_n.fit: # static Python bool — fine at trace time - fit_spaces.append((space_n.N_model, space_n.N_data)) - else: - if space.fit: - fit_spaces.append((space.N_model, space.N_data)) - return fit_spaces - - @jjit def get_phot_loss_2d_multiz( ran_key, @@ -35,7 +16,8 @@ def get_phot_loss_2d_multiz( data_sky_area_degsq, ): phot_loss_2d = 0.0 - for z_data in data: + for z in range(0, len(data)): + z_data = data[z] z_data_model = N_colors_mags( ran_key, param_collection, @@ -44,13 +26,25 @@ def get_phot_loss_2d_multiz( frac_cat, ) sky_rescale = data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq - - # Structure is resolved at trace time — no JAX boolean issue - fit_spaces = _collect_fit_spaces(z_data_model) - - for N_model, N_data in fit_spaces: - phot_loss_2d += poisson_loss(N_model * sky_rescale, N_data) - + fields = z_data_model._fields[3:] + for f in range(0, len(fields)): + space = getattr(z_data_model, fields[f]) + if isinstance(space, list): + for s in range(0, len(space)): + space_n = space[s] + phot_loss_2d += lax.cond( + space_n.fit, + lambda sp=space_n: poisson_loss( + sp.N_model * sky_rescale, sp.N_data + ), + lambda: 0.0, + ) + else: + phot_loss_2d += lax.cond( + space.fit, + lambda sp=space: poisson_loss(sp.N_model * sky_rescale, sp.N_data), + lambda: 0.0, + ) return phot_loss_2d From f695c7149faea846bed8b47bb0ee4078d3be8f26 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Wed, 10 Jun 2026 22:27:06 -0500 Subject: [PATCH 37/57] ugriz sdss app mag func all in fitting --- .../experimental/data_loaders/load_sdss.py | 6 +-- .../experimental/diagnostics/plot_contour.py | 5 ++- scripts/config_diagnostics.yaml | 40 +++++++++---------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 356df3ae..f30c14a1 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -316,7 +316,7 @@ def get_sdss_data( # 1D (g) mag_idx_g = 1 N_1d_g, sig_g, bin_lo_g, bin_hi_g = get_N_1d(sdss_g[z_sel]) - g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, False) + g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, True) # 1D (r) mag_idx_r = 2 @@ -326,12 +326,12 @@ def get_sdss_data( # 1D (i) mag_idx_i = 3 N_1d_i, sig_i, bin_lo_i, bin_hi_i = get_N_1d(sdss_i[z_sel]) - i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, False) + i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, True) # 1D (z) mag_idx_z = 4 N_1d_z, sig_z, bin_lo_z, bin_hi_z = get_N_1d(sdss_z[z_sel]) - z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, False) + z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, True) app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, g, r, i, z)) diff --git a/diffhtwo/experimental/diagnostics/plot_contour.py b/diffhtwo/experimental/diagnostics/plot_contour.py index 0b931e9b..ffe814f0 100644 --- a/diffhtwo/experimental/diagnostics/plot_contour.py +++ b/diffhtwo/experimental/diagnostics/plot_contour.py @@ -162,8 +162,9 @@ def plot_color_contours( else: fig, ax = plt.subplots(figsize=(6.4, 5.2), constrained_layout=True) - fig.get_layout_engine().set(h_pad=0.0, hspace=0.0, rect=(0, 0, 1, 0.96)) - fig.suptitle(str(z_min) + " < z < " + str(z_max), fontsize=14) + fig.suptitle(str(z_min) + " < z < " + str(z_max), fontsize=14, y=0.98) + fig.get_layout_engine().set(h_pad=0.0, hspace=0.0, rect=(0, 0, 1, 0.95)) + name = type(space).__name__ xlabel, ylabel = parse_color_labels(name) plot_density( diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 915245f3..41740218 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run150 -model_nickname: run150_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run150/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run170 +model_nickname: run170_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run170/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,26 +11,26 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: False -plot_feniks: True +plot_sdss: True +plot_feniks: False plot_hizels: False plots: num_halos : 3000 plot_color_contours: True - plot_app_mag_funcs: True - plot_color_pdfs: True - plot_colors_mags: True - plot_mags: True - plot_ssperr: True - plot_massive_cen_colors: True - plot_merging_sat_colors: True + plot_app_mag_funcs: False + plot_color_pdfs: False + plot_colors_mags: False + plot_mags: False + plot_ssperr: False + plot_massive_cen_colors: False + plot_merging_sat_colors: False plot_satquench: False - plot_satquench_model: True - plot_insitu_smhm: True - plot_insitu_sm: True - plot_uvj: True - plot_exsitu_frac: True - plot_avpop: True - plot_burstpop: True - plot_fburst_mh_z: True \ No newline at end of file + plot_satquench_model: False + plot_insitu_smhm: False + plot_insitu_sm: False + plot_uvj: False + plot_exsitu_frac: False + plot_avpop: False + plot_burstpop: False + plot_fburst_mh_z: False \ No newline at end of file From 5665af2435b52b1d67acce635767b966c19ffa28 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Thu, 11 Jun 2026 10:50:59 -0500 Subject: [PATCH 38/57] exclude highest z-bins in feniks and sdss due to low counts --- diffhtwo/experimental/data_loaders/load_feniks.py | 3 +-- diffhtwo/experimental/data_loaders/load_sdss.py | 13 +++---------- diffhtwo/experimental/diagnostics/plot_contour.py | 4 +++- scripts/config_diagnostics.yaml | 6 +++--- scripts/generate_diagnostic_plots.py | 2 +- 5 files changed, 11 insertions(+), 17 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index de8e9a6d..50d7da6d 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -343,7 +343,7 @@ def get_feniks_data( [ [0.2, 0.7], [0.7, 1.5], - [1.5, 2.5], + [1.5, 2.0], ] ) @@ -937,7 +937,6 @@ def get_feniks_data( [0.7, 1.0], [1.0, 1.5], [1.5, 2.0], - [2.0, 2.5], ] ) ############################################################################## diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index f30c14a1..58832059 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -183,8 +183,8 @@ def get_sdss_data( # prepare 2D and 1D color spaces in coarse z-bins for fitting zbins = np.array( [ - [0.02, 0.1], - [0.1, 0.2], + [0.02, 0.08], + [0.08, 0.14], ] ) ############################################################################## @@ -264,14 +264,7 @@ def get_sdss_data( ############################################################################## ############################################################################## # prepare 1D app mag funcs in finer z-bins for fitting - fine_zbins = np.array( - [ - [0.02, 0.06], - [0.06, 0.1], - [0.1, 0.15], - [0.15, 0.20], - ] - ) + fine_zbins = np.array([[0.02, 0.05], [0.05, 0.08], [0.08, 0.11], [0.11, 0.14]]) ############################################################################## AppMagFuncs = namedtuple( "AppMagFuncs", diff --git a/diffhtwo/experimental/diagnostics/plot_contour.py b/diffhtwo/experimental/diagnostics/plot_contour.py index ffe814f0..f6332c8e 100644 --- a/diffhtwo/experimental/diagnostics/plot_contour.py +++ b/diffhtwo/experimental/diagnostics/plot_contour.py @@ -54,7 +54,9 @@ def plot_density( sigma=sigma, ).clip(min=np.finfo(float).tiny) ) - levels = np.linspace(Z.min(), Z.max(), n_levels) + Z_min = np.max((-10, Z.min())) + Z_max = Z.max() + levels = np.linspace(Z_min, Z_max, n_levels) qm = ax.contourf(xc, yc, Z, levels=levels, cmap=cmap, alpha=0.5) ax.get_figure().colorbar(qm, ax=ax, label=r"$\log_{10}(N / N_{\rm tot})$") diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 41740218..29432d81 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run170 -model_nickname: run170_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run170/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run171 +model_nickname: run171_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run171/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 5a144ac2..3238724c 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -231,7 +231,7 @@ [0.5, 0.7], [0.7, 1.0], [1.0, 1.5], - [1.5, 2.0], + [1.5, 2.5], ] ) print("Generating FENIKS app mag funcs plot...") From 85201b768f70af9b65bd352a4958aa5195e19ab4 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Thu, 11 Jun 2026 11:18:50 -0500 Subject: [PATCH 39/57] fit all app mag funcs for feniks --- .../experimental/data_loaders/load_feniks.py | 10 ++--- scripts/generate_diagnostic_plots.py | 39 ++++++------------- 2 files changed, 16 insertions(+), 33 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 50d7da6d..69d1dba6 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -988,7 +988,7 @@ def get_feniks_data( # 1D (g) mag_idx_g = 1 N_1d_g, sig_g, bin_lo_g, bin_hi_g = get_N_1d(hsc_g[z_sel]) - g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, False) + g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, True) n_bins += bin_lo_g.size # 1D (r) @@ -1000,25 +1000,25 @@ def get_feniks_data( # 1D (i) mag_idx_i = 3 N_1d_i, sig_i, bin_lo_i, bin_hi_i = get_N_1d(hsc_i[z_sel]) - i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, False) + i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, True) n_bins += bin_lo_i.size # 1D (z) mag_idx_z = 4 N_1d_z, sig_z, bin_lo_z, bin_hi_z = get_N_1d(hsc_z[z_sel]) - z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, False) + z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, True) n_bins += bin_lo_z.size # 1D (J) mag_idx_j = 5 N_1d_j, sig_j, bin_lo_j, bin_hi_j = get_N_1d(uds_J[z_sel]) - j = J(mag_idx_j, sig_j, bin_lo_j, bin_hi_j, N_1d_j, False) + j = J(mag_idx_j, sig_j, bin_lo_j, bin_hi_j, N_1d_j, True) n_bins += bin_lo_j.size # 1D (H) mag_idx_h = 6 N_1d_h, sig_h, bin_lo_h, bin_hi_h = get_N_1d(uds_H[z_sel]) - h = H(mag_idx_h, sig_h, bin_lo_h, bin_hi_h, N_1d_h, False) + h = H(mag_idx_h, sig_h, bin_lo_h, bin_hi_h, N_1d_h, True) n_bins += bin_lo_h.size # 1D (K) diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 3238724c..50f5dce5 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -212,6 +212,16 @@ num_halos_fine_zbins=int(num_halos / 2), ) + feniks_zbins = np.array( + [ + [0.2, 0.5], + [0.5, 0.7], + [0.7, 1.0], + [1.0, 1.5], + [1.5, 2.0], + ] + ) + if cfg["plots"]["plot_color_contours"]: print("Generating FENIKS color contour plots...") plot_color_contours( @@ -225,15 +235,6 @@ ) if cfg["plots"]["plot_app_mag_funcs"]: - feniks_zbins = np.array( - [ - [0.2, 0.5], - [0.5, 0.7], - [0.7, 1.0], - [1.0, 1.5], - [1.5, 2.5], - ] - ) print("Generating FENIKS app mag funcs plot...") plot_app_mag_funcs( feniks, @@ -294,17 +295,6 @@ plt_show=False, ) - feniks_zbins = np.array( - [ - [0.2, 0.5], - [0.5, 0.7], - [0.7, 1.0], - [1.0, 1.5], - [1.5, 2.0], - [2.0, 2.5], - ] - ) - for zbin in range(0, len(feniks_zbins)): z_min = feniks_zbins[zbin][0] z_max = feniks_zbins[zbin][1] @@ -470,14 +460,7 @@ num_halos_coarse_zbins=num_halos, num_halos_fine_zbins=int(num_halos / 2), ) - sdss_zbins = np.array( - [ - [0.02, 0.06], - [0.06, 0.1], - [0.1, 0.15], - [0.15, 0.2], - ] - ) + sdss_zbins = np.array([[0.02, 0.05], [0.05, 0.08], [0.08, 0.11], [0.11, 0.14]]) if cfg["plots"]["plot_color_contours"]: print("Generating SDSS color contour plots...") From 8d1ad2e4d70a65d99855c64ff022cbbcbe007046 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Thu, 11 Jun 2026 13:51:14 -0500 Subject: [PATCH 40/57] even more conservative estimates of complete mags in sdss and feniks --- .../experimental/data_loaders/load_feniks.py | 14 +++---- .../experimental/data_loaders/load_sdss.py | 10 ++--- diffhtwo/experimental/defaults.py | 2 +- .../experimental/diagnostics/plot_burstpop.py | 2 +- scripts/config_diagnostics.yaml | 40 +++++++++---------- 5 files changed, 34 insertions(+), 34 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 69d1dba6..a13501ba 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -193,13 +193,13 @@ def get_feniks_data( uds_K = get_mag_ab(phot, "fcol_UDS_K") feniks_mag_thresh = FeniksFilters( - MegaCam_uS=24.9, - HSC_G=25.1, - HSC_R=25.3, - HSC_I=25.1, - HSC_Z=24.9, - UDS_J=24.5, - UDS_H=24.3, + MegaCam_uS=24.5, + HSC_G=24.5, + HSC_R=24.5, + HSC_I=24.5, + HSC_Z=24.5, + UDS_J=24.0, + UDS_H=24.0, UDS_K=FENIKS_MAGK_THRESH, ) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 58832059..e27d0b29 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -126,11 +126,11 @@ def get_sdss_data( n_z_phot_table=30, ): sdss_mag_thresh = SdssFilters( - sdss_u=19.7, - sdss_g=18.0, - sdss_r=SDSS_MAGR_THRESH, - sdss_i=17.0, - sdss_z=17.0, + sdss_u=19.0, + sdss_g=17.0, + sdss_r=17.0, + sdss_i=16.5, + sdss_z=16.5, ) sdss, frac_cat = load_sdss_cuts_applied(drn, sdss_mag_thresh) diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index d5348f52..a37503db 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -20,7 +20,7 @@ FENIKS_AREA_DEG2 = 2828.247933129912 / 3600 FENIKS_Z_MIN = 0.2 FENIKS_Z_MAX = 2.5 -FENIKS_MAGK_THRESH = 24.3 # col mag +FENIKS_MAGK_THRESH = 24.0 # col mag SDSS_AREA_DEG2 = 7199 SDSS_Z_MIN = 0.02 diff --git a/diffhtwo/experimental/diagnostics/plot_burstpop.py b/diffhtwo/experimental/diagnostics/plot_burstpop.py index f3267ea8..0e035424 100644 --- a/diffhtwo/experimental/diagnostics/plot_burstpop.py +++ b/diffhtwo/experimental/diagnostics/plot_burstpop.py @@ -160,7 +160,7 @@ def plot_lgfburst_mh_z( ) ax[0].set_xlabel("redshift") - ax[0].set_ylabel("log$_{10}$ (M$_{h, peak}$ [M\u2609])") + ax[0].set_ylabel(r"log$_{10}$ (M$_{h, peak}$ [M${_\odot}$])") ax[0].set_xlim(z_min, z_max) ax[0].set_ylim(10, 15) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 29432d81..f984fc34 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run171 -model_nickname: run171_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run171/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run174 +model_nickname: run174_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run174/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,26 +11,26 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: True -plot_feniks: False +plot_sdss: False +plot_feniks: True plot_hizels: False plots: num_halos : 3000 plot_color_contours: True - plot_app_mag_funcs: False - plot_color_pdfs: False - plot_colors_mags: False - plot_mags: False - plot_ssperr: False - plot_massive_cen_colors: False - plot_merging_sat_colors: False + plot_app_mag_funcs: True + plot_color_pdfs: True + plot_colors_mags: True + plot_mags: True + plot_ssperr: True + plot_massive_cen_colors: True + plot_merging_sat_colors: True plot_satquench: False - plot_satquench_model: False - plot_insitu_smhm: False - plot_insitu_sm: False - plot_uvj: False - plot_exsitu_frac: False - plot_avpop: False - plot_burstpop: False - plot_fburst_mh_z: False \ No newline at end of file + plot_satquench_model: True + plot_insitu_smhm: True + plot_insitu_sm: True + plot_uvj: True + plot_exsitu_frac: True + plot_avpop: True + plot_burstpop: True + plot_fburst_mh_z: True \ No newline at end of file From bcb754d191186f08ec25be82794f146a1d8c812d Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Thu, 11 Jun 2026 21:09:54 -0500 Subject: [PATCH 41/57] condr in sdss --- .../experimental/data_loaders/load_sdss.py | 83 +++++++++++-- .../experimental/diagnostics/plot_phot.py | 2 +- .../{plot_insitu_sm.py => plot_sm.py} | 117 +++++++++++++++++- scripts/config_diagnostics.yaml | 43 +++---- scripts/generate_diagnostic_plots.py | 42 ++++++- 5 files changed, 248 insertions(+), 39 deletions(-) rename diffhtwo/experimental/diagnostics/{plot_insitu_sm.py => plot_sm.py} (52%) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index e27d0b29..110f91f6 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -13,6 +13,7 @@ SDSS_Z_MIN, AppMagFunc, ColorColor, + ColorCondMag, Dataset, FilterInfo, MagColor, @@ -126,11 +127,11 @@ def get_sdss_data( n_z_phot_table=30, ): sdss_mag_thresh = SdssFilters( - sdss_u=19.0, - sdss_g=17.0, - sdss_r=17.0, - sdss_i=16.5, - sdss_z=16.5, + sdss_u=19.7, + sdss_g=18.0, + sdss_r=SDSS_MAGR_THRESH, + sdss_i=17.0, + sdss_z=17.0, ) sdss, frac_cat = load_sdss_cuts_applied(drn, sdss_mag_thresh) @@ -183,14 +184,14 @@ def get_sdss_data( # prepare 2D and 1D color spaces in coarse z-bins for fitting zbins = np.array( [ - [0.02, 0.08], - [0.08, 0.14], + [0.02, 0.1], + [0.1, 0.2], ] ) ############################################################################## Colors = namedtuple( "Colors", - ["z_min", "z_max", "lc_data", "ur_ri", "gr_ri", "R_ur", "R_ri"], + ["z_min", "z_max", "lc_data", "ur_ri", "gr_ri", "ur", "ri", "r_ur", "r_ri"], ) # 2D (u - r, r - i) Ur_ri = namedtuple("Ur_ri", ColorColor._fields) @@ -204,6 +205,12 @@ def get_sdss_data( # 2D (r, r - i) R_ri = namedtuple("R_ri", MagColor._fields) + # 1D (u - r | r) + Ur_condr = namedtuple("Ug_condK", ColorCondMag._fields) + + # 1D (r - i | r) + Ri_condr = namedtuple("Ri_condr", ColorCondMag._fields) + colors = [] for zbin in range(0, len(zbins)): z_min = zbins[zbin][0] @@ -243,13 +250,58 @@ def get_sdss_data( col_idx = [1, 2, 3] gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri, True) + # 1D (u - r | r) + rbins = np.arange(sdss_r[z_sel].min(), sdss_r[z_sel].max(), 2) + print(rbins) + + col_idx = [0, 2] + cond_idx = 2 + ur = [] + for r in range(len(rbins) - 1): + r_sel = (sdss_r[z_sel] > rbins[r]) & (sdss_r[z_sel] <= rbins[r + 1]) + N_1d_ur, sig_ur, bin_lo_ur, bin_hi_ur = get_N_1d(sdss_ur[z_sel][r_sel]) + ur.append( + Ur_condr( + col_idx, + cond_idx, + rbins[r], + rbins[r + 1], + sig_ur, + bin_lo_ur, + bin_hi_ur, + N_1d_ur, + True, + ) + ) + + # 1D (r - i | r) + col_idx = [2, 3] + cond_idx = 2 + ri = [] + for r in range(len(rbins) - 1): + r_sel = (sdss_r[z_sel] > rbins[r]) & (sdss_r[z_sel] <= rbins[r + 1]) + N_1d_ri, sig_ri, bin_lo_ri, bin_hi_ri = get_N_1d(sdss_ri[z_sel][r_sel]) + ri.append( + Ri_condr( + col_idx, + cond_idx, + rbins[r], + rbins[r + 1], + sig_ri, + bin_lo_ri, + bin_hi_ri, + N_1d_ri, + True, + ) + ) + # 2D (r, u - r) N_r_ur, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur = get_N_2d( sdss_r[z_sel], sdss_ur[z_sel] ) mag_idx = 2 col_idx = [0, 2] - r_ur = R_ur(mag_idx, col_idx, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur, N_r_ur, True) + r_ur = R_ur(mag_idx, col_idx, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur, N_r_ur, False) # 2D (r, r - i) N_r_ri, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri = get_N_2d( @@ -257,14 +309,21 @@ def get_sdss_data( ) mag_idx = 2 col_idx = [2, 3] - r_ri = R_ri(mag_idx, col_idx, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri, N_r_ri, True) + r_ri = R_ri(mag_idx, col_idx, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri, N_r_ri, False) - colors.append(Colors(z_min, z_max, lc_data, ur_ri, gr_ri, r_ur, r_ri)) + colors.append(Colors(z_min, z_max, lc_data, ur_ri, gr_ri, ur, ri, r_ur, r_ri)) ############################################################################## ############################################################################## # prepare 1D app mag funcs in finer z-bins for fitting - fine_zbins = np.array([[0.02, 0.05], [0.05, 0.08], [0.08, 0.11], [0.11, 0.14]]) + fine_zbins = fine_zbins = np.array( + [ + [0.02, 0.06], + [0.06, 0.1], + [0.1, 0.15], + [0.15, 0.20], + ] + ) ############################################################################## AppMagFuncs = namedtuple( "AppMagFuncs", diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index 3d372e07..488fb55b 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -721,7 +721,7 @@ def plot_app_mag_funcs( bin_centers, np.log10(n_diffsky), c=colors_z[zbin], alpha=alpha ) - axes[row, col].set_xticks(np.arange(15, 30, 2)) + axes[row, col].set_xticks(np.arange(10, 30, 2)) axes[row, col].minorticks_on() axes[row, col].tick_params( which="major", diff --git a/diffhtwo/experimental/diagnostics/plot_insitu_sm.py b/diffhtwo/experimental/diagnostics/plot_sm.py similarity index 52% rename from diffhtwo/experimental/diagnostics/plot_insitu_sm.py rename to diffhtwo/experimental/diagnostics/plot_sm.py index fac609f8..caf228c5 100644 --- a/diffhtwo/experimental/diagnostics/plot_insitu_sm.py +++ b/diffhtwo/experimental/diagnostics/plot_sm.py @@ -7,7 +7,7 @@ plt.rc("font", family="serif", serif=["Times New Roman"]) -def plot_insitu_sm( +def plot_insitu_sm_obs( ran_key, param_collection, z_min, @@ -105,7 +105,120 @@ def plot_insitu_sm( fig.savefig( savedir - + "/insitu_sm_" + + "/insitu_sm_obs_" + + model_nickname + + "_z" + + z_min_label + + "-" + + z_max_label + + ".png", + bbox_inches="tight", + dpi=200, + ) + if plt_show: + plt.show() + plt.close() + + +def plot_sm_obs( + ran_key, + param_collection, + z_min, + z_max, + dimension_labels, + ssp_data, + tcurves, + model_nickname, + savedir, + mag_thresh=None, + frac_cat=None, + num_halos=1000, + plt_show=True, +): + fig, ax = plt.subplots(1, figsize=(5, 5)) + + z_min_label = str(np.round(z_min, 2)) + z_max_label = str(np.round(z_max, 2)) + + fig.suptitle(z_min_label + " < z < " + z_max_label) + + """fit""" + lc_data, phot_kern_results, weights = multiband_lc_phot_kern( + ran_key, + param_collection, + z_min, + z_max, + num_halos, + ssp_data, + tcurves, + mag_thresh=mag_thresh, + frac_cat=frac_cat, + ) + + bins = np.linspace( + phot_kern_results.logsm_obs.min(), + phot_kern_results.logsm_obs.max(), + 50, + ) + + cen = lc_data.is_central == 1 + sat = lc_data.is_central != 1 + + ax.hist( + phot_kern_results.logsm_obs[sat], + bins=bins, + label="fit sat", + color="tab:orange", + histtype="step", + ) + ax.hist( + phot_kern_results.logsm_obs[cen], + bins=bins, + label="fit cen", + color="tab:blue", + histtype="step", + ) + + """default""" + lc_data, phot_kern_results, weights = multiband_lc_phot_kern( + ran_key, + DEFAULT_PARAM_COLLECTION, + z_min, + z_max, + num_halos, + ssp_data, + tcurves, + ) + + cen = lc_data.is_central == 1 + sat = lc_data.is_central != 1 + + ax.hist( + phot_kern_results.logsm_obs[sat], + bins=bins, + label="default sat", + color="tab:orange", + histtype="step", + ls="--", + ) + ax.hist( + phot_kern_results.logsm_obs[cen], + bins=bins, + label="default cen", + color="tab:blue", + histtype="step", + ls="--", + ) + + ax.set_xlabel("logsm_obs") + ax.set_ylabel("#") + ax.set_xlim(0, 14) + ax.set_yscale("log") + ax.legend() + + fig.savefig( + savedir + + "/sm_obs_" + model_nickname + "_z" + z_min_label diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index f984fc34..2b53211b 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run174 -model_nickname: run174_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run174/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run173 +model_nickname: run173_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run173/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,26 +11,27 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: False -plot_feniks: True +plot_sdss: True +plot_feniks: False plot_hizels: False plots: num_halos : 3000 - plot_color_contours: True - plot_app_mag_funcs: True - plot_color_pdfs: True - plot_colors_mags: True - plot_mags: True - plot_ssperr: True - plot_massive_cen_colors: True - plot_merging_sat_colors: True + plot_color_contours: False + plot_app_mag_funcs: False + plot_color_pdfs: False + plot_colors_mags: False + plot_mags: False + plot_ssperr: False + plot_massive_cen_colors: False + plot_merging_sat_colors: False plot_satquench: False - plot_satquench_model: True - plot_insitu_smhm: True - plot_insitu_sm: True - plot_uvj: True - plot_exsitu_frac: True - plot_avpop: True - plot_burstpop: True - plot_fburst_mh_z: True \ No newline at end of file + plot_satquench_model: False + plot_insitu_smhm: False + plot_insitu_sm: False + plot_sm: True + plot_uvj: False + plot_exsitu_frac: False + plot_avpop: False + plot_burstpop: False + plot_fburst_mh_z: False \ No newline at end of file diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 50f5dce5..e1550226 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -21,6 +21,7 @@ SDSS_Z_MAX, SDSS_Z_MIN, ) +from diffhtwo.experimental.diagnostics import plot_sm from diffhtwo.experimental.diagnostics.plot_avpop_mono import ( make_avpop_mono_comparison_plots, ) @@ -37,7 +38,6 @@ plot_halpha_sfr, plot_halpha_ssfr, ) -from diffhtwo.experimental.diagnostics.plot_insitu_sm import plot_insitu_sm from diffhtwo.experimental.diagnostics.plot_phot import ( plot_app_mag_funcs, plot_color_pdfs, @@ -303,7 +303,25 @@ print( f"Generating FENIKS in-situ sm plot for {zbin+1}/{len(feniks_zbins)} z-bin..." ) - plot_insitu_sm( + plot_sm.plot_insitu_sm_obs( + ran_key, + param_collection_fit, + z_min, + z_max, + feniks.dataset_dim_labels, + ssp_data, + feniks.filter_info.tcurves, + feniks_label, + fit_diagnostics_save_drn, + num_halos=num_halos, + plt_show=False, + ) + + if cfg["plots"]["plot_sm"]: + print( + f"Generating FENIKS in+ex-situ sm plot for {zbin+1}/{len(feniks_zbins)} z-bin..." + ) + plot_sm.plot_sm_obs( ran_key, param_collection_fit, z_min, @@ -531,7 +549,25 @@ print( f"Generating SDSS in-situ sm plot for {zbin+1}/{len(sdss_zbins)} z-bin..." ) - plot_insitu_sm( + plot_sm.plot_insitu_sm_obs( + ran_key, + param_collection_fit, + z_min, + z_max, + sdss.dataset_dim_labels, + ssp_data, + sdss.filter_info.tcurves, + sdss_label, + fit_diagnostics_save_drn, + num_halos=num_halos, + plt_show=False, + ) + + if cfg["plots"]["plot_sm"]: + print( + f"Generating SDSS in+ex-situ sm plot for {zbin+1}/{len(sdss_zbins)} z-bin..." + ) + plot_sm.plot_sm_obs( ran_key, param_collection_fit, z_min, From 7915cd26d039405564ddcb8cd2b93e025c7fa911 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Thu, 11 Jun 2026 22:09:28 -0500 Subject: [PATCH 42/57] Update load_feniks.py --- diffhtwo/experimental/data_loaders/load_feniks.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index a13501ba..69d1dba6 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -193,13 +193,13 @@ def get_feniks_data( uds_K = get_mag_ab(phot, "fcol_UDS_K") feniks_mag_thresh = FeniksFilters( - MegaCam_uS=24.5, - HSC_G=24.5, - HSC_R=24.5, - HSC_I=24.5, - HSC_Z=24.5, - UDS_J=24.0, - UDS_H=24.0, + MegaCam_uS=24.9, + HSC_G=25.1, + HSC_R=25.3, + HSC_I=25.1, + HSC_Z=24.9, + UDS_J=24.5, + UDS_H=24.3, UDS_K=FENIKS_MAGK_THRESH, ) From a100b3ede4d5bf88602fa9a3344c34f4afeb4f21 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Fri, 12 Jun 2026 01:01:13 -0500 Subject: [PATCH 43/57] sdss until z=0.16 --- .../experimental/data_loaders/load_sdss.py | 10 +- .../experimental/diagnostics/plot_smhm.py | 157 ++++++++++++++++++ scripts/config_diagnostics.yaml | 39 ++--- scripts/generate_diagnostic_plots.py | 41 ++++- 4 files changed, 221 insertions(+), 26 deletions(-) create mode 100644 diffhtwo/experimental/diagnostics/plot_smhm.py diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 110f91f6..59d80907 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -184,8 +184,8 @@ def get_sdss_data( # prepare 2D and 1D color spaces in coarse z-bins for fitting zbins = np.array( [ - [0.02, 0.1], - [0.1, 0.2], + [0.02, 0.08], + [0.08, 0.16], ] ) ############################################################################## @@ -252,7 +252,6 @@ def get_sdss_data( # 1D (u - r | r) rbins = np.arange(sdss_r[z_sel].min(), sdss_r[z_sel].max(), 2) - print(rbins) col_idx = [0, 2] cond_idx = 2 @@ -316,12 +315,11 @@ def get_sdss_data( ############################################################################## ############################################################################## # prepare 1D app mag funcs in finer z-bins for fitting - fine_zbins = fine_zbins = np.array( + fine_zbins = np.array( [ [0.02, 0.06], [0.06, 0.1], - [0.1, 0.15], - [0.15, 0.20], + [0.1, 0.16], ] ) ############################################################################## diff --git a/diffhtwo/experimental/diagnostics/plot_smhm.py b/diffhtwo/experimental/diagnostics/plot_smhm.py new file mode 100644 index 00000000..374369c9 --- /dev/null +++ b/diffhtwo/experimental/diagnostics/plot_smhm.py @@ -0,0 +1,157 @@ +import matplotlib.pyplot as plt +import numpy as np +from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION +from scipy.stats import binned_statistic + +from ..kernels.lc_phot_kern import multiband_lc_phot_kern + +plt.rc("font", family="serif", serif=["Times New Roman"]) + + +def get_median_logsm_obs(logmp_obs, logsm_obs): + logmp_bins = np.arange(11.0, logmp_obs.max() + 0.25, 0.25) + logmp_bin_centers = (logmp_bins[:-1] + logmp_bins[1:]) / 2 + + logsm_16, __, __ = binned_statistic( + logmp_obs, logsm_obs, bins=logmp_bins, statistic=lambda x: np.percentile(x, 16) + ) + logsm_50, __, __ = binned_statistic( + logmp_obs, logsm_obs, bins=logmp_bins, statistic="median" + ) + logsm_84, __, __ = binned_statistic( + logmp_obs, logsm_obs, bins=logmp_bins, statistic=lambda x: np.percentile(x, 84) + ) + return logmp_bin_centers, logsm_16, logsm_50, logsm_84 + + +def plot_smhm( + ran_key, + param_collection, + zbins, + num_halos, + ssp_data, + tcurves, + mag_thresh, + frac_cat, + data_label, + savedir, + plt_show=True, +): + n_z_bins = len(zbins) + fig_width = 1.42 * n_z_bins + fig_height = 2 + fig, ax = plt.subplots( + 1, len(zbins), figsize=(fig_width, fig_height), constrained_layout=True + ) + + labelsize = 10 + fontsize = 10 + labelsize = 10 + alpha = 0.25 + + for zbin in range(n_z_bins): + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + z_min_label = str(np.round(z_min, 2)) + z_max_label = str(np.round(z_max, 2)) + ax[zbin].set_title(z_min_label + " < z < " + z_max_label) + + """default""" + lc_data, phot_kern_results, gal_weight = multiband_lc_phot_kern( + ran_key, + DEFAULT_PARAM_COLLECTION, + z_min, + z_max, + num_halos, + ssp_data, + tcurves, + mag_thresh=mag_thresh, + frac_cat=frac_cat, + ) + ( + logmp_bin_centers_default, + logsm_16_default, + logsm_50_default, + logsm_84_default, + ) = get_median_logsm_obs(lc_data.logmp_obs, phot_kern_results.logsm_obs) + + ax[zbin].plot( + logmp_bin_centers_default, + logsm_50_default, + label="default", + color="#FFB689", + ) + ax[zbin].fill_between( + logmp_bin_centers_default, + logsm_16_default, + logsm_84_default, + alpha=alpha, + color="#FFB689", + ) + + """fit""" + lc_data, phot_kern_results, gal_weight = multiband_lc_phot_kern( + ran_key, + param_collection, + z_min, + z_max, + num_halos, + ssp_data, + tcurves, + mag_thresh=mag_thresh, + frac_cat=frac_cat, + ) + + ( + logmp_bin_centers_fit, + logsm_16_fit, + logsm_50_fit, + logsm_84_fit, + ) = get_median_logsm_obs(lc_data.logmp_obs, phot_kern_results.logsm_obs) + + ax[zbin].plot(logmp_bin_centers_fit, logsm_50_fit, label="fit", color="#61C0BF") + ax[zbin].fill_between( + logmp_bin_centers_fit, + logsm_16_fit, + logsm_84_fit, + alpha=alpha, + color="#61C0BF", + ) + + ax[zbin].set_xlabel("logmp_obs", fontsize=fontsize) + ax[zbin].set_xlim(11, lc_data.logmp_obs.max()) + ax[zbin].set_ylim(8, 13) + ax[zbin].set_xticks([11, 12, 13, 14, 15]) + ax[zbin].set_yticks([8, 9, 10, 11, 12]) + + ax[zbin].minorticks_on() + ax[zbin].tick_params( + which="major", + direction="in", + top=True, + right=True, + length=6, + width=1, + labelsize=labelsize, + ) + ax[zbin].tick_params( + which="minor", + direction="in", + top=True, + right=True, + length=3, + width=0.8, + labelsize=labelsize, + ) + + ax[0].set_ylabel("logsm_obs", fontsize=fontsize) + ax[-1].legend(fontsize=7, loc="lower right") + + fig.savefig( + savedir + "/" + data_label + "_smhm.png", + dpi=300, + ) + + if plt_show: + plt.show() + plt.close() diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 2b53211b..cb03c531 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run173 -model_nickname: run173_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run173/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run177 +model_nickname: run177_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run177/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -17,21 +17,22 @@ plot_hizels: False plots: num_halos : 3000 - plot_color_contours: False - plot_app_mag_funcs: False - plot_color_pdfs: False - plot_colors_mags: False - plot_mags: False - plot_ssperr: False - plot_massive_cen_colors: False - plot_merging_sat_colors: False + plot_color_contours: True + plot_app_mag_funcs: True + plot_color_pdfs: True + plot_colors_mags: True + plot_mags: True + plot_ssperr: True + plot_massive_cen_colors: True + plot_merging_sat_colors: True plot_satquench: False - plot_satquench_model: False - plot_insitu_smhm: False - plot_insitu_sm: False + plot_satquench_model: True + plot_smhm: True + plot_insitu_smhm: True + plot_insitu_sm: True plot_sm: True - plot_uvj: False - plot_exsitu_frac: False - plot_avpop: False - plot_burstpop: False - plot_fburst_mh_z: False \ No newline at end of file + plot_uvj: True + plot_exsitu_frac: True + plot_avpop: True + plot_burstpop: True + plot_fburst_mh_z: True \ No newline at end of file diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index e1550226..bc3e138b 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -50,6 +50,7 @@ generate_sat_plots, plot_satquench_model, ) +from diffhtwo.experimental.diagnostics.plot_smhm import plot_smhm if __name__ == "__main__": p = argparse.ArgumentParser() @@ -222,6 +223,22 @@ ] ) + if cfg["plots"]["plot_smhm"]: + print("Generating FENIKS SMHM plots...") + plot_smhm( + ran_key, + param_collection_fit, + feniks_zbins, + num_halos, + ssp_data, + feniks.filter_info.tcurves, + feniks.filter_info.mag_thresh, + feniks.frac_cat, + feniks_label, + fit_diagnostics_save_drn, + plt_show=False, + ) + if cfg["plots"]["plot_color_contours"]: print("Generating FENIKS color contour plots...") plot_color_contours( @@ -478,7 +495,29 @@ num_halos_coarse_zbins=num_halos, num_halos_fine_zbins=int(num_halos / 2), ) - sdss_zbins = np.array([[0.02, 0.05], [0.05, 0.08], [0.08, 0.11], [0.11, 0.14]]) + sdss_zbins = np.array( + [ + [0.02, 0.06], + [0.06, 0.1], + [0.1, 0.16], + ] + ) + + if cfg["plots"]["plot_smhm"]: + print("Generating SDSS SMHM plots...") + plot_smhm( + ran_key, + param_collection_fit, + sdss_zbins, + num_halos, + ssp_data, + sdss.filter_info.tcurves, + sdss.filter_info.mag_thresh, + sdss.frac_cat, + sdss_label, + fit_diagnostics_save_drn, + plt_show=False, + ) if cfg["plots"]["plot_color_contours"]: print("Generating SDSS color contour plots...") From 20e2b48d243fedd732c270766017a2b989c17663 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Fri, 12 Jun 2026 13:22:03 -0500 Subject: [PATCH 44/57] plot_smhm_in_situ --- diffhtwo/experimental/defaults.py | 4 +- .../experimental/diagnostics/plot_phot.py | 11 ++++- .../experimental/diagnostics/plot_smhm.py | 47 ++++++++++++++----- scripts/config_diagnostics.yaml | 10 ++-- scripts/generate_diagnostic_plots.py | 29 ++++++++++-- 5 files changed, 77 insertions(+), 24 deletions(-) diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index a37503db..2c53e829 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -19,8 +19,8 @@ FENIKS_AREA_DEG2 = 2828.247933129912 / 3600 FENIKS_Z_MIN = 0.2 -FENIKS_Z_MAX = 2.5 -FENIKS_MAGK_THRESH = 24.0 # col mag +FENIKS_Z_MAX = 2.0 +FENIKS_MAGK_THRESH = 24.3 # col mag SDSS_AREA_DEG2 = 7199 SDSS_Z_MIN = 0.02 diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index 488fb55b..f8f96970 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -592,7 +592,14 @@ def plot_app_mag_funcs( zbins = np.array(zbins) labels_z = [" z = " + str(np.round(np.median(z), 2)) for z in zbins] - if len(labels_z) == 4: + if len(labels_z) == 3: + colors_z = [ + "#001219", + "#0a7a80", + "#c87820", + ] + + elif len(labels_z) == 4: colors_z = [ "#001219", "#0a7a80", @@ -600,7 +607,7 @@ def plot_app_mag_funcs( "#c87820", ] - else: + elif len(labels_z) == 5: colors_z = [ "#001219", "#0a7a80", diff --git a/diffhtwo/experimental/diagnostics/plot_smhm.py b/diffhtwo/experimental/diagnostics/plot_smhm.py index 374369c9..2c69a1e9 100644 --- a/diffhtwo/experimental/diagnostics/plot_smhm.py +++ b/diffhtwo/experimental/diagnostics/plot_smhm.py @@ -9,7 +9,7 @@ def get_median_logsm_obs(logmp_obs, logsm_obs): - logmp_bins = np.arange(11.0, logmp_obs.max() + 0.25, 0.25) + logmp_bins = np.arange(logmp_obs.min(), logmp_obs.max() + 0.25, 0.25) logmp_bin_centers = (logmp_bins[:-1] + logmp_bins[1:]) / 2 logsm_16, __, __ = binned_statistic( @@ -31,10 +31,11 @@ def plot_smhm( num_halos, ssp_data, tcurves, - mag_thresh, - frac_cat, data_label, savedir, + mag_thresh=None, + frac_cat=None, + in_situ=False, plt_show=True, ): n_z_bins = len(zbins) @@ -68,12 +69,16 @@ def plot_smhm( mag_thresh=mag_thresh, frac_cat=frac_cat, ) + if in_situ: + logsm_obs = phot_kern_results.logsm_obs_in_situ + else: + logsm_obs = phot_kern_results.logsm_obs ( logmp_bin_centers_default, logsm_16_default, logsm_50_default, logsm_84_default, - ) = get_median_logsm_obs(lc_data.logmp_obs, phot_kern_results.logsm_obs) + ) = get_median_logsm_obs(lc_data.logmp_obs, logsm_obs) ax[zbin].plot( logmp_bin_centers_default, @@ -101,13 +106,17 @@ def plot_smhm( mag_thresh=mag_thresh, frac_cat=frac_cat, ) + if in_situ: + logsm_obs = phot_kern_results.logsm_obs_in_situ + else: + logsm_obs = phot_kern_results.logsm_obs ( logmp_bin_centers_fit, logsm_16_fit, logsm_50_fit, logsm_84_fit, - ) = get_median_logsm_obs(lc_data.logmp_obs, phot_kern_results.logsm_obs) + ) = get_median_logsm_obs(lc_data.logmp_obs, logsm_obs) ax[zbin].plot(logmp_bin_centers_fit, logsm_50_fit, label="fit", color="#61C0BF") ax[zbin].fill_between( @@ -118,12 +127,19 @@ def plot_smhm( color="#61C0BF", ) - ax[zbin].set_xlabel("logmp_obs", fontsize=fontsize) + # if in_situ: + # ax[zbin].set_xlim(10, lc_data.logmp_obs.max()) + # ax[zbin].set_ylim(5, 13) + # ax[zbin].set_xticks([10, 11, 12, 13, 14, 15]) + # ax[zbin].set_yticks([5, 6, 7, 8, 9, 10, 11, 12]) + # else: ax[zbin].set_xlim(11, lc_data.logmp_obs.max()) ax[zbin].set_ylim(8, 13) ax[zbin].set_xticks([11, 12, 13, 14, 15]) ax[zbin].set_yticks([8, 9, 10, 11, 12]) + ax[zbin].set_xlabel("logmp_obs", fontsize=fontsize) + ax[zbin].minorticks_on() ax[zbin].tick_params( which="major", @@ -144,13 +160,22 @@ def plot_smhm( labelsize=labelsize, ) - ax[0].set_ylabel("logsm_obs", fontsize=fontsize) + if in_situ: + ax[0].set_ylabel("logsm_obs_in_situ", fontsize=fontsize) + else: + ax[0].set_ylabel("logsm_obs", fontsize=fontsize) ax[-1].legend(fontsize=7, loc="lower right") - fig.savefig( - savedir + "/" + data_label + "_smhm.png", - dpi=300, - ) + if in_situ: + fig.savefig( + savedir + "/" + data_label + "_insitu_smhm.png", + dpi=300, + ) + else: + fig.savefig( + savedir + "/" + data_label + "_smhm.png", + dpi=300, + ) if plt_show: plt.show() diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index cb03c531..11989d3e 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -11,8 +11,8 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: True -plot_feniks: False +plot_sdss: False +plot_feniks: True plot_hizels: False plots: @@ -25,8 +25,6 @@ plots: plot_ssperr: True plot_massive_cen_colors: True plot_merging_sat_colors: True - plot_satquench: False - plot_satquench_model: True plot_smhm: True plot_insitu_smhm: True plot_insitu_sm: True @@ -35,4 +33,6 @@ plots: plot_exsitu_frac: True plot_avpop: True plot_burstpop: True - plot_fburst_mh_z: True \ No newline at end of file + plot_fburst_mh_z: True + plot_satquench_model: True + plot_satquench: False \ No newline at end of file diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index bc3e138b..ff7d914a 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -232,12 +232,22 @@ num_halos, ssp_data, feniks.filter_info.tcurves, - feniks.filter_info.mag_thresh, - feniks.frac_cat, feniks_label, fit_diagnostics_save_drn, plt_show=False, ) + plot_smhm( + ran_key, + param_collection_fit, + feniks_zbins, + num_halos, + ssp_data, + feniks.filter_info.tcurves, + feniks_label, + fit_diagnostics_save_drn, + in_situ=True, + plt_show=False, + ) if cfg["plots"]["plot_color_contours"]: print("Generating FENIKS color contour plots...") @@ -500,6 +510,7 @@ [0.02, 0.06], [0.06, 0.1], [0.1, 0.16], + [0.16, 0.20], ] ) @@ -512,12 +523,22 @@ num_halos, ssp_data, sdss.filter_info.tcurves, - sdss.filter_info.mag_thresh, - sdss.frac_cat, sdss_label, fit_diagnostics_save_drn, plt_show=False, ) + plot_smhm( + ran_key, + param_collection_fit, + sdss_zbins, + num_halos, + ssp_data, + sdss.filter_info.tcurves, + sdss_label, + fit_diagnostics_save_drn, + in_situ=True, + plt_show=False, + ) if cfg["plots"]["plot_color_contours"]: print("Generating SDSS color contour plots...") From 44431ceaec51a2cc4df149bba0d0f6b12815bc80 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Fri, 12 Jun 2026 13:32:45 -0500 Subject: [PATCH 45/57] Update load_sdss.py --- diffhtwo/experimental/data_loaders/load_sdss.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 59d80907..83e152b3 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -184,8 +184,8 @@ def get_sdss_data( # prepare 2D and 1D color spaces in coarse z-bins for fitting zbins = np.array( [ - [0.02, 0.08], - [0.08, 0.16], + [0.02, 0.1], + [0.1, 0.2], ] ) ############################################################################## @@ -315,13 +315,7 @@ def get_sdss_data( ############################################################################## ############################################################################## # prepare 1D app mag funcs in finer z-bins for fitting - fine_zbins = np.array( - [ - [0.02, 0.06], - [0.06, 0.1], - [0.1, 0.16], - ] - ) + fine_zbins = np.array([[0.02, 0.06], [0.06, 0.1], [0.1, 0.14], [0.14, 0.2]]) ############################################################################## AppMagFuncs = namedtuple( "AppMagFuncs", From c9c41690a00fd4745e7b680767c6638817e6f1ac Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Fri, 12 Jun 2026 16:12:54 -0500 Subject: [PATCH 46/57] feniks: include CMDs in loss along with the conditional 1D colors --- .../experimental/data_loaders/load_feniks.py | 16 ++++++++-------- scripts/config_diagnostics.yaml | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 69d1dba6..f3a300c2 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -533,7 +533,7 @@ def get_feniks_data( ) mag_idx = 7 col_idx = [2, 3] - K_ri = K_ri(mag_idx, col_idx, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri, N_K_ri, False) + K_ri = K_ri(mag_idx, col_idx, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri, N_K_ri, True) n_bins += bin_lo_K_ri.size # 2D (K, g - r) @@ -546,7 +546,7 @@ def get_feniks_data( ) mag_idx = 7 col_idx = [1, 2] - K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr, False) + K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr, True) n_bins += bin_lo_K_gr.size # 2D (K, J - H) @@ -559,7 +559,7 @@ def get_feniks_data( ) mag_idx = 7 col_idx = [5, 6] - K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH, False) + K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH, True) n_bins += bin_lo_K_JH.size z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, ri, iz, jh, K_ri, K_gr, K_JH) @@ -707,7 +707,7 @@ def get_feniks_data( ) mag_idx = 7 col_idx = [0, 1] - K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug, False) + K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug, True) n_bins += bin_lo_K_ug.size # 2D (K, r - z) @@ -720,7 +720,7 @@ def get_feniks_data( ) mag_idx = 7 col_idx = [2, 4] - K_rz = K_rz(mag_idx, col_idx, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz, N_K_rz, False) + K_rz = K_rz(mag_idx, col_idx, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz, N_K_rz, True) n_bins += bin_lo_K_rz.size z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh, K_ug, K_rz) @@ -896,7 +896,7 @@ def get_feniks_data( ) mag_idx = 7 col_idx = [0, 1] - K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug, False) + K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug, True) n_bins += bin_lo_K_ug.size # 2D (K, g - r) @@ -909,7 +909,7 @@ def get_feniks_data( ) mag_idx = 7 col_idx = [1, 2] - K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr, False) + K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr, True) n_bins += bin_lo_K_gr.size # 2D (K, J - H) @@ -922,7 +922,7 @@ def get_feniks_data( ) mag_idx = 7 col_idx = [5, 6] - K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH, False) + K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH, True) n_bins += bin_lo_K_JH.size z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh, K_ug, K_gr, K_JH) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 11989d3e..3dd4d8b6 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run177 -model_nickname: run177_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run177/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run179 +model_nickname: run179_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run179/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,7 +11,7 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: False +plot_sdss: True plot_feniks: True plot_hizels: False From 65603f523b572a00759fd6a9099b45a152e5f927 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Fri, 12 Jun 2026 18:50:14 -0500 Subject: [PATCH 47/57] feniks until z=2.5 --- .../experimental/data_loaders/load_feniks.py | 4 +- .../experimental/data_loaders/load_sdss.py | 1 + diffhtwo/experimental/defaults.py | 1 + scripts/config_diagnostics.yaml | 42 +++++++++---------- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index f3a300c2..13270537 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -343,7 +343,7 @@ def get_feniks_data( [ [0.2, 0.7], [0.7, 1.5], - [1.5, 2.0], + [1.5, 2.5], ] ) @@ -937,6 +937,7 @@ def get_feniks_data( [0.7, 1.0], [1.0, 1.5], [1.5, 2.0], + [2.0, 2.5], ] ) ############################################################################## @@ -1052,6 +1053,7 @@ def get_feniks_data( mags_labels, colors, app_mag_funcs, + fine_zbins, filter_info, frac_cat, lh_centroids, diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 83e152b3..bfce178b 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -401,6 +401,7 @@ def get_sdss_data( mag_labels, colors, app_mag_funcs, + fine_zbins, filter_info, frac_cat, lh_centroids, diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 2c53e829..776a2529 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -37,6 +37,7 @@ "mags_labels", "colors", "app_mag_funcs", + "fine_zbins", "filter_info", "frac_cat", "lh_centroids", diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 3dd4d8b6..19c361fe 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run179 -model_nickname: run179_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run179/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run169 +model_nickname: run169_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run169/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,28 +11,28 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: True +plot_sdss: False plot_feniks: True plot_hizels: False plots: num_halos : 3000 - plot_color_contours: True - plot_app_mag_funcs: True - plot_color_pdfs: True - plot_colors_mags: True - plot_mags: True - plot_ssperr: True - plot_massive_cen_colors: True - plot_merging_sat_colors: True + plot_color_contours: False + plot_app_mag_funcs: False + plot_color_pdfs: False + plot_colors_mags: False + plot_mags: False + plot_ssperr: False + plot_massive_cen_colors: False + plot_merging_sat_colors: False plot_smhm: True - plot_insitu_smhm: True - plot_insitu_sm: True - plot_sm: True - plot_uvj: True - plot_exsitu_frac: True - plot_avpop: True - plot_burstpop: True - plot_fburst_mh_z: True - plot_satquench_model: True + plot_insitu_smhm: False + plot_insitu_sm: False + plot_sm: False + plot_uvj: False + plot_exsitu_frac: False + plot_avpop: False + plot_burstpop: False + plot_fburst_mh_z: False + plot_satquench_model: False plot_satquench: False \ No newline at end of file From 57168a4a6bd3f37d04d7988843f55ae6c08af331 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Fri, 12 Jun 2026 23:10:00 -0500 Subject: [PATCH 48/57] use narrow-band colors with broadbands to detect absorption depth, targeting metallicity --- diffhtwo/experimental/data_loaders/N_utils.py | 53 ++ .../experimental/data_loaders/load_feniks.py | 650 +++++++----------- .../experimental/diagnostics/plot_phot.py | 2 + scripts/config_diagnostics.yaml | 40 +- scripts/generate_diagnostic_plots.py | 19 +- 5 files changed, 308 insertions(+), 456 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/N_utils.py b/diffhtwo/experimental/data_loaders/N_utils.py index e6dd4241..f7873b8b 100644 --- a/diffhtwo/experimental/data_loaders/N_utils.py +++ b/diffhtwo/experimental/data_loaders/N_utils.py @@ -1,7 +1,15 @@ +from collections import namedtuple + import jax.numpy as jnp import numpy as np from diffsky import diffndhist_lomem +from ..defaults import ( + ColorColor, + ColorCondMag, + MagColor, +) + def get_N_1d(dim1, dim1_bin_edges=None, dmag=0.2, sig_scale=0.5): dataset = dim1.reshape(dim1.size, 1) @@ -57,3 +65,48 @@ def get_N_2d(dim1, dim2, sig_scale=0.5): ) return N_2d, sig, bin_lo, bin_hi + + +def get_colorcolor_space(name, color1, color2, col_idx, z_sel, fit=True): + ColorColorSpace = namedtuple(name, ColorColor._fields) + + N_2d, sig, bin_lo, bin_hi = get_N_2d(color1[z_sel], color2[z_sel]) + + return ColorColorSpace(col_idx, sig, bin_lo, bin_hi, N_2d, fit) + + +def get_color_cond_space_list( + name, color, cond_mag, col_idx, cond_idx, z_sel, cond_dmag=2, fit=True +): + ColorCondSpace = namedtuple(name, ColorCondMag._fields) + cond_mag_bins = np.arange(cond_mag[z_sel].min(), cond_mag[z_sel].max(), cond_dmag) + + color_cond_list = [] + for b in range(len(cond_mag_bins) - 1): + cond_sel = (cond_mag[z_sel] > cond_mag_bins[b]) & ( + cond_mag[z_sel] <= cond_mag_bins[b + 1] + ) + N_1d, sig, bin_lo, bin_hi = get_N_1d(color[z_sel][cond_sel]) + color_cond_list.append( + ColorCondSpace( + col_idx, + cond_idx, + cond_mag_bins[b], + cond_mag_bins[b + 1], + sig, + bin_lo, + bin_hi, + N_1d, + fit, + ) + ) + + return color_cond_list + + +def get_mag_color_space(name, mag, color, mag_idx, col_idx, z_sel, fit=True): + MagColorSpace = namedtuple(name, MagColor._fields) + + N_2d, sig, bin_lo, bin_hi = get_N_2d(mag[z_sel], color[z_sel]) + + return MagColorSpace(mag_idx, col_idx, sig, bin_lo, bin_hi, N_2d, fit) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 13270537..5ba6f28e 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -15,16 +15,13 @@ FENIKS_Z_MAX, FENIKS_Z_MIN, AppMagFunc, - ColorColor, - ColorCondMag, Dataset, FilterInfo, - MagColor, ) from ..latin_hypercube import latin_hypercube as lh from ..lightcone_generators import generate_lc_data from ..utils import load_feniks_tcurve -from .N_utils import get_N_1d, get_N_2d +from . import N_utils BASE_PATH = Path(__file__).resolve().parent.parent FENIKS_FILTERS_PATH = BASE_PATH / "data" / "feniks_filters" @@ -166,10 +163,11 @@ def get_feniks_data( HSC_R=False, HSC_I=False, HSC_Z=False, - # VIDEO_Y=False, UDS_J=False, UDS_H=False, UDS_K=True, + NB0816=False, + NB0921=False, ) tcurves = [] for feniks_filter in FeniksFilters._fields: @@ -191,6 +189,8 @@ def get_feniks_data( uds_J = get_mag_ab(phot, "fcol_UDS_J") uds_H = get_mag_ab(phot, "fcol_UDS_H") uds_K = get_mag_ab(phot, "fcol_UDS_K") + nb816 = get_mag_ab(phot, "fcol_NB0816") + nb921 = get_mag_ab(phot, "fcol_NB0921") feniks_mag_thresh = FeniksFilters( MegaCam_uS=24.9, @@ -201,6 +201,8 @@ def get_feniks_data( UDS_J=24.5, UDS_H=24.3, UDS_K=FENIKS_MAGK_THRESH, + NB0816=25.3, + NB0921=25.3, ) filter_info = FilterInfo(feniks_mag_thresh, feniks_in_lh, tcurves) @@ -215,6 +217,8 @@ def get_feniks_data( & (uds_J < feniks_mag_thresh.UDS_J) & (uds_H < feniks_mag_thresh.UDS_H) & (uds_K < feniks_mag_thresh.UDS_K) + & (nb816 < feniks_mag_thresh.NB0816) + & (nb921 < feniks_mag_thresh.NB0921) ) # apply mag_thresh cuts and record n_gals. @@ -227,10 +231,11 @@ def get_feniks_data( hsc_r = hsc_r[mag_thresh] hsc_i = hsc_i[mag_thresh] hsc_z = hsc_z[mag_thresh] - # video_Y = video_Y[mag_thresh] uds_J = uds_J[mag_thresh] uds_H = uds_H[mag_thresh] uds_K = uds_K[mag_thresh] + nb816 = nb816[mag_thresh] + nb921 = nb921[mag_thresh] n_gals_pre_cuts = len(zout) @@ -241,10 +246,11 @@ def get_feniks_data( & (hsc_r != -99) & (hsc_i != -99) & (hsc_z != -99) - # & (video_Y != -99) & (uds_J != -99) & (uds_H != -99) & (uds_K != -99) + & (nb816 != -99) + & (nb921 != -99) ) phot = phot[clean] @@ -254,10 +260,11 @@ def get_feniks_data( hsc_r = hsc_r[clean] hsc_i = hsc_i[clean] hsc_z = hsc_z[clean] - # video_Y = video_Y[clean] uds_J = uds_J[clean] uds_H = uds_H[clean] uds_K = uds_K[clean] + nb816 = nb816[clean] + nb921 = nb921[clean] n_gals_post_cuts = len(zout) frac_cat = n_gals_post_cuts / n_gals_pre_cuts @@ -269,14 +276,28 @@ def get_feniks_data( hsc_r, hsc_i, hsc_z, - # video_Y, uds_J, uds_H, uds_K, + nb816, + nb921, zout["z_phot"], ) ).T + mags_labels = [ + r"$uS_{MegaCam}$", + r"$g_{HSC}$", + r"$r_{HSC}$", + r"$i_{HSC}$", + r"$z_{HSC}$", + r"$J_{UDS}$", + r"$H_{UDS}$", + r"$K_{UDS}$", + r"$NB816$", + r"$NB921$", + ] + # derive colors from mags megacam_hsc_uSg = megacam_uS - hsc_g hsc_gr = hsc_g - hsc_r @@ -286,6 +307,9 @@ def get_feniks_data( hsc_uds_zJ = hsc_z - uds_J uds_JH = uds_J - uds_H uds_HK = uds_H - uds_K + hsc_i816 = hsc_i - nb816 + hsc_z921 = hsc_z - nb921 + hsc_uds_rK = hsc_r - uds_K # stack colors_mag dataset = np.vstack( @@ -299,6 +323,8 @@ def get_feniks_data( uds_HK, megacam_uS, uds_K, + hsc_i816, + hsc_z921, zout["z_phot"], ) ).T @@ -314,39 +340,22 @@ def get_feniks_data( r"$H_{UDS} - K_{UDS}$", r"$uS_{MegaCam}$", r"$K_{UDS}$", + r"$i_{HSC} - NB816_{HSC}$", + r"$z_{HSC} - NB921_{HSC}$", r"$redshift$", ] - mags_labels = [ - r"$uS_{MegaCam}$", - r"$g_{HSC}$", - r"$r_{HSC}$", - r"$i_{HSC}$", - r"$z_{HSC}$", - # r"$Y_{VIDEO}$", - r"$J_{UDS}$", - r"$H_{UDS}$", - r"$K_{UDS}$", - ] - - # mask redshift - # z_mask = (zout["z_phot"] > FENIKS_Z_MIN) & (zout["z_phot"] <= FENIKS_Z_MAX) - # dataset = dataset[z_mask] - # mags = mags[z_mask] - # zout = zout[z_mask] - n_bins = 0 - ############################################################################## # prepare 2D and 1D color spaces in coarse z-bins for fitting zbins = np.array( [ [0.2, 0.7], - [0.7, 1.5], + [0.7, 1.0], + [1.0, 1.5], [1.5, 2.5], ] ) - ############################################################################## # Z1 spaces: # 2D (g - r, r - i) @@ -396,184 +405,75 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) # 2D (g - r, r - i) - Gr_ri = namedtuple("Gr_ri", ColorColor._fields) - mag_sel_gr_ri = ( - (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) - & (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) - & (hsc_i[z_sel] < feniks_mag_thresh.HSC_I) - ) - N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri = get_N_2d( - hsc_gr[z_sel][mag_sel_gr_ri], hsc_ri[z_sel][mag_sel_gr_ri] + gr_ri = N_utils.get_colorcolor_space( + "Gr_ri", hsc_gr, hsc_ri, [1, 2, 3], z_sel, fit=True ) - col_idx = [1, 2, 3] - gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri, True) - n_bins += bin_lo_gr_ri.size # 1D (u - g | K) - Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) - - ug = [] - Ug_condK = namedtuple("Ug_condK", ColorCondMag._fields) - mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( - hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ug = N_utils.get_color_cond_space_list( + "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [0, 1] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( - megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] - ) - ug.append( - Ug_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_ug, - bin_lo_ug, - bin_hi_ug, - N_1d_ug, - True, - ) - ) - n_bins += bin_lo_ug.size # 1D (r − i | K) - ri = [] - Ri_condK = namedtuple("Ri_condK", ColorCondMag._fields) - mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( - hsc_i[z_sel] < feniks_mag_thresh.HSC_I + ri = N_utils.get_color_cond_space_list( + "Ri_condK", hsc_ri, uds_K, [2, 3], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [2, 3] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ri, sig_ri, bin_lo_ri, bin_hi_ri = get_N_1d( - hsc_ri[z_sel][mag_sel_ri & K_sel] - ) - ri.append( - Ri_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_ri, - bin_lo_ri, - bin_hi_ri, - N_1d_ri, - True, - ) - ) - n_bins += bin_lo_ri.size # 1D (i − z | K) - iz = [] - Iz_condK = namedtuple("Iz_condK", ColorCondMag._fields) - mag_sel_iz = (hsc_i[z_sel] < feniks_mag_thresh.HSC_I) & ( - hsc_z[z_sel] < feniks_mag_thresh.HSC_Z + iz = N_utils.get_color_cond_space_list( + "Iz_condK", hsc_iz, uds_K, [3, 4], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [3, 4] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_iz, sig_iz, bin_lo_iz, bin_hi_iz = get_N_1d( - hsc_iz[z_sel][mag_sel_iz & K_sel] - ) - iz.append( - Iz_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_iz, - bin_lo_iz, - bin_hi_iz, - N_1d_iz, - True, - ) - ) - n_bins += bin_lo_iz.size # 1D (J − H | K) - jh = [] - JH_condK = namedtuple("JH_condK", ColorCondMag._fields) - mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - uds_H[z_sel] < feniks_mag_thresh.UDS_H + jh = N_utils.get_color_cond_space_list( + "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [5, 6] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( - uds_JH[z_sel][mag_sel_jh & K_sel] - ) - jh.append( - JH_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_jh, - bin_lo_jh, - bin_hi_jh, - N_1d_jh, - True, - ) - ) - n_bins += bin_lo_jh.size # 2D (K, r - i) - K_ri = namedtuple("K_ri", MagColor._fields) - mag_sel_ri = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( - hsc_i[z_sel] < feniks_mag_thresh.HSC_I - ) - N_K_ri, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri = get_N_2d( - uds_K[z_sel][mag_sel_ri], hsc_ri[z_sel][mag_sel_ri] + K_ri = N_utils.get_mag_color_space( + "K_ri", uds_K, hsc_ri, 7, [2, 3], z_sel, fit=True ) - mag_idx = 7 - col_idx = [2, 3] - K_ri = K_ri(mag_idx, col_idx, sig_K_ri, bin_lo_K_ri, bin_hi_K_ri, N_K_ri, True) - n_bins += bin_lo_K_ri.size # 2D (K, g - r) - K_gr = namedtuple("K_gr", MagColor._fields) - mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( - hsc_r[z_sel] < feniks_mag_thresh.HSC_R + K_gr = N_utils.get_mag_color_space( + "K_gr", uds_K, hsc_gr, 7, [1, 2], z_sel, fit=True ) - N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( - uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr] - ) - mag_idx = 7 - col_idx = [1, 2] - K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr, True) - n_bins += bin_lo_K_gr.size # 2D (K, J - H) - K_JH = namedtuple("K_JH", MagColor._fields) - mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - uds_H[z_sel] < feniks_mag_thresh.UDS_H - ) - N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( - uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH] + K_JH = N_utils.get_mag_color_space( + "K_JH", uds_K, uds_JH, 7, [5, 6], z_sel, fit=True ) - mag_idx = 7 - col_idx = [5, 6] - K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH, True) - n_bins += bin_lo_K_JH.size z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, ri, iz, jh, K_ri, K_gr, K_JH) colors.append(z1) ############################################################################## - # Z2 spaces: + # Z2a spaces: # 2D (r - z, z - J) # 2D (K, u - g) # 2D (K, r - z) + # 2D (i - NB816, g - r) -- metallicity vs age + # 2D (i - NB816, r - K) -- metallicity vs mass-to-light + # 1D (i - NB816 | K) -- metallicity at fixed mass + # 1D (z - NB921 | K) -- cross-check metallicity at fixed mass - Z2 = namedtuple( - "Z2", - ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh", "K_ug", "K_rz"], + Z2a = namedtuple( + "Z2a", + [ + "z_min", + "z_max", + "lc_data", + "rz_zJ", + "ug", + "rz", + "jh", + "K_ug", + "K_rz", + "iNB816_gr", + "iNB816_rK", + "iNB816_condK", + "zNB921_condK", + ], ) zbin = 1 z_min = zbins[zbin][0] @@ -600,131 +500,152 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) # 2D (r - z, z - J) - Rz_zJ = namedtuple("Rz_zJ", ColorColor._fields) - mag_sel_rz_zJ = ( - (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) - & (hsc_z[z_sel] < feniks_mag_thresh.HSC_Z) - & (uds_J[z_sel] < feniks_mag_thresh.UDS_J) - ) - N_rz_zJ, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ = get_N_2d( - hsc_rz[z_sel][mag_sel_rz_zJ], hsc_uds_zJ[z_sel][mag_sel_rz_zJ] + rz_zJ = N_utils.get_colorcolor_space( + "Rz_zJ", hsc_rz, hsc_uds_zJ, [2, 4, 5], z_sel, fit=True ) - col_idx = [2, 4, 5] - rz_zJ = Rz_zJ(col_idx, sig_rz_zJ, bin_lo_rz_zJ, bin_hi_rz_zJ, N_rz_zJ, True) - n_bins += bin_lo_rz_zJ.size # 1D (u - g | K) - Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 2) - - ug = [] - mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( - hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ug = N_utils.get_color_cond_space_list( + "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [0, 1] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( - megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] - ) - ug.append( - Ug_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_ug, - bin_lo_ug, - bin_hi_ug, - N_1d_ug, - True, - ) - ) - n_bins += bin_lo_ug.size # 1D (r - z | K) - rz = [] - Rz_condK = namedtuple("Rz_condK", ColorCondMag._fields) - mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( - hsc_z[z_sel] < feniks_mag_thresh.HSC_Z + rz = N_utils.get_color_cond_space_list( + "Rz_condK", hsc_rz, uds_K, [2, 4], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [2, 4] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_rz, sig_rz, bin_lo_rz, bin_hi_rz = get_N_1d( - hsc_rz[z_sel][mag_sel_rz & K_sel] - ) - rz.append( - Rz_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_rz, - bin_lo_rz, - bin_hi_rz, - N_1d_rz, - True, - ) - ) - n_bins += bin_lo_rz.size # 1D (J − H | K) - jh = [] - mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - uds_H[z_sel] < feniks_mag_thresh.UDS_H + jh = N_utils.get_color_cond_space_list( + "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [5, 6] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( - uds_JH[z_sel][mag_sel_jh & K_sel] - ) - jh.append( - JH_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_jh, - bin_lo_jh, - bin_hi_jh, - N_1d_jh, - True, - ) - ) - n_bins += bin_lo_jh.size # 2D (K, u - g) - K_ug = namedtuple("K_ug", MagColor._fields) - mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( - hsc_g[z_sel] < feniks_mag_thresh.HSC_G + K_ug = N_utils.get_mag_color_space( + "K_ug", uds_K, megacam_hsc_uSg, 7, [0, 1], z_sel, fit=True + ) + + # 2D (K, r - z) + K_rz = N_utils.get_mag_color_space( + "K_rz", uds_K, hsc_rz, 7, [2, 4], z_sel, fit=True + ) + + # 2D (i - NB816, g - r) + i816_gr = N_utils.get_colorcolor_space( + "I816_gr", hsc_i816, hsc_gr, [3, 8, 1, 2], z_sel, fit=True + ) + + # 2D (i - NB816, r - K) + i816_rK = N_utils.get_colorcolor_space( + "I816_rK", hsc_i816, hsc_uds_rK, [3, 8, 2, 7], z_sel, fit=True ) - N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( - uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug] + + # 1D (i - NB816 | K) + i816_condK = N_utils.get_color_cond_space_list( + "I816_condK", + hsc_i816, + uds_K, + [3, 8], + 7, + z_sel, + cond_dmag=2, + fit=True, + ) + + # 1D (z - NB921 | K) + z921_condK = N_utils.get_color_cond_space_list( + "Z921_condK", + hsc_z921, + uds_K, + [4, 9], + 7, + z_sel, + cond_dmag=2, + fit=True, ) - mag_idx = 7 - col_idx = [0, 1] - K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug, True) - n_bins += bin_lo_K_ug.size + z2a = Z2a( + z_min, + z_max, + lc_data, + rz_zJ, + ug, + rz, + jh, + K_ug, + K_rz, + i816_gr, + i816_rK, + i816_condK, + z921_condK, + ) + colors.append(z2a) + + ############################################################################## + # Z2b spaces: + # 2D (r - z, z - J) + # 2D (K, u - g) # 2D (K, r - z) - K_rz = namedtuple("K_rz", MagColor._fields) - mag_sel_rz = (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) & ( - hsc_z[z_sel] < feniks_mag_thresh.HSC_Z + + Z2b = namedtuple( + "Z2b", + ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh", "K_ug", "K_rz"], + ) + zbin = 2 + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_coarse_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + + # 2D (r - z, z - J) + rz_zJ = N_utils.get_colorcolor_space( + "Rz_zJ", hsc_rz, hsc_uds_zJ, [2, 4, 5], z_sel, fit=True + ) + + # 1D (u - g | K) + ug = N_utils.get_color_cond_space_list( + "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=2, fit=True + ) + + # 1D (r - z | K) + rz = N_utils.get_color_cond_space_list( + "Rz_condK", hsc_rz, uds_K, [2, 4], 7, z_sel, cond_dmag=2, fit=True + ) + + # 1D (J − H | K) + jh = N_utils.get_color_cond_space_list( + "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=2, fit=True + ) + + # 2D (K, u - g) + K_ug = N_utils.get_mag_color_space( + "K_ug", uds_K, megacam_hsc_uSg, 7, [0, 1], z_sel, fit=True ) - N_K_rz, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz = get_N_2d( - uds_K[z_sel][mag_sel_rz], hsc_rz[z_sel][mag_sel_rz] + + # 2D (K, r - z) + K_rz = N_utils.get_mag_color_space( + "K_rz", uds_K, hsc_rz, 7, [2, 4], z_sel, fit=True ) - mag_idx = 7 - col_idx = [2, 4] - K_rz = K_rz(mag_idx, col_idx, sig_K_rz, bin_lo_K_rz, bin_hi_K_rz, N_K_rz, True) - n_bins += bin_lo_K_rz.size - z2 = Z2(z_min, z_max, lc_data, rz_zJ, ug, rz, jh, K_ug, K_rz) - colors.append(z2) + z2b = Z2b(z_min, z_max, lc_data, rz_zJ, ug, rz, jh, K_ug, K_rz) + colors.append(z2b) ############################################################################## # Z3 spaces: @@ -750,7 +671,7 @@ def get_feniks_data( "K_JH", ], ) - zbin = 2 + zbin = 3 z_min = zbins[zbin][0] z_max = zbins[zbin][1] @@ -775,155 +696,44 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) # 2D (z - J, J - H) - zJ_JH = namedtuple("zJ_JH", ColorColor._fields) - mag_sel_zJ_JH = ( - (hsc_z[z_sel] < feniks_mag_thresh.HSC_Z) - & (uds_J[z_sel] < feniks_mag_thresh.UDS_J) - & (uds_H[z_sel] < feniks_mag_thresh.UDS_H) + zJ_JH = N_utils.get_colorcolor_space( + "ZJ_JH", hsc_uds_zJ, uds_JH, [4, 5, 6], z_sel, fit=True ) - N_zJ_JH, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH = get_N_2d( - hsc_uds_zJ[z_sel][mag_sel_zJ_JH], uds_JH[z_sel][mag_sel_zJ_JH] - ) - col_idx = [4, 5, 6] - zJ_JH = zJ_JH(col_idx, sig_zJ_JH, bin_lo_zJ_JH, bin_hi_zJ_JH, N_zJ_JH, True) - n_bins += bin_lo_zJ_JH.size # 2D (u - g, g - r) - Ug_gr = namedtuple("Ug_gr", ColorColor._fields) - mag_sel_ugr = ( - (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) - & (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) - & (hsc_r[z_sel] < feniks_mag_thresh.HSC_R) - ) - N_ug_gr, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr = get_N_2d( - megacam_hsc_uSg[z_sel][mag_sel_ugr], hsc_gr[z_sel][mag_sel_ugr] + ug_gr = N_utils.get_colorcolor_space( + "Ug_gr", megacam_hsc_uSg, hsc_gr, [0, 1, 2], z_sel, fit=True ) - col_idx = [0, 1, 2] - ug_gr = Ug_gr(col_idx, sig_ug_gr, bin_lo_ug_gr, bin_hi_ug_gr, N_ug_gr, True) - n_bins += bin_lo_ug_gr.size # 1D (u - g | K) - Kbins = np.arange(uds_K[z_sel].min(), uds_K[z_sel].max(), 4) - - ug = [] - mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( - hsc_g[z_sel] < feniks_mag_thresh.HSC_G + ug = N_utils.get_color_cond_space_list( + "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [0, 1] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_ug, sig_ug, bin_lo_ug, bin_hi_ug = get_N_1d( - megacam_hsc_uSg[z_sel][mag_sel_ug & K_sel] - ) - ug.append( - Ug_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_ug, - bin_lo_ug, - bin_hi_ug, - N_1d_ug, - True, - ) - ) - n_bins += bin_lo_ug.size # 1D (g - r | K) - gr = [] - Gr_condK = namedtuple("Gr_condK", ColorCondMag._fields) - mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( - hsc_r[z_sel] < feniks_mag_thresh.HSC_R + gr = N_utils.get_color_cond_space_list( + "Gr_condK", hsc_gr, uds_K, [1, 2], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [1, 2] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_gr, sig_gr, bin_lo_gr, bin_hi_gr = get_N_1d( - hsc_gr[z_sel][mag_sel_gr & K_sel] - ) - gr.append( - Gr_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_gr, - bin_lo_gr, - bin_hi_gr, - N_1d_gr, - True, - ) - ) - n_bins += bin_lo_gr.size # 1D (J − H | K) - jh = [] - mag_sel_jh = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - uds_H[z_sel] < feniks_mag_thresh.UDS_H + jh = N_utils.get_color_cond_space_list( + "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=2, fit=True ) - col_idx = [5, 6] - cond_idx = 7 - for k in range(len(Kbins) - 1): - K_sel = (uds_K[z_sel] > Kbins[k]) & (uds_K[z_sel] <= Kbins[k + 1]) - N_1d_jh, sig_jh, bin_lo_jh, bin_hi_jh = get_N_1d( - uds_JH[z_sel][mag_sel_jh & K_sel] - ) - jh.append( - JH_condK( - col_idx, - cond_idx, - Kbins[k], - Kbins[k + 1], - sig_jh, - bin_lo_jh, - bin_hi_jh, - N_1d_jh, - True, - ) - ) - n_bins += bin_lo_jh.size # 2D (K, u - g) - K_ug = namedtuple("K_ug", MagColor._fields) - mag_sel_ug = (megacam_uS[z_sel] < feniks_mag_thresh.MegaCam_uS) & ( - hsc_g[z_sel] < feniks_mag_thresh.HSC_G - ) - N_K_ug, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug = get_N_2d( - uds_K[z_sel][mag_sel_ug], megacam_hsc_uSg[z_sel][mag_sel_ug] + K_ug = N_utils.get_mag_color_space( + "K_ug", uds_K, megacam_hsc_uSg, 7, [0, 1], z_sel, fit=True ) - mag_idx = 7 - col_idx = [0, 1] - K_ug = K_ug(mag_idx, col_idx, sig_K_ug, bin_lo_K_ug, bin_hi_K_ug, N_K_ug, True) - n_bins += bin_lo_K_ug.size # 2D (K, g - r) - K_gr = namedtuple("K_gr", MagColor._fields) - mag_sel_gr = (hsc_g[z_sel] < feniks_mag_thresh.HSC_G) & ( - hsc_r[z_sel] < feniks_mag_thresh.HSC_R + K_gr = N_utils.get_mag_color_space( + "K_gr", uds_K, hsc_gr, 7, [1, 2], z_sel, fit=True ) - N_K_gr, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr = get_N_2d( - uds_K[z_sel][mag_sel_gr], hsc_gr[z_sel][mag_sel_gr] - ) - mag_idx = 7 - col_idx = [1, 2] - K_gr = K_gr(mag_idx, col_idx, sig_K_gr, bin_lo_K_gr, bin_hi_K_gr, N_K_gr, True) - n_bins += bin_lo_K_gr.size # 2D (K, J - H) - K_JH = namedtuple("K_JH", MagColor._fields) - mag_sel_JH = (uds_J[z_sel] < feniks_mag_thresh.UDS_J) & ( - uds_H[z_sel] < feniks_mag_thresh.UDS_H - ) - N_K_JH, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH = get_N_2d( - uds_K[z_sel][mag_sel_JH], uds_JH[z_sel][mag_sel_JH] + K_JH = N_utils.get_mag_color_space( + "K_JH", uds_K, uds_JH, 7, [5, 6], z_sel, fit=True ) - mag_idx = 7 - col_idx = [5, 6] - K_JH = K_JH(mag_idx, col_idx, sig_K_JH, bin_lo_K_JH, bin_hi_K_JH, N_K_JH, True) - n_bins += bin_lo_K_JH.size z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh, K_ug, K_gr, K_JH) colors.append(z3) @@ -982,49 +792,49 @@ def get_feniks_data( # 1D (u) mag_idx_u = 0 - N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(megacam_uS[z_sel]) + N_1d_u, sig_u, bin_lo_u, bin_hi_u = N_utils.get_N_1d(megacam_uS[z_sel]) u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u, True) n_bins += bin_lo_u.size # 1D (g) mag_idx_g = 1 - N_1d_g, sig_g, bin_lo_g, bin_hi_g = get_N_1d(hsc_g[z_sel]) + N_1d_g, sig_g, bin_lo_g, bin_hi_g = N_utils.get_N_1d(hsc_g[z_sel]) g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, True) n_bins += bin_lo_g.size # 1D (r) mag_idx_r = 2 - N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(hsc_r[z_sel]) + N_1d_r, sig_r, bin_lo_r, bin_hi_r = N_utils.get_N_1d(hsc_r[z_sel]) r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r, True) n_bins += bin_lo_r.size # 1D (i) mag_idx_i = 3 - N_1d_i, sig_i, bin_lo_i, bin_hi_i = get_N_1d(hsc_i[z_sel]) + N_1d_i, sig_i, bin_lo_i, bin_hi_i = N_utils.get_N_1d(hsc_i[z_sel]) i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, True) n_bins += bin_lo_i.size # 1D (z) mag_idx_z = 4 - N_1d_z, sig_z, bin_lo_z, bin_hi_z = get_N_1d(hsc_z[z_sel]) + N_1d_z, sig_z, bin_lo_z, bin_hi_z = N_utils.get_N_1d(hsc_z[z_sel]) z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, True) n_bins += bin_lo_z.size # 1D (J) mag_idx_j = 5 - N_1d_j, sig_j, bin_lo_j, bin_hi_j = get_N_1d(uds_J[z_sel]) + N_1d_j, sig_j, bin_lo_j, bin_hi_j = N_utils.get_N_1d(uds_J[z_sel]) j = J(mag_idx_j, sig_j, bin_lo_j, bin_hi_j, N_1d_j, True) n_bins += bin_lo_j.size # 1D (H) mag_idx_h = 6 - N_1d_h, sig_h, bin_lo_h, bin_hi_h = get_N_1d(uds_H[z_sel]) + N_1d_h, sig_h, bin_lo_h, bin_hi_h = N_utils.get_N_1d(uds_H[z_sel]) h = H(mag_idx_h, sig_h, bin_lo_h, bin_hi_h, N_1d_h, True) n_bins += bin_lo_h.size # 1D (K) mag_idx_k = 7 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = get_N_1d(uds_K[z_sel]) + N_1d_k, sig_k, bin_lo_k, bin_hi_k = N_utils.get_N_1d(uds_K[z_sel]) k = K(mag_idx_k, sig_k, bin_lo_k, bin_hi_k, N_1d_k, True) n_bins += bin_lo_k.size @@ -1076,5 +886,7 @@ def get_feniks_data( "UDS_J", "UDS_H", "UDS_K", + "NB0816", + "NB0921", ], ) diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index f8f96970..1c82a78c 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -615,6 +615,8 @@ def plot_app_mag_funcs( "#c8b44a", "#c87820", ] + elif len(labels_z) == 6: + colors_z = ["#001219", "#0a7a80", "#80cca8", "#c8b44a", "#c87820", "#9b1d20"] n_bands = dataset_mags.shape[1] - 1 if n_bands <= 5: diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 19c361fe..c822d62d 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run169 -model_nickname: run169_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run169/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run182 +model_nickname: run182_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run182/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -17,22 +17,22 @@ plot_hizels: False plots: num_halos : 3000 - plot_color_contours: False - plot_app_mag_funcs: False - plot_color_pdfs: False - plot_colors_mags: False - plot_mags: False - plot_ssperr: False - plot_massive_cen_colors: False - plot_merging_sat_colors: False + plot_color_contours: True + plot_app_mag_funcs: True + plot_color_pdfs: True + plot_colors_mags: True + plot_mags: True + plot_ssperr: True + plot_massive_cen_colors: True + plot_merging_sat_colors: True plot_smhm: True - plot_insitu_smhm: False - plot_insitu_sm: False - plot_sm: False - plot_uvj: False - plot_exsitu_frac: False - plot_avpop: False - plot_burstpop: False - plot_fburst_mh_z: False - plot_satquench_model: False + plot_insitu_smhm: True + plot_insitu_sm: True + plot_sm: True + plot_uvj: True + plot_exsitu_frac: True + plot_avpop: True + plot_burstpop: True + plot_fburst_mh_z: True + plot_satquench_model: True plot_satquench: False \ No newline at end of file diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index ff7d914a..ffca208d 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -213,15 +213,7 @@ num_halos_fine_zbins=int(num_halos / 2), ) - feniks_zbins = np.array( - [ - [0.2, 0.5], - [0.5, 0.7], - [0.7, 1.0], - [1.0, 1.5], - [1.5, 2.0], - ] - ) + feniks_zbins = feniks.fine_zbins if cfg["plots"]["plot_smhm"]: print("Generating FENIKS SMHM plots...") @@ -505,14 +497,7 @@ num_halos_coarse_zbins=num_halos, num_halos_fine_zbins=int(num_halos / 2), ) - sdss_zbins = np.array( - [ - [0.02, 0.06], - [0.06, 0.1], - [0.1, 0.16], - [0.16, 0.20], - ] - ) + sdss_zbins = sdss.fine_zbins if cfg["plots"]["plot_smhm"]: print("Generating SDSS SMHM plots...") From ee91f28c9d465ff1fd8421ca61a731a64dde2977 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 12:51:06 -0500 Subject: [PATCH 49/57] fix bug in compute_cat_weights: obs_mags was being used instead of obs_mags_weighted --- diffhtwo/experimental/conftest.py | 5 +- .../experimental/data_loaders/load_feniks.py | 24 +++++--- diffhtwo/experimental/kernels/N_phot.py | 61 +++++++------------ diffhtwo/experimental/kernels/cat_weights.py | 18 +++--- diffhtwo/experimental/kernels/lc_phot_kern.py | 3 +- diffhtwo/experimental/kernels/phot_kern.py | 7 +-- .../experimental/kernels/tests/test_N_phot.py | 25 ++++++++ .../kernels/tests/test_compute_cat_weights.py | 13 +++- diffhtwo/experimental/utils.py | 35 +++++++++++ scripts/config_diagnostics.yaml | 8 +-- 10 files changed, 126 insertions(+), 73 deletions(-) diff --git a/diffhtwo/experimental/conftest.py b/diffhtwo/experimental/conftest.py index 496af452..a38d4f82 100644 --- a/diffhtwo/experimental/conftest.py +++ b/diffhtwo/experimental/conftest.py @@ -1,7 +1,6 @@ from pathlib import Path import jax.numpy as jnp -import numpy as np import pytest from dsps.data_loaders import load_emline_info as lemi from dsps.data_loaders import retrieve_fake_fsps_data @@ -42,15 +41,13 @@ def fake_subset_ssp_data(): def feniks(ran_key, fake_subset_ssp_data): ssp_data, emline_wave_aa = fake_subset_ssp_data - mag_bin_edges = np.array([18, 25]) - feniks = load_feniks.get_feniks_data( FENIKS_DRN, ran_key, ssp_data, phot=PHOT, zout=ZOUT, - mag_bin_edges=mag_bin_edges, + add_random_rows_for_testing=True, ) return feniks diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 5ba6f28e..319846c6 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -20,7 +20,7 @@ ) from ..latin_hypercube import latin_hypercube as lh from ..lightcone_generators import generate_lc_data -from ..utils import load_feniks_tcurve +from ..utils import add_random_rows, load_feniks_tcurve from . import N_utils BASE_PATH = Path(__file__).resolve().parent.parent @@ -153,7 +153,7 @@ def get_feniks_data( lgmp_max=15.0, lc_sky_area_degsq=100, n_z_phot_table=30, - mag_bin_edges=None, + add_random_rows_for_testing=False, ): # Transmission curves and filter mag thresholds @@ -179,6 +179,10 @@ def get_feniks_data( phot = ascii.read(drn_path / phot) zout = ascii.read(drn_path / zout) + if add_random_rows_for_testing: + phot = add_random_rows(phot, N=10000) + zout = add_random_rows(zout, N=10000) + # get mags megacam_uS = get_mag_ab(phot, "fcol_MegaCam_uS") hsc_g = get_mag_ab(phot, "fcol_HSC_G") @@ -406,7 +410,7 @@ def get_feniks_data( # 2D (g - r, r - i) gr_ri = N_utils.get_colorcolor_space( - "Gr_ri", hsc_gr, hsc_ri, [1, 2, 3], z_sel, fit=True + "Gr_ri", hsc_gr, hsc_ri, [1, 2, 2, 3], z_sel, fit=True ) # 1D (u - g | K) @@ -501,7 +505,7 @@ def get_feniks_data( # 2D (r - z, z - J) rz_zJ = N_utils.get_colorcolor_space( - "Rz_zJ", hsc_rz, hsc_uds_zJ, [2, 4, 5], z_sel, fit=True + "Rz_zJ", hsc_rz, hsc_uds_zJ, [2, 4, 4, 5], z_sel, fit=True ) # 1D (u - g | K) @@ -616,7 +620,7 @@ def get_feniks_data( # 2D (r - z, z - J) rz_zJ = N_utils.get_colorcolor_space( - "Rz_zJ", hsc_rz, hsc_uds_zJ, [2, 4, 5], z_sel, fit=True + "Rz_zJ", hsc_rz, hsc_uds_zJ, [2, 4, 4, 5], z_sel, fit=True ) # 1D (u - g | K) @@ -697,27 +701,27 @@ def get_feniks_data( # 2D (z - J, J - H) zJ_JH = N_utils.get_colorcolor_space( - "ZJ_JH", hsc_uds_zJ, uds_JH, [4, 5, 6], z_sel, fit=True + "ZJ_JH", hsc_uds_zJ, uds_JH, [4, 5, 5, 6], z_sel, fit=True ) # 2D (u - g, g - r) ug_gr = N_utils.get_colorcolor_space( - "Ug_gr", megacam_hsc_uSg, hsc_gr, [0, 1, 2], z_sel, fit=True + "Ug_gr", megacam_hsc_uSg, hsc_gr, [0, 1, 1, 2], z_sel, fit=True ) # 1D (u - g | K) ug = N_utils.get_color_cond_space_list( - "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=2, fit=True + "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=4, fit=True ) # 1D (g - r | K) gr = N_utils.get_color_cond_space_list( - "Gr_condK", hsc_gr, uds_K, [1, 2], 7, z_sel, cond_dmag=2, fit=True + "Gr_condK", hsc_gr, uds_K, [1, 2], 7, z_sel, cond_dmag=4, fit=True ) # 1D (J − H | K) jh = N_utils.get_color_cond_space_list( - "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=2, fit=True + "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=4, fit=True ) # 2D (K, u - g) diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index 3429e357..b68604ad 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -17,7 +17,7 @@ def N_colors_mags( mag_thresh, frac_cat, ): - obs_mags, gal_weight, phot_kern_results = mag_kern( + obs_mags_weighted, gal_weight, phot_kern_results = mag_kern( ran_key, param_collection, z_data.lc_data, @@ -37,18 +37,15 @@ def N_colors_mags( col_idx = space_n.col_idx # get cond weight - obs_mags_cond = obs_mags[:, space_n.cond_idx] - cond = (obs_mags_cond > space_n.cond_min) & ( - obs_mags_cond <= space_n.cond_max + obs_mags_weighted_cond = obs_mags_weighted[:, space_n.cond_idx] + cond = (obs_mags_weighted_cond > space_n.cond_min) & ( + obs_mags_weighted_cond <= space_n.cond_max ) weight = jnp.where(cond, gal_weight, 0.0) - # get mag_sel weight - for c in range(0, len(col_idx)): - mag_sel = obs_mags[:, col_idx[c]] < mag_thresh[col_idx[c]] - weight *= jnp.where(mag_sel, 1.0, 0.0) - - obs_color = obs_mags[:, col_idx[0]] - obs_mags[:, col_idx[1]] + obs_color = ( + obs_mags_weighted[:, col_idx[0]] - obs_mags_weighted[:, col_idx[1]] + ) obs_color = obs_color.reshape(obs_color.size, 1) N_model = diffndhist_lomem.tw_ndhist_weighted( @@ -71,21 +68,16 @@ def N_colors_mags( col_idx = space.col_idx mag_idx = space.mag_idx - mag = obs_mags[:, mag_idx] - obs_color = obs_mags[:, col_idx[0]] - obs_mags[:, col_idx[1]] - obs_mag_color = jnp.vstack((mag, obs_color)).T - - mag_sel = ( - (obs_mags[:, mag_idx] < mag_thresh[mag_idx]) - & (obs_mags[:, col_idx[0]] < mag_thresh[col_idx[0]]) - & (obs_mags[:, col_idx[1]] < mag_thresh[col_idx[1]]) + mag = obs_mags_weighted[:, mag_idx] + obs_color = ( + obs_mags_weighted[:, col_idx[0]] - obs_mags_weighted[:, col_idx[1]] ) - weight = jnp.where(mag_sel, gal_weight, 0.0) + obs_mag_color = jnp.vstack((mag, obs_color)).T N_model = diffndhist_lomem.tw_ndhist_weighted( obs_mag_color, space.sig, - weight, + gal_weight, space.bin_lo, space.bin_hi, ) @@ -96,17 +88,13 @@ def N_colors_mags( else: # Apparent Magnitude space mag_idx = space.mag_idx - obs_mag = obs_mags[:, mag_idx] + obs_mag = obs_mags_weighted[:, mag_idx] obs_mag = obs_mag.reshape(obs_mag.size, 1) - # get mag_sel weight - mag_sel = obs_mags[:, mag_idx] < mag_thresh[mag_idx] - weight = jnp.where(mag_sel, gal_weight, 0.0) - N_model = diffndhist_lomem.tw_ndhist_weighted( obs_mag, space.sig, - weight, + gal_weight, space.bin_lo, space.bin_hi, ) @@ -119,21 +107,18 @@ def N_colors_mags( # Color-Color space col_idx = space.col_idx obs_colors = [] - for c in range(0, len(col_idx) - 1): - obs_color = obs_mags[:, col_idx[c]] - obs_mags[:, col_idx[c + 1]] + for c in range(0, len(col_idx) - 1, 2): + obs_color = ( + obs_mags_weighted[:, col_idx[c]] + - obs_mags_weighted[:, col_idx[c + 1]] + ) obs_colors.append(obs_color) obs_colors = jnp.array(obs_colors).T - # get mag_sel weight - weight = gal_weight.copy() - for c in range(0, len(col_idx)): - mag_sel = obs_mags[:, col_idx[c]] < mag_thresh[col_idx[c]] - weight *= jnp.where(mag_sel, 1.0, 0.0) - N_model = diffndhist_lomem.tw_ndhist_weighted( obs_colors, space.sig, - weight, + gal_weight, space.bin_lo, space.bin_hi, ) @@ -155,7 +140,7 @@ def N_mags_1d( frac_cat, sig_scale=0.5, ): - obs_mags, gal_weight, phot_kern_results = mag_kern( + obs_mags_weighted, gal_weight, phot_kern_results = mag_kern( ran_key, param_collection, lc_data, @@ -163,10 +148,10 @@ def N_mags_1d( frac_cat, ) - n_gals, n_bands = obs_mags.shape + n_gals, n_bands = obs_mags_weighted.shape N_bands = [] for band in range(0, n_bands): - mags = obs_mags[:, band].reshape(obs_mags[:, band].size, 1) + mags = obs_mags_weighted[:, band].reshape(obs_mags_weighted[:, band].size, 1) magbin_edges = magbin_bands[band] diff --git a/diffhtwo/experimental/kernels/cat_weights.py b/diffhtwo/experimental/kernels/cat_weights.py index 9d3dec81..bff120d9 100644 --- a/diffhtwo/experimental/kernels/cat_weights.py +++ b/diffhtwo/experimental/kernels/cat_weights.py @@ -3,16 +3,12 @@ @jjit -def compute_cat_weights(weights, phot_kern_results, mag_thresh, frac_cat): - obs_mags = phot_kern_results.obs_mags - n_gals, n_bands = obs_mags.shape - mag_thresh_mask = jnp.ones((n_gals,), dtype=bool) +def compute_cat_weights(weights, obs_mags_weighted, mag_thresh, frac_cat): + mag_thresh = jnp.array(mag_thresh) + mag_thresh_mask = obs_mags_weighted[:, 0] < mag_thresh[0] - for band in range(0, n_bands): - if mag_thresh[band] is not None: - band_mag_thresh_mask = obs_mags[:, band] < mag_thresh[band] - mag_thresh_mask *= band_mag_thresh_mask + n_gals, n_bands = obs_mags_weighted.shape + for band in range(1, n_bands): + mag_thresh_mask *= obs_mags_weighted[:, band] < mag_thresh[band] - weights = weights * jnp.where(mag_thresh_mask, frac_cat, 0.0) - - return weights + return weights * jnp.where(mag_thresh_mask, frac_cat, 0.0) diff --git a/diffhtwo/experimental/kernels/lc_phot_kern.py b/diffhtwo/experimental/kernels/lc_phot_kern.py index 603c6981..8a333026 100644 --- a/diffhtwo/experimental/kernels/lc_phot_kern.py +++ b/diffhtwo/experimental/kernels/lc_phot_kern.py @@ -44,11 +44,12 @@ def multiband_lc_phot_kern( param_collection, lc_data, ) + obs_mags_weighted = phot_kern_results.obs_mags_weighted gal_weight = lc_data.cen_weight * lc_data.sat_weight if mag_thresh is not None: gal_weight = compute_cat_weights( - gal_weight, phot_kern_results, mag_thresh, frac_cat + gal_weight, obs_mags_weighted, mag_thresh, frac_cat ) return lc_data, phot_kern_results, gal_weight diff --git a/diffhtwo/experimental/kernels/phot_kern.py b/diffhtwo/experimental/kernels/phot_kern.py index 0ce829f7..71a89655 100644 --- a/diffhtwo/experimental/kernels/phot_kern.py +++ b/diffhtwo/experimental/kernels/phot_kern.py @@ -55,16 +55,15 @@ def mag_kern( param_collection, lc_data, ) - obs_mags = phot_kern_results.obs_mags_weighted - + obs_mags_weighted = phot_kern_results.obs_mags_weighted gal_weight = lc_data.cen_weight * lc_data.sat_weight # update weights to incorporate mag thresh cuts and frac_cat gal_weight = compute_cat_weights( - gal_weight, phot_kern_results, mag_thresh, frac_cat + gal_weight, obs_mags_weighted, mag_thresh, frac_cat ) - return obs_mags, gal_weight, phot_kern_results + return obs_mags_weighted, gal_weight, phot_kern_results @partial(jjit, static_argnames=["redshift_as_last_dimension_in_lh"]) diff --git a/diffhtwo/experimental/kernels/tests/test_N_phot.py b/diffhtwo/experimental/kernels/tests/test_N_phot.py index efe4bfdd..f6b0ecb0 100644 --- a/diffhtwo/experimental/kernels/tests/test_N_phot.py +++ b/diffhtwo/experimental/kernels/tests/test_N_phot.py @@ -1,8 +1,10 @@ +import jax.numpy as jnp import numpy as np from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION from jax import random as jran from ..N_phot import N_colors_mags_lh +from ..phot_kern import mag_kern def test_N_colors_mags_lh(feniks_single_z_data): @@ -19,3 +21,26 @@ def test_N_colors_mags_lh(feniks_single_z_data): assert np.isfinite(N).all() assert (N >= 0.0).all() + + +def test_mag_kern(feniks): + ran_key = jran.key(0) + + obs_mags_weighted, gal_weight, phot_kern_results = mag_kern( + ran_key, + DEFAULT_PARAM_COLLECTION, + feniks.colors[0].lc_data, + feniks.filter_info.mag_thresh, + feniks.frac_cat, + ) + + assert np.isfinite(obs_mags_weighted).all() + assert np.isfinite(gal_weight).all() + + # ensure that gal_weight for gals above mag_thresh is 0.0 in each band + mag_thresh = jnp.array(feniks.filter_info.mag_thresh) + n_gals, n_bands = obs_mags_weighted.shape + for i in range(0, n_bands): + mag_sel = obs_mags_weighted[:, i] < mag_thresh[i] + gal_weight_above_magthresh = jnp.where(mag_sel, 0, gal_weight) + assert gal_weight_above_magthresh.sum() == 0.0 diff --git a/diffhtwo/experimental/kernels/tests/test_compute_cat_weights.py b/diffhtwo/experimental/kernels/tests/test_compute_cat_weights.py index c05e2700..b58d5137 100644 --- a/diffhtwo/experimental/kernels/tests/test_compute_cat_weights.py +++ b/diffhtwo/experimental/kernels/tests/test_compute_cat_weights.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION from jax import random as jran @@ -14,15 +15,25 @@ def test_compute_cat_weights(feniks, feniks_lc_data): DEFAULT_PARAM_COLLECTION, feniks_lc_data, ) + obs_mags_weighted = phot_kern_results.obs_mags_weighted gal_weight = feniks_lc_data.cen_weight * feniks_lc_data.sat_weight assert np.isfinite(gal_weight).all() assert (gal_weight >= 0).all() + # apply mag_thresh cuts and frac_cat gal_cat_weight = compute_cat_weights( - gal_weight, phot_kern_results, feniks.filter_info.mag_thresh, feniks.frac_cat + gal_weight, obs_mags_weighted, feniks.filter_info.mag_thresh, feniks.frac_cat ) assert np.isfinite(gal_cat_weight).all() assert (gal_cat_weight >= 0).all() # ensure that gal_cat_weight does not upweight compared to gal_weight assert gal_cat_weight.sum() <= gal_weight.sum() + + # ensure that gal_cat_weight for gals above mag_thresh is 0.0 in each band + mag_thresh = jnp.array(feniks.filter_info.mag_thresh) + n_gals, n_bands = obs_mags_weighted.shape + for i in range(0, n_bands): + mag_sel = obs_mags_weighted[:, i] < mag_thresh[i] + gal_cat_weight_above_magthresh = jnp.where(mag_sel, 0, gal_cat_weight) + assert gal_cat_weight_above_magthresh.sum() == 0.0 diff --git a/diffhtwo/experimental/utils.py b/diffhtwo/experimental/utils.py index 890f67fa..11ceb033 100644 --- a/diffhtwo/experimental/utils.py +++ b/diffhtwo/experimental/utils.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import numpy as np +from astropy.table import Table, vstack from jax import jit as jjit from jax.tree_util import tree_flatten_with_path @@ -73,3 +74,37 @@ def get_param_names(params): paths, leaves = tree_flatten_with_path(params) names = [p[0][0].name for p in paths] # each p is (GetAttrKey(...),) return names + + +def add_random_rows(tab, N): + new_rows = {} + for col in tab.colnames: + data = tab[col] + dtype = data.dtype + + if dtype.kind in ("f", "i") and data.ndim == 1: + valid = data[(data != -99) & np.isfinite(data)] + if len(valid) > 0: + lo, hi = valid.min(), valid.max() + new_rows[col] = np.random.uniform(lo, hi, N) + if dtype.kind == "i": + new_rows[col] = new_rows[col].astype(dtype) + else: + new_rows[col] = np.full(N, -99, dtype=dtype) + elif dtype.kind == "f" and data.ndim == 2: + ncols = data.shape[1] + new_col = np.empty((N, ncols), dtype=dtype) + for j in range(ncols): + sub = data[:, j] + valid = sub[(sub != -99) & np.isfinite(sub)] + if len(valid) > 0: + lo, hi = valid.min(), valid.max() + new_col[:, j] = np.random.uniform(lo, hi, N) + else: + new_col[:, j] = -99 + new_rows[col] = new_col + else: + new_rows[col] = data[:N] # fallback for non-numeric columns + + new_tab = Table(new_rows) + return vstack([tab, new_tab]) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index c822d62d..326141fb 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run182 -model_nickname: run182_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run182/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run183 +model_nickname: run183_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run183/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,7 +11,7 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: False +plot_sdss: True plot_feniks: True plot_hizels: False From f88e55fcca6ad74631cfff3b655bb5cbca1d8576 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 12:52:31 -0500 Subject: [PATCH 50/57] remove NBs from fitting for now --- diffhtwo/experimental/data_loaders/load_feniks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 319846c6..79d2f1c8 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -535,12 +535,12 @@ def get_feniks_data( # 2D (i - NB816, g - r) i816_gr = N_utils.get_colorcolor_space( - "I816_gr", hsc_i816, hsc_gr, [3, 8, 1, 2], z_sel, fit=True + "I816_gr", hsc_i816, hsc_gr, [3, 8, 1, 2], z_sel, fit=False ) # 2D (i - NB816, r - K) i816_rK = N_utils.get_colorcolor_space( - "I816_rK", hsc_i816, hsc_uds_rK, [3, 8, 2, 7], z_sel, fit=True + "I816_rK", hsc_i816, hsc_uds_rK, [3, 8, 2, 7], z_sel, fit=False ) # 1D (i - NB816 | K) @@ -552,7 +552,7 @@ def get_feniks_data( 7, z_sel, cond_dmag=2, - fit=True, + fit=False, ) # 1D (z - NB921 | K) @@ -564,7 +564,7 @@ def get_feniks_data( 7, z_sel, cond_dmag=2, - fit=True, + fit=False, ) z2a = Z2a( From 0f59e76315ae08af086d15358ea9bed5edf6b3d8 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 13:00:50 -0500 Subject: [PATCH 51/57] # noqa: E741 --- diffhtwo/experimental/data_loaders/load_feniks.py | 2 +- diffhtwo/experimental/data_loaders/load_sdss.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 79d2f1c8..2662fb33 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -763,7 +763,7 @@ def get_feniks_data( U = namedtuple("U", AppMagFunc._fields) G = namedtuple("G", AppMagFunc._fields) R = namedtuple("R", AppMagFunc._fields) - I = namedtuple("I", AppMagFunc._fields) + I = namedtuple("I", AppMagFunc._fields) # noqa: E741 Z = namedtuple("Z", AppMagFunc._fields) J = namedtuple("J", AppMagFunc._fields) H = namedtuple("H", AppMagFunc._fields) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index bfce178b..127800c1 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -324,7 +324,7 @@ def get_sdss_data( U = namedtuple("U", AppMagFunc._fields) G = namedtuple("G", AppMagFunc._fields) R = namedtuple("R", AppMagFunc._fields) - I = namedtuple("I", AppMagFunc._fields) + I = namedtuple("I", AppMagFunc._fields) # noqa: E741 Z = namedtuple("Z", AppMagFunc._fields) app_mag_funcs = [] From d0625977d9631121795c622cfe98f21b322ee01e Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 13:03:37 -0500 Subject: [PATCH 52/57] Update load_sdss.py --- diffhtwo/experimental/data_loaders/load_sdss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 127800c1..e7ef64af 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -240,14 +240,14 @@ def get_sdss_data( N_ur_ri, sig_ur_ri, bin_lo_ur_ri, bin_hi_ur_ri = get_N_2d( sdss_ur[z_sel], sdss_ri[z_sel] ) - col_idx = [0, 2, 3] + col_idx = [0, 2, 2, 3] ur_ri = Ur_ri(col_idx, sig_ur_ri, bin_lo_ur_ri, bin_hi_ur_ri, N_ur_ri, True) # 2D (g - r, r - i) N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri = get_N_2d( sdss_gr[z_sel], sdss_ri[z_sel] ) - col_idx = [1, 2, 3] + col_idx = [1, 2, 2, 3] gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri, True) # 1D (u - r | r) From b491525d15189126640dcb2ab38f90d74d15d2fd Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 17:10:15 -0500 Subject: [PATCH 53/57] filter_name_to_idx --- diffhtwo/experimental/conftest.py | 14 + diffhtwo/experimental/data_loaders/N_utils.py | 55 ++- .../experimental/data_loaders/load_feniks.py | 371 ++++++++++++------ diffhtwo/experimental/defaults.py | 16 + .../loss_kernels/tests/test_phot_loss.py | 4 + .../optimizers/tests/test_Np_specphot_opt.py | 56 ++- scripts/config_diagnostics.yaml | 10 +- scripts/fit_feniks.py | 10 +- 8 files changed, 359 insertions(+), 177 deletions(-) diff --git a/diffhtwo/experimental/conftest.py b/diffhtwo/experimental/conftest.py index a38d4f82..7c5c0d15 100644 --- a/diffhtwo/experimental/conftest.py +++ b/diffhtwo/experimental/conftest.py @@ -73,6 +73,20 @@ def feniks_tcurves(): return tcurves +@pytest.fixture(scope="session") +def feniks_fitting_data(ran_key, fake_subset_ssp_data): + ssp_data, emline_wave_aa = fake_subset_ssp_data + feniks_fitting_data = load_feniks.get_feniks_fitting_data( + FENIKS_DRN, + ran_key, + ssp_data, + phot=PHOT, + zout=ZOUT, + add_random_rows_for_testing=True, + ) + return feniks_fitting_data + + @pytest.fixture(scope="session") def feniks_single_z_data(ran_key, fake_subset_ssp_data, feniks): ssp_data, emline_wave_aa = fake_subset_ssp_data diff --git a/diffhtwo/experimental/data_loaders/N_utils.py b/diffhtwo/experimental/data_loaders/N_utils.py index f7873b8b..2978ee04 100644 --- a/diffhtwo/experimental/data_loaders/N_utils.py +++ b/diffhtwo/experimental/data_loaders/N_utils.py @@ -4,11 +4,7 @@ import numpy as np from diffsky import diffndhist_lomem -from ..defaults import ( - ColorColor, - ColorCondMag, - MagColor, -) +from ..defaults import AppMagFunc, ColorColor, ColorCondMag, FeniksFilters, MagColor def get_N_1d(dim1, dim1_bin_edges=None, dmag=0.2, sig_scale=0.5): @@ -67,18 +63,48 @@ def get_N_2d(dim1, dim2, sig_scale=0.5): return N_2d, sig, bin_lo, bin_hi -def get_colorcolor_space(name, color1, color2, col_idx, z_sel, fit=True): - ColorColorSpace = namedtuple(name, ColorColor._fields) +def filter_name_to_idx(filter_name): + return FeniksFilters._fields.index(filter_name) + + +def get_mag_space(namedtuple_name, mag, filter_name, z_sel, fit=True): + AppMagFuncSpace = namedtuple(namedtuple_name, AppMagFunc._fields) + mag_idx = filter_name_to_idx(filter_name) + N_1d, sig, bin_lo, bin_hi = get_N_1d(mag[z_sel]) + return AppMagFuncSpace(mag_idx, sig, bin_lo, bin_hi, N_1d, True) + + +def get_colorcolor_space( + namedtuple_name, color1, color2, col_filter_names, z_sel, fit=True +): + ColorColorSpace = namedtuple(namedtuple_name, ColorColor._fields) N_2d, sig, bin_lo, bin_hi = get_N_2d(color1[z_sel], color2[z_sel]) + col_idx = [] + for n in col_filter_names: + col_idx.append(filter_name_to_idx(n)) + return ColorColorSpace(col_idx, sig, bin_lo, bin_hi, N_2d, fit) def get_color_cond_space_list( - name, color, cond_mag, col_idx, cond_idx, z_sel, cond_dmag=2, fit=True + namedtuple_name, + color, + cond_mag, + col_filter_names, + cond_filter_name, + z_sel, + cond_dmag=2, + fit=True, ): - ColorCondSpace = namedtuple(name, ColorCondMag._fields) + ColorCondSpace = namedtuple(namedtuple_name, ColorCondMag._fields) + + col_idx = [] + for n in col_filter_names: + col_idx.append(filter_name_to_idx(n)) + cond_idx = filter_name_to_idx(cond_filter_name) + cond_mag_bins = np.arange(cond_mag[z_sel].min(), cond_mag[z_sel].max(), cond_dmag) color_cond_list = [] @@ -104,8 +130,15 @@ def get_color_cond_space_list( return color_cond_list -def get_mag_color_space(name, mag, color, mag_idx, col_idx, z_sel, fit=True): - MagColorSpace = namedtuple(name, MagColor._fields) +def get_mag_color_space( + namedtuple_name, mag, color, mag_filter_name, col_filter_names, z_sel, fit=True +): + MagColorSpace = namedtuple(namedtuple_name, MagColor._fields) + + mag_idx = filter_name_to_idx(mag_filter_name) + col_idx = [] + for n in col_filter_names: + col_idx.append(filter_name_to_idx(n)) N_2d, sig, bin_lo, bin_hi = get_N_2d(mag[z_sel], color[z_sel]) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 2662fb33..17f175fb 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -16,6 +16,7 @@ FENIKS_Z_MIN, AppMagFunc, Dataset, + FeniksFilters, FilterInfo, ) from ..latin_hypercube import latin_hypercube as lh @@ -147,8 +148,8 @@ def get_feniks_data( lh_d_mag=0.6, phot=PHOT, zout=ZOUT, - num_halos_coarse_zbins=150, - num_halos_fine_zbins=250, + num_halos_coarse_zbins=250, + num_halos_fine_zbins=150, lgmp_min=10.0, lgmp_max=15.0, lc_sky_area_degsq=100, @@ -162,12 +163,12 @@ def get_feniks_data( HSC_G=False, HSC_R=False, HSC_I=False, + NB0816=False, HSC_Z=False, + NB0921=False, UDS_J=False, UDS_H=False, UDS_K=True, - NB0816=False, - NB0921=False, ) tcurves = [] for feniks_filter in FeniksFilters._fields: @@ -188,25 +189,24 @@ def get_feniks_data( hsc_g = get_mag_ab(phot, "fcol_HSC_G") hsc_r = get_mag_ab(phot, "fcol_HSC_R") hsc_i = get_mag_ab(phot, "fcol_HSC_I") + nb816 = get_mag_ab(phot, "fcol_NB0816") hsc_z = get_mag_ab(phot, "fcol_HSC_Z") - # video_Y = get_mag_ab(phot, "fcol_VIDEO_Y") + nb921 = get_mag_ab(phot, "fcol_NB0921") uds_J = get_mag_ab(phot, "fcol_UDS_J") uds_H = get_mag_ab(phot, "fcol_UDS_H") uds_K = get_mag_ab(phot, "fcol_UDS_K") - nb816 = get_mag_ab(phot, "fcol_NB0816") - nb921 = get_mag_ab(phot, "fcol_NB0921") feniks_mag_thresh = FeniksFilters( MegaCam_uS=24.9, HSC_G=25.1, HSC_R=25.3, HSC_I=25.1, + NB0816=25.3, HSC_Z=24.9, + NB0921=25.3, UDS_J=24.5, UDS_H=24.3, UDS_K=FENIKS_MAGK_THRESH, - NB0816=25.3, - NB0921=25.3, ) filter_info = FilterInfo(feniks_mag_thresh, feniks_in_lh, tcurves) @@ -217,12 +217,12 @@ def get_feniks_data( & (hsc_g < feniks_mag_thresh.HSC_G) & (hsc_r < feniks_mag_thresh.HSC_R) & (hsc_i < feniks_mag_thresh.HSC_I) + & (nb816 < feniks_mag_thresh.NB0816) & (hsc_z < feniks_mag_thresh.HSC_Z) + & (nb921 < feniks_mag_thresh.NB0921) & (uds_J < feniks_mag_thresh.UDS_J) & (uds_H < feniks_mag_thresh.UDS_H) & (uds_K < feniks_mag_thresh.UDS_K) - & (nb816 < feniks_mag_thresh.NB0816) - & (nb921 < feniks_mag_thresh.NB0921) ) # apply mag_thresh cuts and record n_gals. @@ -234,12 +234,12 @@ def get_feniks_data( hsc_g = hsc_g[mag_thresh] hsc_r = hsc_r[mag_thresh] hsc_i = hsc_i[mag_thresh] + nb816 = nb816[mag_thresh] hsc_z = hsc_z[mag_thresh] + nb921 = nb921[mag_thresh] uds_J = uds_J[mag_thresh] uds_H = uds_H[mag_thresh] uds_K = uds_K[mag_thresh] - nb816 = nb816[mag_thresh] - nb921 = nb921[mag_thresh] n_gals_pre_cuts = len(zout) @@ -249,12 +249,12 @@ def get_feniks_data( & (hsc_g != -99) & (hsc_r != -99) & (hsc_i != -99) + & (nb816 != -99) & (hsc_z != -99) + & (nb921 != -99) & (uds_J != -99) & (uds_H != -99) & (uds_K != -99) - & (nb816 != -99) - & (nb921 != -99) ) phot = phot[clean] @@ -263,12 +263,12 @@ def get_feniks_data( hsc_g = hsc_g[clean] hsc_r = hsc_r[clean] hsc_i = hsc_i[clean] + nb816 = nb816[clean] hsc_z = hsc_z[clean] + nb921 = nb921[clean] uds_J = uds_J[clean] uds_H = uds_H[clean] uds_K = uds_K[clean] - nb816 = nb816[clean] - nb921 = nb921[clean] n_gals_post_cuts = len(zout) frac_cat = n_gals_post_cuts / n_gals_pre_cuts @@ -279,12 +279,12 @@ def get_feniks_data( hsc_g, hsc_r, hsc_i, + nb816, hsc_z, + nb921, uds_J, uds_H, uds_K, - nb816, - nb921, zout["z_phot"], ) ).T @@ -294,12 +294,12 @@ def get_feniks_data( r"$g_{HSC}$", r"$r_{HSC}$", r"$i_{HSC}$", + r"$NB816_{HSC}$", r"$z_{HSC}$", + r"$NB921_{HSC}$", r"$J_{UDS}$", r"$H_{UDS}$", r"$K_{UDS}$", - r"$NB816$", - r"$NB921$", ] # derive colors from mags @@ -307,12 +307,12 @@ def get_feniks_data( hsc_gr = hsc_g - hsc_r hsc_rz = hsc_r - hsc_z hsc_ri = hsc_r - hsc_i + hsc_i816 = hsc_i - nb816 hsc_iz = hsc_i - hsc_z + hsc_z921 = hsc_z - nb921 hsc_uds_zJ = hsc_z - uds_J uds_JH = uds_J - uds_H uds_HK = uds_H - uds_K - hsc_i816 = hsc_i - nb816 - hsc_z921 = hsc_z - nb921 hsc_uds_rK = hsc_r - uds_K # stack colors_mag @@ -321,14 +321,14 @@ def get_feniks_data( megacam_hsc_uSg, hsc_gr, hsc_ri, + hsc_i816, hsc_iz, + hsc_z921, hsc_uds_zJ, uds_JH, uds_HK, megacam_uS, uds_K, - hsc_i816, - hsc_z921, zout["z_phot"], ) ).T @@ -337,19 +337,17 @@ def get_feniks_data( r"$uS_{MegaCam} - g_{HSC}$", r"$g_{HSC} - r_{HSC}$", r"$r_{HSC} - i_{HSC}$", + r"$i_{HSC} - NB816_{HSC}$", r"$i_{HSC} - z_{HSC}$", + r"$z_{HSC} - NB921_{HSC}$", r"$z_{HSC} - J_{UDS}$", - # r"$Y_{VIDEO} - J_{UDS}$", r"$J_{UDS} - H_{UDS}$", r"$H_{UDS} - K_{UDS}$", r"$uS_{MegaCam}$", r"$K_{UDS}$", - r"$i_{HSC} - NB816_{HSC}$", - r"$z_{HSC} - NB921_{HSC}$", r"$redshift$", ] - n_bins = 0 ############################################################################## # prepare 2D and 1D color spaces in coarse z-bins for fitting zbins = np.array( @@ -410,42 +408,70 @@ def get_feniks_data( # 2D (g - r, r - i) gr_ri = N_utils.get_colorcolor_space( - "Gr_ri", hsc_gr, hsc_ri, [1, 2, 2, 3], z_sel, fit=True + "Gr_ri", hsc_gr, hsc_ri, ["HSC_G", "HSC_R", "HSC_R", "HSC_I"], z_sel, fit=True ) # 1D (u - g | K) ug = N_utils.get_color_cond_space_list( - "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=2, fit=True + "Ug_condK", + megacam_hsc_uSg, + uds_K, + ["MegaCam_uS", "HSC_G"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 1D (r − i | K) ri = N_utils.get_color_cond_space_list( - "Ri_condK", hsc_ri, uds_K, [2, 3], 7, z_sel, cond_dmag=2, fit=True + "Ri_condK", + hsc_ri, + uds_K, + ["HSC_R", "HSC_I"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 1D (i − z | K) iz = N_utils.get_color_cond_space_list( - "Iz_condK", hsc_iz, uds_K, [3, 4], 7, z_sel, cond_dmag=2, fit=True + "Iz_condK", + hsc_iz, + uds_K, + ["HSC_I", "HSC_Z"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 1D (J − H | K) jh = N_utils.get_color_cond_space_list( - "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=2, fit=True + "JH_condK", + uds_JH, + uds_K, + ["UDS_J", "UDS_H"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 2D (K, r - i) K_ri = N_utils.get_mag_color_space( - "K_ri", uds_K, hsc_ri, 7, [2, 3], z_sel, fit=True + "K_ri", uds_K, hsc_ri, "UDS_K", ["HSC_R", "HSC_I"], z_sel, fit=True ) # 2D (K, g - r) K_gr = N_utils.get_mag_color_space( - "K_gr", uds_K, hsc_gr, 7, [1, 2], z_sel, fit=True + "K_gr", uds_K, hsc_gr, "UDS_K", ["HSC_G", "HSC_R"], z_sel, fit=True ) # 2D (K, J - H) K_JH = N_utils.get_mag_color_space( - "K_JH", uds_K, uds_JH, 7, [5, 6], z_sel, fit=True + "K_JH", uds_K, uds_JH, "UDS_K", ["UDS_J", "UDS_H"], z_sel, fit=True ) z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, ri, iz, jh, K_ri, K_gr, K_JH) @@ -505,42 +531,84 @@ def get_feniks_data( # 2D (r - z, z - J) rz_zJ = N_utils.get_colorcolor_space( - "Rz_zJ", hsc_rz, hsc_uds_zJ, [2, 4, 4, 5], z_sel, fit=True + "Rz_zJ", + hsc_rz, + hsc_uds_zJ, + ["HSC_R", "HSC_Z", "HSC_Z", "UDS_J"], + z_sel, + fit=True, ) # 1D (u - g | K) ug = N_utils.get_color_cond_space_list( - "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=2, fit=True + "Ug_condK", + megacam_hsc_uSg, + uds_K, + ["MegaCam_uS", "HSC_G"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 1D (r - z | K) rz = N_utils.get_color_cond_space_list( - "Rz_condK", hsc_rz, uds_K, [2, 4], 7, z_sel, cond_dmag=2, fit=True + "Rz_condK", + hsc_rz, + uds_K, + ["HSC_R", "HSC_Z"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 1D (J − H | K) jh = N_utils.get_color_cond_space_list( - "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=2, fit=True + "JH_condK", + uds_JH, + uds_K, + ["UDS_J", "UDS_H"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 2D (K, u - g) K_ug = N_utils.get_mag_color_space( - "K_ug", uds_K, megacam_hsc_uSg, 7, [0, 1], z_sel, fit=True + "K_ug", + uds_K, + megacam_hsc_uSg, + "UDS_K", + ["MegaCam_uS", "HSC_G"], + z_sel, + fit=True, ) # 2D (K, r - z) K_rz = N_utils.get_mag_color_space( - "K_rz", uds_K, hsc_rz, 7, [2, 4], z_sel, fit=True + "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=True ) # 2D (i - NB816, g - r) i816_gr = N_utils.get_colorcolor_space( - "I816_gr", hsc_i816, hsc_gr, [3, 8, 1, 2], z_sel, fit=False + "I816_gr", + hsc_i816, + hsc_gr, + ["HSC_I", "NB0816", "HSC_G", "HSC_R"], + z_sel, + fit=False, ) # 2D (i - NB816, r - K) i816_rK = N_utils.get_colorcolor_space( - "I816_rK", hsc_i816, hsc_uds_rK, [3, 8, 2, 7], z_sel, fit=False + "I816_rK", + hsc_i816, + hsc_uds_rK, + ["HSC_I", "NB0816", "HSC_R", "UDS_K"], + z_sel, + fit=False, ) # 1D (i - NB816 | K) @@ -548,8 +616,8 @@ def get_feniks_data( "I816_condK", hsc_i816, uds_K, - [3, 8], - 7, + ["HSC_I", "NB0816"], + "UDS_K", z_sel, cond_dmag=2, fit=False, @@ -560,8 +628,8 @@ def get_feniks_data( "Z921_condK", hsc_z921, uds_K, - [4, 9], - 7, + ["HSC_Z", "NB0921"], + "UDS_K", z_sel, cond_dmag=2, fit=False, @@ -620,32 +688,64 @@ def get_feniks_data( # 2D (r - z, z - J) rz_zJ = N_utils.get_colorcolor_space( - "Rz_zJ", hsc_rz, hsc_uds_zJ, [2, 4, 4, 5], z_sel, fit=True + "Rz_zJ", + hsc_rz, + hsc_uds_zJ, + ["HSC_R", "HSC_Z", "HSC_Z", "UDS_J"], + z_sel, + fit=True, ) # 1D (u - g | K) ug = N_utils.get_color_cond_space_list( - "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=2, fit=True + "Ug_condK", + megacam_hsc_uSg, + uds_K, + ["MegaCam_uS", "HSC_G"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 1D (r - z | K) rz = N_utils.get_color_cond_space_list( - "Rz_condK", hsc_rz, uds_K, [2, 4], 7, z_sel, cond_dmag=2, fit=True + "Rz_condK", + hsc_rz, + uds_K, + ["HSC_R", "HSC_Z"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 1D (J − H | K) jh = N_utils.get_color_cond_space_list( - "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=2, fit=True + "JH_condK", + uds_JH, + uds_K, + ["UDS_J", "UDS_H"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, ) # 2D (K, u - g) K_ug = N_utils.get_mag_color_space( - "K_ug", uds_K, megacam_hsc_uSg, 7, [0, 1], z_sel, fit=True + "K_ug", + uds_K, + megacam_hsc_uSg, + "UDS_K", + ["MegaCam_uS", "HSC_G"], + z_sel, + fit=True, ) # 2D (K, r - z) K_rz = N_utils.get_mag_color_space( - "K_rz", uds_K, hsc_rz, 7, [2, 4], z_sel, fit=True + "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=True ) z2b = Z2b(z_min, z_max, lc_data, rz_zJ, ug, rz, jh, K_ug, K_rz) @@ -701,42 +801,79 @@ def get_feniks_data( # 2D (z - J, J - H) zJ_JH = N_utils.get_colorcolor_space( - "ZJ_JH", hsc_uds_zJ, uds_JH, [4, 5, 5, 6], z_sel, fit=True + "ZJ_JH", + hsc_uds_zJ, + uds_JH, + ["HSC_Z", "UDS_J", "UDS_J", "UDS_H"], + z_sel, + fit=True, ) # 2D (u - g, g - r) ug_gr = N_utils.get_colorcolor_space( - "Ug_gr", megacam_hsc_uSg, hsc_gr, [0, 1, 1, 2], z_sel, fit=True + "Ug_gr", + megacam_hsc_uSg, + hsc_gr, + ["MegaCam_uS", "HSC_G", "HSC_G", "HSC_R"], + z_sel, + fit=True, ) # 1D (u - g | K) ug = N_utils.get_color_cond_space_list( - "Ug_condK", megacam_hsc_uSg, uds_K, [0, 1], 7, z_sel, cond_dmag=4, fit=True + "Ug_condK", + megacam_hsc_uSg, + uds_K, + ["MegaCam_uS", "HSC_G"], + "UDS_K", + z_sel, + cond_dmag=4, + fit=True, ) # 1D (g - r | K) gr = N_utils.get_color_cond_space_list( - "Gr_condK", hsc_gr, uds_K, [1, 2], 7, z_sel, cond_dmag=4, fit=True + "Gr_condK", + hsc_gr, + uds_K, + ["HSC_G", "HSC_R"], + "UDS_K", + z_sel, + cond_dmag=4, + fit=True, ) # 1D (J − H | K) jh = N_utils.get_color_cond_space_list( - "JH_condK", uds_JH, uds_K, [5, 6], 7, z_sel, cond_dmag=4, fit=True + "JH_condK", + uds_JH, + uds_K, + ["UDS_J", "UDS_H"], + "UDS_K", + z_sel, + cond_dmag=4, + fit=True, ) # 2D (K, u - g) K_ug = N_utils.get_mag_color_space( - "K_ug", uds_K, megacam_hsc_uSg, 7, [0, 1], z_sel, fit=True + "K_ug", + uds_K, + megacam_hsc_uSg, + "UDS_K", + ["MegaCam_uS", "HSC_G"], + z_sel, + fit=True, ) # 2D (K, g - r) K_gr = N_utils.get_mag_color_space( - "K_gr", uds_K, hsc_gr, 7, [1, 2], z_sel, fit=True + "K_gr", uds_K, hsc_gr, "UDS_K", ["HSC_G", "HSC_R"], z_sel, fit=True ) # 2D (K, J - H) K_JH = N_utils.get_mag_color_space( - "K_JH", uds_K, uds_JH, 7, [5, 6], z_sel, fit=True + "K_JH", uds_K, uds_JH, "UDS_K", ["UDS_J", "UDS_H"], z_sel, fit=True ) z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh, K_ug, K_gr, K_JH) @@ -760,15 +897,6 @@ def get_feniks_data( ["z_min", "z_max", "lc_data", "u", "g", "r", "i", "z", "J", "H", "K"], ) - U = namedtuple("U", AppMagFunc._fields) - G = namedtuple("G", AppMagFunc._fields) - R = namedtuple("R", AppMagFunc._fields) - I = namedtuple("I", AppMagFunc._fields) # noqa: E741 - Z = namedtuple("Z", AppMagFunc._fields) - J = namedtuple("J", AppMagFunc._fields) - H = namedtuple("H", AppMagFunc._fields) - K = namedtuple("K", AppMagFunc._fields) - app_mag_funcs = [] for zbin in range(0, len(fine_zbins)): z_min = fine_zbins[zbin][0] @@ -795,52 +923,28 @@ def get_feniks_data( z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) # 1D (u) - mag_idx_u = 0 - N_1d_u, sig_u, bin_lo_u, bin_hi_u = N_utils.get_N_1d(megacam_uS[z_sel]) - u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u, True) - n_bins += bin_lo_u.size + u = N_utils.get_mag_space("U", megacam_uS, "MegaCam_uS", z_sel, fit=True) # 1D (g) - mag_idx_g = 1 - N_1d_g, sig_g, bin_lo_g, bin_hi_g = N_utils.get_N_1d(hsc_g[z_sel]) - g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, True) - n_bins += bin_lo_g.size + g = N_utils.get_mag_space("G", hsc_g, "HSC_G", z_sel, fit=True) # 1D (r) - mag_idx_r = 2 - N_1d_r, sig_r, bin_lo_r, bin_hi_r = N_utils.get_N_1d(hsc_r[z_sel]) - r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r, True) - n_bins += bin_lo_r.size + r = N_utils.get_mag_space("R", hsc_r, "HSC_R", z_sel, fit=True) # 1D (i) - mag_idx_i = 3 - N_1d_i, sig_i, bin_lo_i, bin_hi_i = N_utils.get_N_1d(hsc_i[z_sel]) - i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, True) - n_bins += bin_lo_i.size + i = N_utils.get_mag_space("I", hsc_i, "HSC_I", z_sel, fit=True) # 1D (z) - mag_idx_z = 4 - N_1d_z, sig_z, bin_lo_z, bin_hi_z = N_utils.get_N_1d(hsc_z[z_sel]) - z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, True) - n_bins += bin_lo_z.size + z = N_utils.get_mag_space("Z", hsc_z, "HSC_Z", z_sel, fit=True) # 1D (J) - mag_idx_j = 5 - N_1d_j, sig_j, bin_lo_j, bin_hi_j = N_utils.get_N_1d(uds_J[z_sel]) - j = J(mag_idx_j, sig_j, bin_lo_j, bin_hi_j, N_1d_j, True) - n_bins += bin_lo_j.size + j = N_utils.get_mag_space("J", uds_J, "UDS_J", z_sel, fit=True) # 1D (H) - mag_idx_h = 6 - N_1d_h, sig_h, bin_lo_h, bin_hi_h = N_utils.get_N_1d(uds_H[z_sel]) - h = H(mag_idx_h, sig_h, bin_lo_h, bin_hi_h, N_1d_h, True) - n_bins += bin_lo_h.size + h = N_utils.get_mag_space("H", uds_H, "UDS_H", z_sel, fit=True) # 1D (K) - mag_idx_k = 7 - N_1d_k, sig_k, bin_lo_k, bin_hi_k = N_utils.get_N_1d(uds_K[z_sel]) - k = K(mag_idx_k, sig_k, bin_lo_k, bin_hi_k, N_1d_k, True) - n_bins += bin_lo_k.size + k = N_utils.get_mag_space("K", uds_K, "UDS_K", z_sel, fit=True) app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, g, r, i, z, j, h, k)) @@ -879,18 +983,47 @@ def get_feniks_data( ) -FeniksFilters = namedtuple( - "FeniksFilters", - [ - "MegaCam_uS", - "HSC_G", - "HSC_R", - "HSC_I", - "HSC_Z", - "UDS_J", - "UDS_H", - "UDS_K", - "NB0816", - "NB0921", - ], -) +def get_feniks_fitting_data( + feniks_drn, + ran_key, + ssp_data, + phot=PHOT, + zout=ZOUT, + num_halos_coarse_zbins=250, + num_halos_fine_zbins=150, + add_random_rows_for_testing=False, +): + feniks = get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + phot=phot, + zout=zout, + num_halos_coarse_zbins=num_halos_coarse_zbins, + num_halos_fine_zbins=num_halos_fine_zbins, + add_random_rows_for_testing=add_random_rows_for_testing, + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} + ) + + return feniks_fitting_data + + +# FeniksFilters = namedtuple( +# "FeniksFilters", +# [ +# "MegaCam_uS", +# "HSC_G", +# "HSC_R", +# "HSC_I", +# "NB0816", +# "HSC_Z", +# "NB0921", +# "UDS_J", +# "UDS_H", +# "UDS_K", +# ], +# ) diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 776a2529..e1306602 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -76,3 +76,19 @@ "AppMagFunc", ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data", "fit"], ) + +FeniksFilters = namedtuple( + "FeniksFilters", + [ + "MegaCam_uS", + "HSC_G", + "HSC_R", + "HSC_I", + "NB0816", + "HSC_Z", + "NB0921", + "UDS_J", + "UDS_H", + "UDS_K", + ], +) diff --git a/diffhtwo/experimental/loss_kernels/tests/test_phot_loss.py b/diffhtwo/experimental/loss_kernels/tests/test_phot_loss.py index 716cbff3..affcc0b7 100644 --- a/diffhtwo/experimental/loss_kernels/tests/test_phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/tests/test_phot_loss.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION from jax import random as jran @@ -6,6 +7,9 @@ from ..phot_loss import _loss_phot_kern, get_phot_loss +@pytest.mark.skip( + reason="LH dimensions need to be fixed in load_feniks before activating this test again" +) def test_phot_loss(feniks_single_z_data): feniks_meta_data, feniks_fitting_data = feniks_single_z_data diff --git a/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py b/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py index 40daef72..d55c338c 100644 --- a/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py @@ -8,16 +8,15 @@ from ... import param_utils as pu from ..Np_specphot_opt import ( - _loss_and_grad_phot_kern_multi_z, - _loss_and_grad_sdss_feniks_hizels, - fit_N_multi_z, - fit_sdss_feniks_hizels, + _loss_and_grad_emline_kern_multi_line_multi_z, + _loss_and_grad_phot_kern_2d_multiz, + fit_feniks_hizels, + fit_N_phot_2d, ) @pytest.fixture(scope="module") -def multistep_grads(ran_key, feniks_multi_z_data): - feniks_meta_data, feniks_fitting_data = feniks_multi_z_data +def multistep_grads(ran_key, feniks_fitting_data): n_steps = 10 step_size = 0.1 @@ -25,14 +24,13 @@ def multistep_grads(ran_key, feniks_multi_z_data): opt_init, opt_update, get_params = jax_opt.adam(step_size) other = ( ran_key, - feniks_meta_data, feniks_fitting_data, ) opt_state = opt_init(u_theta_init) multistep_grads = [] for i in range(n_steps): u_theta = get_params(opt_state) - loss, grads = _loss_and_grad_phot_kern_multi_z(u_theta, *other) + loss, grads = _loss_and_grad_phot_kern_2d_multiz(u_theta, *other) multistep_grads.append(grads) opt_state = opt_update(i, grads, opt_state) return multistep_grads, n_steps @@ -93,14 +91,11 @@ def test_all_diffsky_u_param_grads_are_nonzero( ), f"These {diffsky_param_names[param_idx]} have exactly zero grads: {zero_grad_params}" -def test_phot_opt(ran_key, feniks_multi_z_data): - feniks_meta_data, feniks_fitting_data = feniks_multi_z_data - +def test_phot_opt(ran_key, feniks_fitting_data): u_theta = pu.get_u_theta_from_param_collection(DEFAULT_PARAM_COLLECTION) - loss, grads = _loss_and_grad_phot_kern_multi_z( + loss, grads = _loss_and_grad_phot_kern_2d_multiz( u_theta, ran_key, - feniks_meta_data, feniks_fitting_data, ) assert np.isfinite(loss) @@ -108,11 +103,10 @@ def test_phot_opt(ran_key, feniks_multi_z_data): assert np.isfinite(grads[g]).all() trainable_params = pu.get_trainable_params(fit_type="all") - loss_hist, u_theta_fit = fit_N_multi_z( + loss_hist, u_theta_fit = fit_N_phot_2d( u_theta, trainable_params, ran_key, - feniks_meta_data, feniks_fitting_data, n_steps=2, step_size=0.1, @@ -126,38 +120,34 @@ def test_phot_opt(ran_key, feniks_multi_z_data): def test_specphot_opt( - ran_key, fake_subset_ssp_data, feniks_multi_z_data, hizels_fitting_data + ran_key, fake_subset_ssp_data, feniks_fitting_data, hizels_fitting_data ): ssp_data, emline_wave_aa = fake_subset_ssp_data - feniks_meta_data, feniks_fitting_data = feniks_multi_z_data - - # duplicate feniks data for sdss data - sdss_meta_data, sdss_fitting_data = feniks_meta_data, feniks_fitting_data - u_theta = pu.get_u_theta_from_param_collection(DEFAULT_PARAM_COLLECTION) - loss, grads = _loss_and_grad_sdss_feniks_hizels( + loss_phot, grad_phot = _loss_and_grad_phot_kern_2d_multiz( u_theta, ran_key, - sdss_meta_data, - sdss_fitting_data, - feniks_meta_data, feniks_fitting_data, - hizels_fitting_data, ) + assert np.isfinite(loss_phot) + for g in range(len(grad_phot)): + assert np.isfinite(grad_phot[g]).all() - assert np.isfinite(loss) - for g in range(len(grads)): - assert np.isfinite(grads[g]).all() + loss_emline, grad_emline = _loss_and_grad_emline_kern_multi_line_multi_z( + u_theta, + ran_key, + hizels_fitting_data, + ) + assert np.isfinite(loss_emline) + for g in range(len(grad_emline)): + assert np.isfinite(grad_emline[g]).all() trainable_params = pu.get_trainable_params(fit_type="all") - loss_hist, u_theta_fit = fit_sdss_feniks_hizels( + loss_hist, u_theta_fit = fit_feniks_hizels( u_theta, trainable_params, ran_key, - sdss_meta_data, - sdss_fitting_data, - feniks_meta_data, feniks_fitting_data, hizels_fitting_data, n_steps=2, diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 326141fb..ad1cc0ca 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run183 -model_nickname: run183_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run183/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run121 +model_nickname: run121_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run121/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -11,12 +11,12 @@ dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_sdss: True +plot_sdss: False plot_feniks: True plot_hizels: False plots: - num_halos : 3000 + num_halos : 1000 plot_color_contours: True plot_app_mag_funcs: True plot_color_pdfs: True diff --git a/scripts/fit_feniks.py b/scripts/fit_feniks.py index 9f096089..e05914a2 100644 --- a/scripts/fit_feniks.py +++ b/scripts/fit_feniks.py @@ -1,7 +1,6 @@ import argparse import os import time -from collections import namedtuple from datetime import datetime import jax @@ -94,20 +93,13 @@ for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - feniks = load_feniks.get_feniks_data( + feniks_fitting_data = load_feniks.get_feniks_fitting_data( feniks_drn, ran_key, ssp_data, num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], ) - remove = {"dataset_dim_labels", "mags_labels"} - FeniksFitting = namedtuple( - "Feniks", [f for f in feniks._fields if f not in remove] - ) - feniks_fitting_data = FeniksFitting( - **{f: getattr(feniks, f) for f in FeniksFitting._fields} - ) loss_hist, u_theta_fit = Np_specphot_opt.fit_N_phot_2d( u_theta_fit, From 8c94ac39ea3cb49dccac4ab80dbd5157f0c4ce09 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 19:00:42 -0500 Subject: [PATCH 54/57] CMD fit false for feniks --- .../experimental/data_loaders/load_feniks.py | 21 +++++++++---------- scripts/config_diagnostics.yaml | 6 +++--- scripts/generate_diagnostic_plots.py | 3 +++ 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 17f175fb..684804f3 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -14,7 +14,6 @@ FENIKS_MAGK_THRESH, FENIKS_Z_MAX, FENIKS_Z_MIN, - AppMagFunc, Dataset, FeniksFilters, FilterInfo, @@ -461,17 +460,17 @@ def get_feniks_data( # 2D (K, r - i) K_ri = N_utils.get_mag_color_space( - "K_ri", uds_K, hsc_ri, "UDS_K", ["HSC_R", "HSC_I"], z_sel, fit=True + "K_ri", uds_K, hsc_ri, "UDS_K", ["HSC_R", "HSC_I"], z_sel, fit=False ) # 2D (K, g - r) K_gr = N_utils.get_mag_color_space( - "K_gr", uds_K, hsc_gr, "UDS_K", ["HSC_G", "HSC_R"], z_sel, fit=True + "K_gr", uds_K, hsc_gr, "UDS_K", ["HSC_G", "HSC_R"], z_sel, fit=False ) # 2D (K, J - H) K_JH = N_utils.get_mag_color_space( - "K_JH", uds_K, uds_JH, "UDS_K", ["UDS_J", "UDS_H"], z_sel, fit=True + "K_JH", uds_K, uds_JH, "UDS_K", ["UDS_J", "UDS_H"], z_sel, fit=False ) z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, ri, iz, jh, K_ri, K_gr, K_JH) @@ -583,12 +582,12 @@ def get_feniks_data( "UDS_K", ["MegaCam_uS", "HSC_G"], z_sel, - fit=True, + fit=False, ) # 2D (K, r - z) K_rz = N_utils.get_mag_color_space( - "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=True + "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=False ) # 2D (i - NB816, g - r) @@ -740,12 +739,12 @@ def get_feniks_data( "UDS_K", ["MegaCam_uS", "HSC_G"], z_sel, - fit=True, + fit=False, ) # 2D (K, r - z) K_rz = N_utils.get_mag_color_space( - "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=True + "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=False ) z2b = Z2b(z_min, z_max, lc_data, rz_zJ, ug, rz, jh, K_ug, K_rz) @@ -863,17 +862,17 @@ def get_feniks_data( "UDS_K", ["MegaCam_uS", "HSC_G"], z_sel, - fit=True, + fit=False, ) # 2D (K, g - r) K_gr = N_utils.get_mag_color_space( - "K_gr", uds_K, hsc_gr, "UDS_K", ["HSC_G", "HSC_R"], z_sel, fit=True + "K_gr", uds_K, hsc_gr, "UDS_K", ["HSC_G", "HSC_R"], z_sel, fit=False ) # 2D (K, J - H) K_JH = N_utils.get_mag_color_space( - "K_JH", uds_K, uds_JH, "UDS_K", ["UDS_J", "UDS_H"], z_sel, fit=True + "K_JH", uds_K, uds_JH, "UDS_K", ["UDS_J", "UDS_H"], z_sel, fit=False ) z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh, K_ug, K_gr, K_JH) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index ad1cc0ca..c9e33a64 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run121 -model_nickname: run121_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run121/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run187 +model_nickname: run187_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run187/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index ffca208d..41b3167a 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -3,6 +3,7 @@ from pathlib import Path import jax.numpy as jnp +import matplotlib.pyplot as plt import numpy as np import yaml from diffsky.data_loaders.hacc_utils import lc_mock @@ -296,6 +297,7 @@ + str(FENIKS_Z_MAX), drn_out=fit_diagnostics_save_drn, ) + plt.close() if cfg["plots"]["plot_fburst_mh_z"]: print("Generating FENIKS lgfburst plot...") @@ -568,6 +570,7 @@ + str(SDSS_Z_MAX), drn_out=fit_diagnostics_save_drn, ) + plt.close() if cfg["plots"]["plot_fburst_mh_z"]: print("Generating SDSS lgfburst plot...") From a46530a18af69fb85d8cc9234115e446cf5960b2 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 19:11:39 -0500 Subject: [PATCH 55/57] diffsky.git@v0.3.6 --- .github/workflows/tests.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a311b4d3..4438b307 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,4 +1,4 @@ -name: Test Pull Request into main +name: Test Pull Request into main with diffsky@v0.3.6 on: push: @@ -41,7 +41,7 @@ jobs: scipy \ python-build pip uninstall diffsky --yes - pip install --no-deps git+https://github.com/ArgonneCPAC/diffsky.git + pip install --no-deps git+https://github.com/ArgonneCPAC/diffsky.git@v0.3.6 python -m pip install --no-build-isolation --no-deps -e . From 613b2c6684f085ee8479e0716ee919a6dc0ccff3 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 19:23:26 -0500 Subject: [PATCH 56/57] remove emline_luminosity_pop --- .../diagnostics/plot_mag_color_1d_hist.py | 1027 ----------------- .../experimental/emline_luminosity_pop.py | 246 ---- .../optimizers/emline_luminosity_opt.py | 166 --- .../tests/test_emline_luminosity_opt.py | 120 -- .../tests/test_emline_luminosity_pop.py | 69 -- .../experimental/tests/test_line_phot_kern.py | 89 -- 6 files changed, 1717 deletions(-) delete mode 100644 diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py delete mode 100644 diffhtwo/experimental/emline_luminosity_pop.py delete mode 100644 diffhtwo/experimental/optimizers/emline_luminosity_opt.py delete mode 100644 diffhtwo/experimental/optimizers/tests/test_emline_luminosity_opt.py delete mode 100644 diffhtwo/experimental/tests/test_emline_luminosity_pop.py diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py deleted file mode 100644 index af5c2fec..00000000 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ /dev/null @@ -1,1027 +0,0 @@ -import corner -import jax.numpy as jnp -import numpy as np - -# from diffsky import diffndhist -from diffsky.experimental import lc_phot_kern -from diffsky.experimental import mc_lightcone_halos as mclh -from diffsky.experimental import precompute_ssp_phot as psspp -from diffstar.defaults import FB, T_TABLE_MIN -from dsps.cosmology import flat_wcdm -from dsps.cosmology.defaults import DEFAULT_COSMOLOGY -from jax import random as jran -from lc_utils import zbin_area, zbin_volume - -blue = "#1E90FF" # DodgerBlue -orange = "#FF8C00" # DarkOrange -# blue = "#4169E1" # RoyalBlue -# orange = "#D2691E" # Chocolate -# blue = "#00BFFF" # DeepSkyBlue -# orange = "#FFA500" # Orange - -color1 = orange -color2 = "k" -color_data = blue - -alpha1 = 1.0 -alpha2 = 0.7 -alpha_data = 0.5 - - -lw = 1.5 -fontsize = 24 -labelsize = 20 -legend_fontsize = 30 - - -try: - import matplotlib as mpl - from matplotlib import pyplot as plt - from matplotlib.lines import Line2D - - plt.rc("font", family="serif", serif=["Times New Roman"]) - - HAS_MATPLOTLIB = True -except ImportError: - HAS_MATPLOTLIB = False - - -mpl.rcParams["axes.linewidth"] = 2 - - -def plot_n_mag( - diffstarpop_params1, - spspop_params1, - ssp_err_pop_params1, - label1, - tcurves, - mag_thresh_column, - mag_thresh, - frac_cat, - dimension_labels, - ran_key, - zmins, - zmaxs, - ssp_data, - mzr_params, - scatter_params, - suptitle, - zbin_titles, - savedir, - dataset_mags=None, - n_bands=None, - data_sky_area_degsq=None, - diffstarpop_params2=None, - spspop_params2=None, - ssp_err_pop_params2=None, - label2=None, - dmag=0.1, - lgmp_min=10.0, - lgmp_max=15.0, - lc_vol_mpc3=7e4, - cosmo_params=DEFAULT_COSMOLOGY, - fb=FB, -): - # Plot 1D histograms - n_bands = len(tcurves) - n_zbins = len(zmins) - - fig_width = 3.0 * n_bands - fig_height = 3.0 * n_zbins - fig, ax = plt.subplots( - n_zbins, - n_bands, - figsize=(fig_width, fig_height), - ) - - fig.subplots_adjust( - left=0.065, hspace=0, top=0.95, right=0.99, bottom=0.05, wspace=0.0 - ) - fig.suptitle(suptitle, fontsize=32) - - fig_offset, ax_offset = plt.subplots( - n_zbins, - n_bands, - figsize=(fig_width, fig_height), - ) - - fig_offset.subplots_adjust( - left=0.065, hspace=0, top=0.95, right=0.99, bottom=0.05, wspace=0.0 - ) - fig_offset.suptitle(suptitle, fontsize=32) - - dataset_mags_z1 = np.array(dataset_mags[0]) - - for z in range(0, n_zbins): - zmin = zmins[z] - zmax = zmaxs[z] - dataset_mags_z = np.array(dataset_mags[z]) - - t = int(n_bands / 2) - ax[z, t].set_title(zbin_titles[z], y=0.85, fontsize=labelsize) - - """mc lightcone""" - ran_key, lc_key = jran.split(ran_key, 2) - sky_area_degsq = zbin_area(lc_vol_mpc3, zlow=zmin, zhigh=zmax).value - lc_args = (lc_key, lgmp_min, zmin, zmax, sky_area_degsq) - lc_halopop = mclh.mc_lightcone_host_halo_diffmah( - *lc_args, cosmo_params=cosmo_params, lgmp_max=lgmp_max - ) - - if data_sky_area_degsq is not None: - data_vol_mpc3 = zbin_volume( - data_sky_area_degsq, zlow=zmin, zhigh=zmax - ).value - - n_z_phot_table = 33 - - if (zmin < 0.24) & (zmax > 0.24): - nb_z = jnp.array([0.2445706, 0.40185568]) - nb816_zspan = np.linspace(nb_z[0] - 0.02, nb_z[0] + 0.02, 11) - nb921_zspan = np.linspace(nb_z[1] - 0.02, nb_z[1] + 0.02, 11) - z1_zspan = np.linspace(0.2, 0.5, 11) - z_phot_table = np.concatenate((nb816_zspan, nb921_zspan, z1_zspan)) - z_phot_table.sort() - else: - z_phot_table = jnp.linspace(zmin, zmax, n_z_phot_table) - - t_0 = flat_wcdm.age_at_z0(*DEFAULT_COSMOLOGY) - lgt0 = jnp.log10(t_0) - t_table = jnp.linspace(T_TABLE_MIN, 10**lgt0, 100) - - precomputed_ssp_mag_table = psspp.get_precompute_ssp_mag_redshift_table( - tcurves, ssp_data, z_phot_table, DEFAULT_COSMOLOGY - ) - - wave_eff_table = lc_phot_kern.get_wave_eff_table(z_phot_table, tcurves) - - ran_key, phot_key1 = jran.split(ran_key, 2) - phot_args1 = ( - phot_key1, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - diffstarpop_params1, - mzr_params, - spspop_params1, - scatter_params, - ssp_err_pop_params1, - cosmo_params, - fb, - ) - - lc_phot1 = lc_phot_kern.multiband_lc_phot_kern(*phot_args1) - if n_bands is None: - num_halos, n_bands = lc_phot1.obs_mags_q.shape - - # set weights=0 for mag > mag_thresh for the band indicated by mag_thresh_column - obs_mag_q1 = lc_phot1.obs_mags_q[:, mag_thresh_column] - obs_mag_smooth_ms1 = lc_phot1.obs_mags_smooth_ms[:, mag_thresh_column] - obs_mag_bursty_ms1 = lc_phot1.obs_mags_bursty_ms[:, mag_thresh_column] - - lc_phot_weights_q1 = jnp.where( - obs_mag_q1 < mag_thresh, - lc_phot1.weights_q, - jnp.zeros_like(lc_phot1.weights_q), - ) - lc_phot_weights_smooth_ms1 = jnp.where( - obs_mag_smooth_ms1 < mag_thresh, - lc_phot1.weights_smooth_ms, - jnp.zeros_like(lc_phot1.weights_smooth_ms), - ) - lc_phot_weights_bursty_ms1 = jnp.where( - obs_mag_bursty_ms1 < mag_thresh, - lc_phot1.weights_bursty_ms, - jnp.zeros_like(lc_phot1.weights_bursty_ms), - ) - N_weights1 = np.concatenate( - [ - lc_phot_weights_q1 * frac_cat, - lc_phot_weights_smooth_ms1 * frac_cat, - lc_phot_weights_bursty_ms1 * frac_cat, - ] - ) - - if diffstarpop_params2 is not None: - ran_key, phot_key2 = jran.split(ran_key, 2) - phot_args2 = ( - phot_key2, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - diffstarpop_params2, - mzr_params, - spspop_params2, - scatter_params, - ssp_err_pop_params2, - cosmo_params, - fb, - ) - - lc_phot2 = lc_phot_kern.multiband_lc_phot_kern(*phot_args2) - - # set weights=0 for mag > mag_thresh for the band indicated by mag_thresh_column - obs_mag_q2 = lc_phot2.obs_mags_q[:, mag_thresh_column] - obs_mag_smooth_ms2 = lc_phot2.obs_mags_smooth_ms[:, mag_thresh_column] - obs_mag_bursty_ms2 = lc_phot2.obs_mags_bursty_ms[:, mag_thresh_column] - - lc_phot_weights_q2 = jnp.where( - obs_mag_q2 < mag_thresh, - lc_phot2.weights_q, - jnp.zeros_like(lc_phot2.weights_q), - ) - lc_phot_weights_smooth_ms2 = jnp.where( - obs_mag_smooth_ms2 < mag_thresh, - lc_phot2.weights_smooth_ms, - jnp.zeros_like(lc_phot2.weights_smooth_ms), - ) - lc_phot_weights_bursty_ms2 = jnp.where( - obs_mag_bursty_ms2 < mag_thresh, - lc_phot2.weights_bursty_ms, - jnp.zeros_like(lc_phot2.weights_bursty_ms), - ) - N_weights2 = np.concatenate( - [ - lc_phot_weights_q2 * frac_cat, - lc_phot_weights_smooth_ms2 * frac_cat, - lc_phot_weights_bursty_ms2 * frac_cat, - ] - ) - - for i in range(0, n_bands): - sigma = np.std(dataset_mags_z1[:, i]) - lower_limit = np.mean(dataset_mags_z1[:, i]) - (4 * sigma) - upper_limit = np.mean(dataset_mags_z1[:, i]) + (4 * sigma) - if i == n_bands - 1: - upper_limit = mag_thresh - mag_bin_edges = np.arange( - lower_limit, - upper_limit, - dmag, - ) - ax[z, i].set_xlim(lower_limit, upper_limit) - ax_offset[z, i].set_xlim(lower_limit, upper_limit) - - # model 1 - lc_phot1_obs_mags = np.concatenate( - [ - lc_phot1.obs_mags_q[:, i], - lc_phot1.obs_mags_smooth_ms[:, i], - lc_phot1.obs_mags_bursty_ms[:, i], - ] - ) - lc_phot1_hist = ax[z, i].hist( - lc_phot1_obs_mags, - weights=N_weights1 * (1 / lc_vol_mpc3), - bins=mag_bin_edges, - histtype="step", - color=color1, - alpha=alpha1, - label=label1, - lw=lw + 1, - ) - - # model 2 - if diffstarpop_params2 is not None: - lc_phot2_obs_mags = np.concatenate( - [ - lc_phot2.obs_mags_q[:, i], - lc_phot2.obs_mags_smooth_ms[:, i], - lc_phot2.obs_mags_bursty_ms[:, i], - ] - ) - - ax[z, i].hist( - lc_phot2_obs_mags, - weights=N_weights2 * (1 / lc_vol_mpc3), - bins=mag_bin_edges, - histtype="step", - color=color2, - alpha=alpha2, - lw=lw, - label=label2, - ) - - # data - data_hist = ax[z, i].hist( - dataset_mags_z[:, i], - weights=np.ones_like(dataset_mags_z[:, i]) * (1 / data_vol_mpc3), - bins=mag_bin_edges, - color=color_data, - lw=lw, - alpha=alpha_data, - label="FENIKS-UDS", - ) - - """ax_offset""" - mag_bin_centers = (mag_bin_edges[1:] + mag_bin_edges[:-1]) / 2 - offset = data_hist[0] / lc_phot1_hist[0] - ax_offset[z, i].plot(mag_bin_centers, offset, lw=2, color="k") - ax_offset[z, i].set_ylim(0.09, 10.1) - ax_offset[z, i].set_yticks([0.1, 0.2, 0.5, 1, 2, 5, 10]) - ax_offset[z, i].set_yscale("log") - ax_offset[z, i].set_xlabel(dimension_labels[i], fontsize=fontsize) - ax_offset[z, i].tick_params( - axis="both", direction="in", labelsize=labelsize - ) - - ax[z, i].set_yscale("log") - ax[z, i].set_xlabel(dimension_labels[i], fontsize=fontsize) - ax[z, i].set_ylim(1e-6, 5e-3) - ax[z, i].tick_params(axis="both", direction="in", labelsize=labelsize) - - ax_offset_yticks = np.array([0.2, 0.5, 1, 2, 5]) - ax_offset[z, i].set_yticks(ax_offset_yticks) - ax_offset[z, i].axhspan( - ax_offset_yticks[1], ax_offset_yticks[3], color="orange", alpha=0.5 - ) - ax_offset[z, i].axhspan( - ax_offset_yticks[0], ax_offset_yticks[1], color="r", alpha=0.5 - ) - ax_offset[z, i].axhspan( - ax_offset_yticks[3], ax_offset_yticks[4], color="r", alpha=0.5 - ) - ax_offset[z, i].axhspan(0, ax_offset_yticks[0], color="r", alpha=0.8) - ax_offset[z, i].axhspan(ax_offset_yticks[4], 10, color="r", alpha=0.8) - - if i != 0: - ax[z, i].set_yticklabels([]) - ax_offset[z, i].set_yticklabels([]) - if z != n_zbins - 1: - ax[z, i].set_xticklabels([]) - ax_offset[z, i].set_xticklabels([]) - if i == 0: - ax_offset[z, i].set_yticklabels(["5x", "2x", "1x", "2x", "5x"]) - - ax[0, -1].legend( - framealpha=0.5, - loc="upper left", - bbox_to_anchor=(-2, 1.4), - ncols=3, - fontsize=legend_fontsize, - ) - fig.supylabel("\u03d5 [Mpc$^{-3}$]", fontsize=fontsize) - fig.savefig(savedir + "/mags_" + savedir.split("/")[-1] + ".pdf") - - fig_offset.supylabel("n$_{FENIKS}$ / n$_{diffsky}$", fontsize=fontsize) - fig_offset.savefig(savedir + "/mags_offsets_" + savedir.split("/")[-1] + ".pdf") - - plt.show() - - -def plot_n( - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - label1, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - dimension_labels, - ran_key, - zmins, - zmaxs, - ssp_data, - mzr_params, - scatter_params, - suptitle, - zbin_titles, - savedir, - dataset_colors_mag=None, - data_sky_area_degsq=None, - diffstarpop_params2=None, - spspop_params2=None, - ssperrpop_params2=None, - label2=None, - lh_centroids=None, - lg_n_data_err_lh=None, - lg_n_thresh=None, - dmag=0.1, - lgmp_min=10.0, - lgmp_max=15.0, - lc_vol_mpc3=7e4, - cosmo_params=DEFAULT_COSMOLOGY, - fb=FB, - n_z_phot_table=15, -): - # Plot 1D histograms - n_bands = len(tcurves) - n_dims = n_bands - 1 + len(mag_columns) - n_zbins = len(zmins) - - fig_width = 3.00 * n_dims - fig_height = 3.25 * n_zbins - fig, ax = plt.subplots( - n_zbins, - n_dims, - figsize=(fig_width, fig_height), - ) - - fig.subplots_adjust( - left=0.065, hspace=0, top=0.95, right=0.99, bottom=0.05, wspace=0.0 - ) - fig.suptitle(suptitle, fontsize=32) - - fig_offset, ax_offset = plt.subplots( - n_zbins, - n_dims, - figsize=(fig_width, fig_height), - ) - - fig_offset.subplots_adjust( - left=0.065, hspace=0, top=0.95, right=0.99, bottom=0.05, wspace=0.0 - ) - fig_offset.suptitle(suptitle, fontsize=32) - - dataset_colors_mag_z1 = np.array(dataset_colors_mag[0]) - for z in range(0, n_zbins): - zmin = zmins[z] - zmax = zmaxs[z] - dataset_colors_mag_z = np.array(dataset_colors_mag[z]) - - t = int(n_dims / 2) - ax[z, t].set_title(zbin_titles[z], y=0.85, fontsize=labelsize) - - if data_sky_area_degsq is not None: - data_vol_mpc3 = zbin_volume( - data_sky_area_degsq, zlow=zmin, zhigh=zmax - ).value - - if diffstarpop_params2 is not None: - ( - obs_colors_mag1, - N_weights1, - obs_colors_mag2, - N_weights2, - ) = get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - diffstarpop_params2=diffstarpop_params2, - spspop_params2=spspop_params2, - ssperrpop_params2=ssperrpop_params2, - ) - - else: - obs_colors_mag1, N_weights1 = get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - ) - - for i in range(0, n_dims): - sigma = np.std(dataset_colors_mag_z1[:, i]) - lower_limit = np.mean(dataset_colors_mag_z1[:, i]) - (4 * sigma) - upper_limit = np.mean(dataset_colors_mag_z1[:, i]) + (4 * sigma) - if i == n_dims - 1: - upper_limit = mag_thresh - bins = np.arange( - lower_limit, - upper_limit, - dmag, - ) - ax[z, i].set_xlim(lower_limit, upper_limit) - ax_offset[z, i].set_xlim(lower_limit, upper_limit) - - obs_colors_mag1_hist = ax[z, i].hist( - obs_colors_mag1[:, i], - weights=N_weights1 * (1 / lc_vol_mpc3), - bins=bins, - histtype="step", - color=color1, - alpha=alpha1, - lw=lw + 2, - label=label1, - ) - if diffstarpop_params2 is not None: - ax[z, i].hist( - obs_colors_mag2[:, i], - weights=N_weights2 * (1 / lc_vol_mpc3), - bins=bins, - histtype="step", - color=color2, - alpha=alpha2, - lw=lw, - label=label2, - ) - - # data - if dataset_colors_mag_z is not None: - dataset_colors_mag_hist = ax[z, i].hist( - dataset_colors_mag_z[:, i], - weights=np.ones_like(dataset_colors_mag_z[:, i]) - * (1 / data_vol_mpc3), - bins=bins, - color=color_data, - alpha=alpha_data, - lw=lw, - label="FENIKS-UDS", - ) - """ax_offset""" - bin_centers = (bins[1:] + bins[:-1]) / 2 - offset = dataset_colors_mag_hist[0] / obs_colors_mag1_hist[0] - ax_offset[z, i].plot(bin_centers, offset, lw=2, color="k") - ax_offset[z, i].set_ylim(0.09, 10.1) - ax_offset[z, i].set_yticks([0.1, 0.2, 0.5, 1, 2, 5, 10]) - ax_offset[z, i].set_yscale("log") - ax_offset[z, i].set_xlabel(dimension_labels[i], fontsize=fontsize) - ax_offset[z, i].tick_params( - axis="both", direction="in", labelsize=labelsize - ) - - ax[z, i].set_yscale("log") - ax[z, i].set_xlabel(dimension_labels[i], fontsize=fontsize) - ax[z, i].set_ylim(1e-6, 3e-2) - ax[z, i].tick_params(axis="both", direction="in", labelsize=labelsize) - - ax_offset_yticks = np.array([0.2, 0.5, 1, 2, 5]) - ax_offset[z, i].set_yticks(ax_offset_yticks) - ax_offset[z, i].axhspan( - ax_offset_yticks[1], ax_offset_yticks[3], color="orange", alpha=0.5 - ) - ax_offset[z, i].axhspan( - ax_offset_yticks[0], ax_offset_yticks[1], color="r", alpha=0.5 - ) - ax_offset[z, i].axhspan( - ax_offset_yticks[3], ax_offset_yticks[4], color="r", alpha=0.5 - ) - ax_offset[z, i].axhspan(0, ax_offset_yticks[0], color="r", alpha=0.8) - ax_offset[z, i].axhspan(ax_offset_yticks[4], 10, color="r", alpha=0.8) - - if i != 0: - ax[z, i].set_yticklabels([]) - ax_offset[z, i].set_yticklabels([]) - if z != n_zbins - 1: - ax[z, i].set_xticklabels([]) - ax_offset[z, i].set_xticklabels([]) - if i == 0: - ax_offset[z, i].set_yticklabels(["5x", "2x", "1x", "2x", "5x"]) - - ax[0, -1].legend( - framealpha=0.5, - loc="upper left", - bbox_to_anchor=(-2, 1.4), - ncols=3, - fontsize=legend_fontsize, - ) - - fig.supylabel("\u03d5 [Mpc$^{-3}$]", fontsize=fontsize) - fig.savefig(savedir + "/phot_fit_" + savedir.split("/")[-1] + ".pdf") - - fig_offset.supylabel("n$_{FENIKS}$ / n$_{diffsky}$", fontsize=fontsize) - fig_offset.savefig(savedir + "/phot_offsets_" + savedir.split("/")[-1] + ".pdf") - - plt.show() - - -def plot_n_corner( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - label1, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - dataset_colors_mag, - dimension_labels, - title, - savedir, - ssp_data, - mzr_params, - scatter_params, - dmag=0.1, - cosmo_params=DEFAULT_COSMOLOGY, - fb=FB, - lgmp_min=10.0, - lgmp_max=15.0, - lc_vol_mpc3=7e4, - diffstarpop_params2=None, - spspop_params2=None, - ssperrpop_params2=None, - label2=None, -): - if diffstarpop_params2 is not None: - ( - obs_colors_mag1, - N_weights1, - obs_colors_mag2, - N_weights2, - ) = get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - diffstarpop_params2=diffstarpop_params2, - spspop_params2=spspop_params2, - ssperrpop_params2=ssperrpop_params2, - ) - - else: - obs_colors_mag1, N_weights1 = get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - ) - color_bin_edges = np.arange(-0.5 - dmag / 2, 2.2, dmag) - mag_bin_edges = np.arange(18.0 - dmag / 2, mag_thresh, dmag) - ranges = [(color_bin_edges[0], color_bin_edges[-1])] * ( - len(dimension_labels) - len(mag_columns) - ) - for m in range(0, len(mag_columns)): - ranges.append((mag_bin_edges[0], mag_bin_edges[-1])) - - # data - fig_corner = corner.corner( - dataset_colors_mag, - # weights=dataset_colors_mag, - color=color_data, - labels=dimension_labels, - label_kwargs={"fontsize": 20}, - plot_datapoints=False, - smooth=1.0, - levels=[0.68, 0.95], - hist_kwargs={ - "histtype": "stepfilled", - "alpha": 0.5, - "lw": lw, - "density": True, - }, - fill_contours=False, - plot_density=False, - contour_kwargs={"linewidths": 3.5, "alpha": 0.75}, - range=ranges, - ) - - fig_corner.suptitle(title, fontsize=fontsize + 4) - - # model 1 - corner.corner( - obs_colors_mag1, - weights=N_weights1, - fig=fig_corner, - color=color1, - smooth=1.0, - plot_datapoints=False, - levels=[0.68, 0.95], - hist_kwargs={ - "histtype": "step", - "alpha": 0.5, - "lw": lw + 2, - "density": True, - }, - fill_contours=False, - plot_density=False, - contour_kwargs={"linewidths": 3.5, "alpha": 0.75}, - range=ranges, - ) - - # model 2 - if diffstarpop_params2 is not None: - corner.corner( - obs_colors_mag2, - weights=N_weights2, - fig=fig_corner, - color=color2, - smooth=1.0, - plot_datapoints=False, - levels=[0.68, 0.95], - hist_kwargs={ - "histtype": "step", - "alpha": 0.5, - "lw": lw + 2, - "density": True, - }, - plot_density=False, - fill_contours=False, - contour_kwargs={"linewidths": 3.5, "alpha": 0.75}, - range=ranges, - ) - - if label2 is not None: - handles = [ - Line2D([], [], color=color1, lw=lw + 1, label=label1), - Line2D([], [], color=color2, lw=lw + 1, label=label2), - Line2D([], [], color=color_data, lw=lw, label="FENIKS-UDS"), - ] - else: - handles = [ - Line2D([], [], color=color1, lw=lw + 1, label=label1), - Line2D([], [], color=color_data, lw=lw, label="FENIKS-UDS"), - ] - - fig_corner.axes[0].legend( - handles=handles, - loc="center left", - bbox_to_anchor=(1.0, 0.5), - frameon=False, - fontsize=fontsize, - ) - - for ax in fig_corner.get_axes(): - ax.tick_params(axis="both", direction="in", labelsize=labelsize / 2) - fig_corner.savefig( - savedir - + "/z" - + str(zmin) - + "-" - + str(zmax) - + "_corner_fit_" - + savedir.split("/")[-1] - + ".pdf" - ) - plt.show() - - -def get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - n_z_phot_table=15, - data_sky_area_degsq=None, - diffstarpop_params2=None, - spspop_params2=None, - ssperrpop_params2=None, -): - """mc lightcone""" - ran_key, lc_key = jran.split(ran_key, 2) - sky_area_degsq = zbin_area(lc_vol_mpc3, zlow=zmin, zhigh=zmax).value - lc_args = (lc_key, lgmp_min, zmin, zmax, sky_area_degsq) - lc_halopop = mclh.mc_lightcone_host_halo_diffmah( - *lc_args, cosmo_params=cosmo_params, lgmp_max=lgmp_max - ) - - z_phot_table = jnp.linspace(zmin, zmax, n_z_phot_table) - t_0 = flat_wcdm.age_at_z0(*cosmo_params) - lgt0 = jnp.log10(t_0) - t_table = jnp.linspace(T_TABLE_MIN, 10**lgt0, 100) - - precomputed_ssp_mag_table = psspp.get_precompute_ssp_mag_redshift_table( - tcurves, ssp_data, z_phot_table, cosmo_params - ) - - wave_eff_table = lc_phot_kern.get_wave_eff_table(z_phot_table, tcurves) - - ran_key, phot_key1 = jran.split(ran_key, 2) - phot_args1 = ( - phot_key1, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - diffstarpop_params1, - mzr_params, - spspop_params1, - scatter_params, - ssperrpop_params1, - cosmo_params, - fb, - ) - - lc_phot1 = lc_phot_kern.multiband_lc_phot_kern(*phot_args1) - num_halos, n_bands = lc_phot1.obs_mags_q.shape - - ( - obs_colors_mag_q1, - obs_colors_mag_smooth_ms1, - obs_colors_mag_bursty_ms1, - ) = get_obs_colors_mag(lc_phot1, mag_columns) - obs_colors_mag1 = np.concatenate( - [obs_colors_mag_q1, obs_colors_mag_smooth_ms1, obs_colors_mag_bursty_ms1] - ) - - # set weights=0 for mag > mag_thresh for the band indicated by mag_thresh_column - obs_mag_q1 = lc_phot1.obs_mags_q[:, mag_thresh_column] - obs_mag_smooth_ms1 = lc_phot1.obs_mags_smooth_ms[:, mag_thresh_column] - obs_mag_bursty_ms1 = lc_phot1.obs_mags_bursty_ms[:, mag_thresh_column] - - lc_phot_weights_q1 = jnp.where( - obs_mag_q1 < mag_thresh, - lc_phot1.weights_q, - jnp.zeros_like(lc_phot1.weights_q), - ) - lc_phot_weights_smooth_ms1 = jnp.where( - obs_mag_smooth_ms1 < mag_thresh, - lc_phot1.weights_smooth_ms, - jnp.zeros_like(lc_phot1.weights_smooth_ms), - ) - lc_phot_weights_bursty_ms1 = jnp.where( - obs_mag_bursty_ms1 < mag_thresh, - lc_phot1.weights_bursty_ms, - jnp.zeros_like(lc_phot1.weights_bursty_ms), - ) - N_weights1 = np.concatenate( - [ - lc_phot_weights_q1 * frac_cat, - lc_phot_weights_smooth_ms1 * frac_cat, - lc_phot_weights_bursty_ms1 * frac_cat, - ] - ) - - if diffstarpop_params2 is not None: - ran_key, phot_key2 = jran.split(ran_key, 2) - phot_args2 = ( - phot_key2, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - diffstarpop_params2, - mzr_params, - spspop_params2, - scatter_params, - ssperrpop_params2, - cosmo_params, - fb, - ) - - lc_phot2 = lc_phot_kern.multiband_lc_phot_kern(*phot_args2) - - ( - obs_colors_mag_q2, - obs_colors_mag_smooth_ms2, - obs_colors_mag_bursty_ms2, - ) = get_obs_colors_mag(lc_phot2, mag_columns) - obs_colors_mag2 = np.concatenate( - [ - obs_colors_mag_q2, - obs_colors_mag_smooth_ms2, - obs_colors_mag_bursty_ms2, - ] - ) - # set weights=0 for mag > mag_thresh for the band indicated by mag_thresh_column - obs_mag_q2 = lc_phot2.obs_mags_q[:, mag_thresh_column] - obs_mag_smooth_ms2 = lc_phot2.obs_mags_smooth_ms[:, mag_thresh_column] - obs_mag_bursty_ms2 = lc_phot2.obs_mags_bursty_ms[:, mag_thresh_column] - - lc_phot_weights_q2 = jnp.where( - obs_mag_q2 < mag_thresh, - lc_phot2.weights_q, - jnp.zeros_like(lc_phot2.weights_q), - ) - lc_phot_weights_smooth_ms2 = jnp.where( - obs_mag_smooth_ms2 < mag_thresh, - lc_phot2.weights_smooth_ms, - jnp.zeros_like(lc_phot2.weights_smooth_ms), - ) - lc_phot_weights_bursty_ms2 = jnp.where( - obs_mag_bursty_ms2 < mag_thresh, - lc_phot2.weights_bursty_ms, - jnp.zeros_like(lc_phot2.weights_bursty_ms), - ) - N_weights2 = np.concatenate( - [ - lc_phot_weights_q2 * frac_cat, - lc_phot_weights_smooth_ms2 * frac_cat, - lc_phot_weights_bursty_ms2 * frac_cat, - ] - ) - return obs_colors_mag1, N_weights1, obs_colors_mag2, N_weights2 - else: - return obs_colors_mag1, N_weights1 - - -def get_obs_colors_mag(lc_phot, mag_columns): - num_halos, n_bands = lc_phot.obs_mags_q.shape - - obs_colors_mag_q = [] - obs_colors_mag_smooth_ms = [] - obs_colors_mag_bursty_ms = [] - - for i in range(n_bands - 1): - obs_color_q = lc_phot.obs_mags_q[:, i] - lc_phot.obs_mags_q[:, i + 1] - obs_colors_mag_q.append(obs_color_q) - - obs_color_smooth_ms = ( - lc_phot.obs_mags_smooth_ms[:, i] - lc_phot.obs_mags_smooth_ms[:, i + 1] - ) - obs_colors_mag_smooth_ms.append(obs_color_smooth_ms) - - obs_color_bursty_ms = ( - lc_phot.obs_mags_bursty_ms[:, i] - lc_phot.obs_mags_bursty_ms[:, i + 1] - ) - obs_colors_mag_bursty_ms.append(obs_color_bursty_ms) - - """mag_column""" - for mag_column in mag_columns: - obs_mag_q = lc_phot.obs_mags_q[:, mag_column] - obs_colors_mag_q.append(obs_mag_q) - - obs_mag_smooth_ms = lc_phot.obs_mags_smooth_ms[:, mag_column] - obs_colors_mag_smooth_ms.append(obs_mag_smooth_ms) - - obs_mag_bursty_ms = lc_phot.obs_mags_bursty_ms[:, mag_column] - obs_colors_mag_bursty_ms.append(obs_mag_bursty_ms) - - obs_colors_mag_q = jnp.asarray(obs_colors_mag_q).T - obs_colors_mag_smooth_ms = jnp.asarray(obs_colors_mag_smooth_ms).T - obs_colors_mag_bursty_ms = jnp.asarray(obs_colors_mag_bursty_ms).T - - return obs_colors_mag_q, obs_colors_mag_smooth_ms, obs_colors_mag_bursty_ms diff --git a/diffhtwo/experimental/emline_luminosity_pop.py b/diffhtwo/experimental/emline_luminosity_pop.py deleted file mode 100644 index 4a232a97..00000000 --- a/diffhtwo/experimental/emline_luminosity_pop.py +++ /dev/null @@ -1,246 +0,0 @@ -# flake8: noqa: E402 -""" """ -from jax import config - -config.update("jax_enable_x64", True) - -from collections import namedtuple - -import jax.numpy as jnp -from diffsky.burstpop import diffqburstpop_mono, freqburst_mono -from diffsky.dustpop import tw_dustpop_mono_noise -from diffsky.experimental.lc_phot_kern import diffstarpop_lc_cen_wrapper -from dsps.metallicity import umzr -from dsps.sed import metallicity_weights as zmetw -from dsps.sed.stellar_age_weights import calc_age_weights_from_sfh_table -from jax import jit as jjit -from jax import random as jran -from jax import vmap - -from . import emline_luminosity -from .emline_utils import get_ssp_emline_luminosity - -LGMET_SCATTER = 0.2 - -# copied from astropy.constants.L_sun.cgs.value -L_SUN_CGS = jnp.array(3.828e33, dtype="float64") - - -_M = (0, None, None) -_calc_lgmet_weights_galpop = jjit( - vmap(zmetw.calc_lgmet_weights_from_lognormal_mdf, in_axes=_M) -) - -_B = (None, 0, 0, None, 0) -_calc_bursty_age_weights_vmap = jjit( - vmap( - diffqburstpop_mono.calc_bursty_age_weights_from_diffburstpop_params, in_axes=_B - ) -) - -_AGEPOP = (None, 0, None, 0) -calc_age_weights_from_sfh_table_vmap = jjit( - vmap(calc_age_weights_from_sfh_table, in_axes=_AGEPOP) -) - - -_D = (None, None, 0, 0, 0, None, 0, 0, 0, None) -calc_dust_ftrans_vmap = jjit( - vmap( - tw_dustpop_mono_noise.calc_ftrans_singlegal_singlewave_from_dustpop_params, - in_axes=_D, - ) -) - - -_LCLINE_RET_KEYS = ( - "emline_L_cgs_q", - "emline_L_cgs_smooth_ms", - "emline_L_cgs_bursty_ms", - "weights_q", - "weights_smooth_ms", - "weights_bursty_ms", -) -LCLine = namedtuple("LCLine", _LCLINE_RET_KEYS) -LCLINE_EMPTY = LCLine._make([None] * len(LCLine._fields)) - - -@jjit -def emline_luminosity_pop( - diffstarpop_params, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, -): - ssp_emline_luminosity = get_ssp_emline_luminosity(emline_wave_aa, ssp_data) - n_met, n_age = ssp_emline_luminosity.shape - n_gals = logmp0.size - - ran_key, sfh_key = jran.split(ran_key, 2) - diffstar_galpop = diffstarpop_lc_cen_wrapper( - diffstarpop_params, - sfh_key, - mah_params, - logmp0, - t_table, - t_obs, - cosmo_params, - fb, - ) - - # get age weights - smooth_age_weights_ms = calc_age_weights_from_sfh_table_vmap( - t_table, diffstar_galpop.sfh_ms, ssp_data.ssp_lg_age_gyr, t_obs - ) - - smooth_age_weights_q = calc_age_weights_from_sfh_table_vmap( - t_table, diffstar_galpop.sfh_q, ssp_data.ssp_lg_age_gyr, t_obs - ) - - # get bursty age weights - _args = ( - spspop_params.burstpop_params, - diffstar_galpop.logsm_obs_ms, - diffstar_galpop.logssfr_obs_ms, - ssp_data.ssp_lg_age_gyr, - smooth_age_weights_ms, - ) - bursty_age_weights_ms, burst_params = _calc_bursty_age_weights_vmap(*_args) - - # get p_burst_ms - p_burst_ms = freqburst_mono.get_freqburst_from_freqburst_params( - spspop_params.burstpop_params.freqburst_params, - diffstar_galpop.logsm_obs_ms, - diffstar_galpop.logssfr_obs_ms, - ) - - # get metallicity weights - lgmet_med_ms = umzr.mzr_model(diffstar_galpop.logsm_obs_ms, t_obs, *mzr_params) - lgmet_med_q = umzr.mzr_model(diffstar_galpop.logsm_obs_q, t_obs, *mzr_params) - - lgmet_weights_ms = _calc_lgmet_weights_galpop( - lgmet_med_ms, LGMET_SCATTER, ssp_data.ssp_lgmet - ) - lgmet_weights_q = _calc_lgmet_weights_galpop( - lgmet_med_q, LGMET_SCATTER, ssp_data.ssp_lgmet - ) - - # age weights * metallicity weights - _w_age_q = smooth_age_weights_q.reshape((n_gals, 1, n_age)) - _w_lgmet_q = lgmet_weights_q.reshape((n_gals, n_met, 1)) - ssp_weights_q = _w_lgmet_q * _w_age_q - - _w_age_ms = smooth_age_weights_ms.reshape((n_gals, 1, n_age)) - _w_lgmet_ms = lgmet_weights_ms.reshape((n_gals, n_met, 1)) - ssp_weights_smooth_ms = _w_lgmet_ms * _w_age_ms - - _w_age_bursty_ms = bursty_age_weights_ms.reshape((n_gals, 1, n_age)) - ssp_weights_bursty_ms = _w_lgmet_ms * _w_age_bursty_ms - - # get ftrans due to dust - ran_key, dust_key = jran.split(ran_key, 2) - av_key, delta_key, funo_key = jran.split(dust_key, 3) - uran_av = jran.uniform(av_key, shape=(n_gals,)) - uran_delta = jran.uniform(delta_key, shape=(n_gals,)) - uran_funo = jran.uniform(funo_key, shape=(n_gals,)) - - ftrans_args_q = ( - spspop_params.dustpop_params, - emline_wave_aa, - diffstar_galpop.logsm_obs_q, - diffstar_galpop.logssfr_obs_q, - z_obs, - ssp_data.ssp_lg_age_gyr, - uran_av, - uran_delta, - uran_funo, - scatter_params, - ) - _res = calc_dust_ftrans_vmap(*ftrans_args_q) - ftrans_q = _res[1] - - ftrans_args_ms = ( - spspop_params.dustpop_params, - emline_wave_aa, - diffstar_galpop.logsm_obs_ms, - diffstar_galpop.logssfr_obs_ms, - z_obs, - ssp_data.ssp_lg_age_gyr, - uran_av, - uran_delta, - uran_funo, - scatter_params, - ) - _res = calc_dust_ftrans_vmap(*ftrans_args_ms) - ftrans_ms = _res[1] - - _ftrans_q = ftrans_q.reshape((n_gals, 1, n_age)) - _ftrans_ms = ftrans_ms.reshape((n_gals, 1, n_age)) - - _mstar_q = 10**diffstar_galpop.logsm_obs_q - _mstar_ms = 10**diffstar_galpop.logsm_obs_ms - - integrand_q = ssp_emline_luminosity * ssp_weights_q * _ftrans_q - emline_L_cgs_q = jnp.sum(integrand_q, axis=(1, 2)) * (L_SUN_CGS * _mstar_q) - - integrand_smooth_ms = ssp_emline_luminosity * ssp_weights_smooth_ms * _ftrans_ms - emline_L_cgs_smooth_ms = jnp.sum(integrand_smooth_ms, axis=(1, 2)) * ( - L_SUN_CGS * _mstar_ms - ) - - integrand_bursty_ms = ssp_emline_luminosity * ssp_weights_bursty_ms * _ftrans_ms - emline_L_cgs_bursty_ms = jnp.sum(integrand_bursty_ms, axis=(1, 2)) * ( - L_SUN_CGS * _mstar_ms - ) - - weights_q = diffstar_galpop.frac_q - weights_smooth_ms = (1 - diffstar_galpop.frac_q) * (1 - p_burst_ms) - weights_bursty_ms = (1 - diffstar_galpop.frac_q) * p_burst_ms - - emline_L = LCLINE_EMPTY._replace( - emline_L_cgs_q=emline_L_cgs_q, - emline_L_cgs_smooth_ms=emline_L_cgs_smooth_ms, - emline_L_cgs_bursty_ms=emline_L_cgs_bursty_ms, - weights_q=weights_q, - weights_smooth_ms=weights_smooth_ms, - weights_bursty_ms=weights_bursty_ms, - ) - - return emline_L - - -@jjit -def emline_luminosity_func_pop(emline_L_tuple, nhalos, sig=None, lgL_bin_edges=None): - # get q emline L_cgs histogram - emline_L_cgs_q = emline_L_tuple.emline_L_cgs_q - w_q = emline_L_tuple.weights_q * nhalos - lgL_bin_edges, tw_hist_q = emline_luminosity.get_emline_luminosity_func( - emline_L_cgs_q, w_q, sig=sig, lgL_bin_edges=lgL_bin_edges - ) - - # get smooth_ms emline L_cgs histogram - emline_L_cgs_smooth_ms = emline_L_tuple.emline_L_cgs_smooth_ms - w_smooth_ms = emline_L_tuple.weights_smooth_ms * nhalos - _, tw_hist_smooth_ms = emline_luminosity.get_emline_luminosity_func( - emline_L_cgs_smooth_ms, w_smooth_ms, sig=sig, lgL_bin_edges=lgL_bin_edges - ) - - # get bursty_ms emline L_cgs histogram - emline_L_cgs_bursty_ms = emline_L_tuple.emline_L_cgs_bursty_ms - w_bursty_ms = emline_L_tuple.weights_bursty_ms * nhalos - _, tw_hist_bursty_ms = emline_luminosity.get_emline_luminosity_func( - emline_L_cgs_bursty_ms, w_bursty_ms, sig=sig, lgL_bin_edges=lgL_bin_edges - ) - - return lgL_bin_edges, tw_hist_q, tw_hist_smooth_ms, tw_hist_bursty_ms diff --git a/diffhtwo/experimental/optimizers/emline_luminosity_opt.py b/diffhtwo/experimental/optimizers/emline_luminosity_opt.py deleted file mode 100644 index 3f2a0a85..00000000 --- a/diffhtwo/experimental/optimizers/emline_luminosity_opt.py +++ /dev/null @@ -1,166 +0,0 @@ -# flake8: noqa: E402 -""" """ -import jax - -jax.config.update("jax_enable_x64", True) -jax.config.update("jax_debug_nans", True) -jax.config.update("jax_debug_infs", True) - -from functools import partial - -import jax.numpy as jnp -from diffstar.diffstarpop import get_bounded_diffstarpop_params -from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_U_PARAMS -from jax import jit as jjit -from jax import lax, value_and_grad -from jax.example_libraries import optimizers as jax_opt -from jax.flatten_util import ravel_pytree - -from ..emline_luminosity_pop import emline_luminosity_func_pop, emline_luminosity_pop - -u_theta_default, u_unravel_fn = ravel_pytree(DEFAULT_DIFFSTARPOP_U_PARAMS) -IDX = jnp.arange(16, 22, 1) - - -@jjit -def _mse(emline_lf_weighted_composite_true, emline_lf_weighted_composite_pred): - diff = emline_lf_weighted_composite_true - emline_lf_weighted_composite_pred - return jnp.mean(jnp.square(diff)) - - -def make_subspace_loss(u_unravel_fn, u_theta_default, IDX): - """ - Build a loss that optimizes ONLY the parameters at flat indices `IDX`. - - u_unravel_fn: from ravel_pytree(template) - - u_theta_default: 1D base vector (others stay fixed to these values) - - IDX: 1D array/list of flat indices to vary (static for the compiled fn) - - Notes: The only thing you should not do is use namedtuple._replace() / ._asdict() - inside @jjit. Those are Python-side and will slow/break JIT. - """ - IDX = jnp.asarray(IDX, dtype=jnp.int64) # capture in closure - - @jjit - def _loss_kern_subspace( - u_theta_sub, # only the selected subset: shape (len(IDX),) - emline_lf_weighted_composite_true, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, - ): - # scatter the subset into the full flat vector - u_theta_full = u_theta_default.at[IDX].set(u_theta_sub) - - # back to structured params and do the usual - u_diffstarpop_params = u_unravel_fn(u_theta_full) - - # convert to bounded params - diffstarpop_params = get_bounded_diffstarpop_params(u_diffstarpop_params) - - emline_lf_pred = emline_luminosity_pop( - diffstarpop_params, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, - ) - - nhalos = jnp.ones_like(emline_lf_pred.emline_L_cgs_q) - ( - lgL_bin_edges, - emline_lf_weighted_q_pred, - emline_lf_weighted_smooth_ms_pred, - emline_lf_weighted_bursty_ms_pred, - ) = emline_luminosity_func_pop(emline_lf_pred, nhalos) - - emline_lf_weighted_composite_pred = ( - emline_lf_weighted_q_pred - + emline_lf_weighted_smooth_ms_pred - + emline_lf_weighted_bursty_ms_pred - ) - - return _mse( - emline_lf_weighted_composite_true, emline_lf_weighted_composite_pred - ) - - return _loss_kern_subspace - - -loss_kern = make_subspace_loss(u_unravel_fn, u_theta_default, IDX) -loss_and_grad_fn = jjit(value_and_grad(loss_kern)) - - -@partial(jjit, static_argnames=["n_steps", "step_size"]) -def fit_emline_luminosity( - u_theta_init_sub, # only the selected subset: shape (len(IDX),) - emline_lf_weighted_composite_true, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, - n_steps=10, - step_size=1e-2, -): - opt_init, opt_update, get_params = jax_opt.adam(step_size) - opt_state = opt_init(u_theta_init_sub) - - other = ( - emline_lf_weighted_composite_true, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, - ) - - def _opt_update(opt_state, i): - u_theta_sub = get_params(opt_state) - loss, grads = loss_and_grad_fn(u_theta_sub, *other) - opt_state = opt_update(i, grads, opt_state) - return opt_state, loss - - opt_state, loss_hist = lax.scan(_opt_update, opt_state, jnp.arange(n_steps)) - - u_theta_fit_sub = get_params(opt_state) - - return loss_hist, u_theta_fit_sub diff --git a/diffhtwo/experimental/optimizers/tests/test_emline_luminosity_opt.py b/diffhtwo/experimental/optimizers/tests/test_emline_luminosity_opt.py deleted file mode 100644 index ca209f53..00000000 --- a/diffhtwo/experimental/optimizers/tests/test_emline_luminosity_opt.py +++ /dev/null @@ -1,120 +0,0 @@ -import jax.numpy as jnp -import numpy as np -from diffsky.experimental import mc_lightcone_halos as mclh -from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS -from diffsky.param_utils import spspop_param_utils as spspu -from diffstar.defaults import FB, T_TABLE_MIN -from diffstar.diffstarpop.defaults import ( - DEFAULT_DIFFSTARPOP_PARAMS, - DEFAULT_DIFFSTARPOP_U_PARAMS, -) -from dsps.cosmology import DEFAULT_COSMOLOGY, flat_wcdm -from dsps.data_loaders import retrieve_fake_fsps_data -from dsps.metallicity import umzr -from jax import random as jran -from jax.flatten_util import ravel_pytree - -from ...emline_luminosity_pop import emline_luminosity_func_pop, emline_luminosity_pop -from ..emline_luminosity_opt import IDX, fit_emline_luminosity - -u_theta_default, u_unravel_fn = ravel_pytree(DEFAULT_DIFFSTARPOP_U_PARAMS) -theta_default, unravel_fn = ravel_pytree(DEFAULT_DIFFSTARPOP_PARAMS) - -ssp_data = retrieve_fake_fsps_data.load_fake_ssp_data() -emline_wave_aa = 6000 - - -def test_emline_luminosity_opt(): - ran_key = jran.key(0) - - # generate lightcone - ran_key, lc_key = jran.split(ran_key, 2) - lgmp_min = 12.0 - z_min, z_max = 0.1, 0.5 - sky_area_degsq = 0.1 - - args = (lc_key, lgmp_min, z_min, z_max, sky_area_degsq) - - lc_halopop = mclh.mc_lightcone_host_halo_diffmah(*args) - - n_z_phot_table = 15 - z_phot_table = np.linspace(z_min, z_max, n_z_phot_table) - - z_obs = lc_halopop["z_obs"] - t_obs = lc_halopop["t_obs"] - mah_params = lc_halopop["mah_params"] - logmp0 = lc_halopop["logmp0"] - t_0 = flat_wcdm.age_at_z0(*DEFAULT_COSMOLOGY) - lgt0 = np.log10(t_0) - - t_table = np.linspace(T_TABLE_MIN, 10**lgt0, 100) - - mzr_params = umzr.DEFAULT_MZR_PARAMS - - spspop_params = spspu.DEFAULT_SPSPOP_PARAMS - scatter_params = DEFAULT_SCATTER_PARAMS - - ran_key, dpop_halpha_true_key = jran.split(ran_key, 2) - args = ( - DEFAULT_DIFFSTARPOP_PARAMS, - dpop_halpha_true_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - DEFAULT_COSMOLOGY, - FB, - ) - halpha_L_true = emline_luminosity_pop(*args) - nhalos = jnp.ones_like(halpha_L_true.emline_L_cgs_q) - ( - lgL_bin_edges, - halpha_lf_weighted_q_true, - halpha_lf_weighted_smooth_ms_true, - halpha_lf_weighted_bursty_ms_true, - ) = emline_luminosity_func_pop(halpha_L_true, nhalos) - - halpha_lf_weighted_composite_true = ( - halpha_lf_weighted_q_true - + halpha_lf_weighted_smooth_ms_true - + halpha_lf_weighted_bursty_ms_true - ) - - noise_scale = 0.1 - ran_key, perturb_key = jran.split(ran_key, 2) - u_theta_perturbed = u_theta_default + noise_scale * jran.normal( - perturb_key, shape=u_theta_default.shape - ) - - ran_key, dpop_halpha_perturbed_key = jran.split(ran_key, 2) - fit_args = ( - u_theta_perturbed[IDX], - halpha_lf_weighted_composite_true, - dpop_halpha_perturbed_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - DEFAULT_COSMOLOGY, - FB, - ) - - loss_hist, u_theta_fit_sub = fit_emline_luminosity( - *fit_args, n_steps=2, step_size=0.02 - ) - - assert np.isfinite(loss_hist).all() diff --git a/diffhtwo/experimental/tests/test_emline_luminosity_pop.py b/diffhtwo/experimental/tests/test_emline_luminosity_pop.py deleted file mode 100644 index bebd003a..00000000 --- a/diffhtwo/experimental/tests/test_emline_luminosity_pop.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np -from diffsky.experimental import mc_lightcone_halos as mclh -from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS -from diffsky.param_utils import spspop_param_utils as spspu -from diffstar.defaults import FB, T_TABLE_MIN -from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_PARAMS -from dsps.cosmology import flat_wcdm -from dsps.cosmology.defaults import DEFAULT_COSMOLOGY -from dsps.data_loaders import retrieve_fake_fsps_data -from dsps.metallicity import umzr -from jax import random as jran - -from ..emline_luminosity_pop import emline_luminosity_pop - -ssp_data = retrieve_fake_fsps_data.load_fake_ssp_data() - - -def test_emline_luminosity_pop(): - ran_key = jran.key(0) - lgmp_min = 12.0 - z_min, z_max = 0.1, 0.5 - sky_area_degsq = 0.1 - - ran_key, lc_key = jran.split(ran_key, 2) - args = (lc_key, lgmp_min, z_min, z_max, sky_area_degsq) - - lc_halopop = mclh.mc_lightcone_host_halo_diffmah(*args) - - n_z_phot_table = 15 - z_phot_table = np.linspace(z_min, z_max, n_z_phot_table) - - z_obs = lc_halopop["z_obs"] - t_obs = lc_halopop["t_obs"] - mah_params = lc_halopop["mah_params"] - logmp0 = lc_halopop["logmp0"] - t_0 = flat_wcdm.age_at_z0(*DEFAULT_COSMOLOGY) - lgt0 = np.log10(t_0) - - t_table = np.linspace(T_TABLE_MIN, 10**lgt0, 100) - - mzr_params = umzr.DEFAULT_MZR_PARAMS - - spspop_params = spspu.DEFAULT_SPSPOP_PARAMS - scatter_params = DEFAULT_SCATTER_PARAMS - - emline_wave_aa = 6000 - - ran_key, diffstarpop_key = jran.split(ran_key, 2) - args = ( - DEFAULT_DIFFSTARPOP_PARAMS, - diffstarpop_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - DEFAULT_COSMOLOGY, - FB, - ) - halpha_L = emline_luminosity_pop(*args) - - for arr in halpha_L: - assert np.all(np.isfinite(arr)) diff --git a/diffhtwo/experimental/tests/test_line_phot_kern.py b/diffhtwo/experimental/tests/test_line_phot_kern.py index 835e68d1..afcfecb5 100644 --- a/diffhtwo/experimental/tests/test_line_phot_kern.py +++ b/diffhtwo/experimental/tests/test_line_phot_kern.py @@ -1,22 +1,11 @@ -import jax.numpy as jnp import numpy as np from astropy import units as u from astropy.cosmology import FlatLambdaCDM -from diffsky.experimental import mc_lightcone_halos as mclh -from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS -from diffsky.param_utils import spspop_param_utils as spspu -from diffstar.defaults import FB, T_TABLE_MIN -from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_PARAMS -from dsps.cosmology import flat_wcdm -from dsps.cosmology.defaults import DEFAULT_COSMOLOGY from dsps.data_loaders import retrieve_fake_fsps_data -from dsps.metallicity import umzr -from jax import random as jran from .. import line_phot_kern from ..data_loaders import retrieve_tcurves from ..defaults import C_ANGSTROMS, HALPHA_CENTER_AA -from ..emline_luminosity_pop import emline_luminosity_func_pop, emline_luminosity_pop ssp_data = retrieve_fake_fsps_data.load_fake_ssp_data() @@ -118,81 +107,3 @@ def test_line_phot_kern(BB_tcurve=HSC_Z_tcurve, NB_tcurve=HSC_NB921_tcurve): BB_mag_ab_check = -2.5 * np.log10(numerator / denominator) - 48.6 assert np.isclose(BB_mag_ab, BB_mag_ab_check) - - -def test_line_mag_vmap(): - ran_key = jran.key(0) - ran_key, lc_key = jran.split(ran_key, 2) - - lgmp_min = 12.0 - z_min, z_max = 0.2, 0.5 - sky_area_degsq = 1 - - """weighted mc lightcone""" - num_halos = 500 - lgmp_max = 15.0 - args = (lc_key, num_halos, z_min, z_max, lgmp_min, lgmp_max, sky_area_degsq) - lc_halopop = mclh.mc_weighted_halo_lightcone(*args) - - n_z_phot_table = 15 - z_phot_table = np.linspace(z_min, z_max, n_z_phot_table) - - z_obs = jnp.array(lc_halopop["z_obs"]) - t_obs = lc_halopop["t_obs"] - mah_params = lc_halopop["mah_params"] - # logmp0 = lc_halopop["logmp0"] - logmp0 = lc_halopop["logmp0"] - nhalos = lc_halopop["nhalos"] - t_0 = flat_wcdm.age_at_z0(*DEFAULT_COSMOLOGY) - lgt0 = np.log10(t_0) - - t_table = np.linspace(T_TABLE_MIN, 10**lgt0, 100) - - mzr_params = umzr.DEFAULT_MZR_PARAMS - - spspop_params = spspu.DEFAULT_SPSPOP_PARAMS - scatter_params = DEFAULT_SCATTER_PARAMS - - ran_key, dpop_halpha_true_key = jran.split(ran_key, 2) - args = ( - DEFAULT_DIFFSTARPOP_PARAMS, - dpop_halpha_true_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - HALPHA_CENTER_AA, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - DEFAULT_COSMOLOGY, - FB, - ) - halpha_L_true = emline_luminosity_pop(*args) - - ( - lgL_bin_edges, - halpha_lf_weighted_q_true, - halpha_lf_weighted_smooth_ms_true, - halpha_lf_weighted_bursty_ms_true, - ) = emline_luminosity_func_pop(halpha_L_true, nhalos) - - halpha_obs_aa = HALPHA_CENTER_AA * (1 + z_obs) - d_L_Mpc = COSMO.luminosity_distance(z_obs) - d_L_cm = d_L_Mpc * MPC_TO_CM - - SXDS_z_mag_ab = line_phot_kern.get_band_mag_ab_from_luminosity( - halpha_obs_aa, - halpha_L_true, - z_obs, - d_L_cm, - HSC_Z_tcurve.wave, - HSC_Z_tcurve.transmission, - ) - - assert np.isfinite(SXDS_z_mag_ab.band_mag_ab_q).all() - assert np.isfinite(SXDS_z_mag_ab.band_mag_ab_smooth_ms).all() - assert np.isfinite(SXDS_z_mag_ab.band_mag_ab_bursty_ms).all() From 2707955c86affcee58379eee89de5be175e45827 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sat, 13 Jun 2026 19:51:30 -0500 Subject: [PATCH 57/57] limit fitting spaces to fix loss descent for feniks --- diffhtwo/experimental/data_loaders/N_utils.py | 2 +- .../experimental/data_loaders/load_feniks.py | 347 +++++++++--------- 2 files changed, 175 insertions(+), 174 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/N_utils.py b/diffhtwo/experimental/data_loaders/N_utils.py index 2978ee04..3dc8f869 100644 --- a/diffhtwo/experimental/data_loaders/N_utils.py +++ b/diffhtwo/experimental/data_loaders/N_utils.py @@ -71,7 +71,7 @@ def get_mag_space(namedtuple_name, mag, filter_name, z_sel, fit=True): AppMagFuncSpace = namedtuple(namedtuple_name, AppMagFunc._fields) mag_idx = filter_name_to_idx(filter_name) N_1d, sig, bin_lo, bin_hi = get_N_1d(mag[z_sel]) - return AppMagFuncSpace(mag_idx, sig, bin_lo, bin_hi, N_1d, True) + return AppMagFuncSpace(mag_idx, sig, bin_lo, bin_hi, N_1d, fit) def get_colorcolor_space( diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 684804f3..6d1b035a 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -299,6 +299,7 @@ def get_feniks_data( r"$J_{UDS}$", r"$H_{UDS}$", r"$K_{UDS}$", + r"$redshift$", ] # derive colors from mags @@ -352,8 +353,8 @@ def get_feniks_data( zbins = np.array( [ [0.2, 0.7], - [0.7, 1.0], - [1.0, 1.5], + [0.7, 1.5], + # [1.0, 1.5], [1.5, 2.5], ] ) @@ -486,170 +487,170 @@ def get_feniks_data( # 1D (i - NB816 | K) -- metallicity at fixed mass # 1D (z - NB921 | K) -- cross-check metallicity at fixed mass - Z2a = namedtuple( - "Z2a", - [ - "z_min", - "z_max", - "lc_data", - "rz_zJ", - "ug", - "rz", - "jh", - "K_ug", - "K_rz", - "iNB816_gr", - "iNB816_rK", - "iNB816_condK", - "zNB921_condK", - ], - ) - zbin = 1 - z_min = zbins[zbin][0] - z_max = zbins[zbin][1] - - z_phot_table = 10 ** jnp.linspace( - jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table - ) - lc_args = ( - ran_key, - num_halos_coarse_zbins, - z_min, - z_max, - lgmp_min, - lgmp_max, - lc_sky_area_degsq, - ssp_data, - tcurves, - z_phot_table, - ) - - lc_data = generate_lc_data(*lc_args) - - z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) - - # 2D (r - z, z - J) - rz_zJ = N_utils.get_colorcolor_space( - "Rz_zJ", - hsc_rz, - hsc_uds_zJ, - ["HSC_R", "HSC_Z", "HSC_Z", "UDS_J"], - z_sel, - fit=True, - ) - - # 1D (u - g | K) - ug = N_utils.get_color_cond_space_list( - "Ug_condK", - megacam_hsc_uSg, - uds_K, - ["MegaCam_uS", "HSC_G"], - "UDS_K", - z_sel, - cond_dmag=2, - fit=True, - ) - - # 1D (r - z | K) - rz = N_utils.get_color_cond_space_list( - "Rz_condK", - hsc_rz, - uds_K, - ["HSC_R", "HSC_Z"], - "UDS_K", - z_sel, - cond_dmag=2, - fit=True, - ) - - # 1D (J − H | K) - jh = N_utils.get_color_cond_space_list( - "JH_condK", - uds_JH, - uds_K, - ["UDS_J", "UDS_H"], - "UDS_K", - z_sel, - cond_dmag=2, - fit=True, - ) - - # 2D (K, u - g) - K_ug = N_utils.get_mag_color_space( - "K_ug", - uds_K, - megacam_hsc_uSg, - "UDS_K", - ["MegaCam_uS", "HSC_G"], - z_sel, - fit=False, - ) - - # 2D (K, r - z) - K_rz = N_utils.get_mag_color_space( - "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=False - ) - - # 2D (i - NB816, g - r) - i816_gr = N_utils.get_colorcolor_space( - "I816_gr", - hsc_i816, - hsc_gr, - ["HSC_I", "NB0816", "HSC_G", "HSC_R"], - z_sel, - fit=False, - ) - - # 2D (i - NB816, r - K) - i816_rK = N_utils.get_colorcolor_space( - "I816_rK", - hsc_i816, - hsc_uds_rK, - ["HSC_I", "NB0816", "HSC_R", "UDS_K"], - z_sel, - fit=False, - ) - - # 1D (i - NB816 | K) - i816_condK = N_utils.get_color_cond_space_list( - "I816_condK", - hsc_i816, - uds_K, - ["HSC_I", "NB0816"], - "UDS_K", - z_sel, - cond_dmag=2, - fit=False, - ) - - # 1D (z - NB921 | K) - z921_condK = N_utils.get_color_cond_space_list( - "Z921_condK", - hsc_z921, - uds_K, - ["HSC_Z", "NB0921"], - "UDS_K", - z_sel, - cond_dmag=2, - fit=False, - ) - - z2a = Z2a( - z_min, - z_max, - lc_data, - rz_zJ, - ug, - rz, - jh, - K_ug, - K_rz, - i816_gr, - i816_rK, - i816_condK, - z921_condK, - ) - colors.append(z2a) + # Z2a = namedtuple( + # "Z2a", + # [ + # "z_min", + # "z_max", + # "lc_data", + # "rz_zJ", + # "ug", + # "rz", + # "jh", + # "K_ug", + # "K_rz", + # "iNB816_gr", + # "iNB816_rK", + # "iNB816_condK", + # "zNB921_condK", + # ], + # ) + # zbin = 1 + # z_min = zbins[zbin][0] + # z_max = zbins[zbin][1] + + # z_phot_table = 10 ** jnp.linspace( + # jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + # ) + # lc_args = ( + # ran_key, + # num_halos_coarse_zbins, + # z_min, + # z_max, + # lgmp_min, + # lgmp_max, + # lc_sky_area_degsq, + # ssp_data, + # tcurves, + # z_phot_table, + # ) + + # lc_data = generate_lc_data(*lc_args) + + # z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + + # # 2D (r - z, z - J) + # rz_zJ = N_utils.get_colorcolor_space( + # "Rz_zJ", + # hsc_rz, + # hsc_uds_zJ, + # ["HSC_R", "HSC_Z", "HSC_Z", "UDS_J"], + # z_sel, + # fit=True, + # ) + + # # 1D (u - g | K) + # ug = N_utils.get_color_cond_space_list( + # "Ug_condK", + # megacam_hsc_uSg, + # uds_K, + # ["MegaCam_uS", "HSC_G"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=True, + # ) + + # # 1D (r - z | K) + # rz = N_utils.get_color_cond_space_list( + # "Rz_condK", + # hsc_rz, + # uds_K, + # ["HSC_R", "HSC_Z"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=True, + # ) + + # # 1D (J − H | K) + # jh = N_utils.get_color_cond_space_list( + # "JH_condK", + # uds_JH, + # uds_K, + # ["UDS_J", "UDS_H"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=True, + # ) + + # # 2D (K, u - g) + # K_ug = N_utils.get_mag_color_space( + # "K_ug", + # uds_K, + # megacam_hsc_uSg, + # "UDS_K", + # ["MegaCam_uS", "HSC_G"], + # z_sel, + # fit=False, + # ) + + # # 2D (K, r - z) + # K_rz = N_utils.get_mag_color_space( + # "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=False + # ) + + # # 2D (i - NB816, g - r) + # i816_gr = N_utils.get_colorcolor_space( + # "I816_gr", + # hsc_i816, + # hsc_gr, + # ["HSC_I", "NB0816", "HSC_G", "HSC_R"], + # z_sel, + # fit=False, + # ) + + # # 2D (i - NB816, r - K) + # i816_rK = N_utils.get_colorcolor_space( + # "I816_rK", + # hsc_i816, + # hsc_uds_rK, + # ["HSC_I", "NB0816", "HSC_R", "UDS_K"], + # z_sel, + # fit=False, + # ) + + # # 1D (i - NB816 | K) + # i816_condK = N_utils.get_color_cond_space_list( + # "I816_condK", + # hsc_i816, + # uds_K, + # ["HSC_I", "NB0816"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=False, + # ) + + # # 1D (z - NB921 | K) + # z921_condK = N_utils.get_color_cond_space_list( + # "Z921_condK", + # hsc_z921, + # uds_K, + # ["HSC_Z", "NB0921"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=False, + # ) + + # z2a = Z2a( + # z_min, + # z_max, + # lc_data, + # rz_zJ, + # ug, + # rz, + # jh, + # K_ug, + # K_rz, + # i816_gr, + # i816_rK, + # i816_condK, + # z921_condK, + # ) + # colors.append(z2a) ############################################################################## # Z2b spaces: @@ -661,7 +662,7 @@ def get_feniks_data( "Z2b", ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh", "K_ug", "K_rz"], ) - zbin = 2 + zbin = 1 z_min = zbins[zbin][0] z_max = zbins[zbin][1] @@ -774,7 +775,7 @@ def get_feniks_data( "K_JH", ], ) - zbin = 3 + zbin = 2 z_min = zbins[zbin][0] z_max = zbins[zbin][1] @@ -925,22 +926,22 @@ def get_feniks_data( u = N_utils.get_mag_space("U", megacam_uS, "MegaCam_uS", z_sel, fit=True) # 1D (g) - g = N_utils.get_mag_space("G", hsc_g, "HSC_G", z_sel, fit=True) + g = N_utils.get_mag_space("G", hsc_g, "HSC_G", z_sel, fit=False) # 1D (r) r = N_utils.get_mag_space("R", hsc_r, "HSC_R", z_sel, fit=True) # 1D (i) - i = N_utils.get_mag_space("I", hsc_i, "HSC_I", z_sel, fit=True) + i = N_utils.get_mag_space("I", hsc_i, "HSC_I", z_sel, fit=False) # 1D (z) - z = N_utils.get_mag_space("Z", hsc_z, "HSC_Z", z_sel, fit=True) + z = N_utils.get_mag_space("Z", hsc_z, "HSC_Z", z_sel, fit=False) # 1D (J) - j = N_utils.get_mag_space("J", uds_J, "UDS_J", z_sel, fit=True) + j = N_utils.get_mag_space("J", uds_J, "UDS_J", z_sel, fit=False) # 1D (H) - h = N_utils.get_mag_space("H", uds_H, "UDS_H", z_sel, fit=True) + h = N_utils.get_mag_space("H", uds_H, "UDS_H", z_sel, fit=False) # 1D (K) k = N_utils.get_mag_space("K", uds_K, "UDS_K", z_sel, fit=True)