diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a311b4d30..4438b3071 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,4 +1,4 @@ -name: Test Pull Request into main +name: Test Pull Request into main with diffsky@v0.3.6 on: push: @@ -41,7 +41,7 @@ jobs: scipy \ python-build pip uninstall diffsky --yes - pip install --no-deps git+https://github.com/ArgonneCPAC/diffsky.git + pip install --no-deps git+https://github.com/ArgonneCPAC/diffsky.git@v0.3.6 python -m pip install --no-build-isolation --no-deps -e . diff --git a/diffhtwo/experimental/conftest.py b/diffhtwo/experimental/conftest.py index 328d94ef8..7c5c0d150 100644 --- a/diffhtwo/experimental/conftest.py +++ b/diffhtwo/experimental/conftest.py @@ -47,6 +47,7 @@ def feniks(ran_key, fake_subset_ssp_data): ssp_data, phot=PHOT, zout=ZOUT, + add_random_rows_for_testing=True, ) return feniks @@ -72,6 +73,20 @@ def feniks_tcurves(): return tcurves +@pytest.fixture(scope="session") +def feniks_fitting_data(ran_key, fake_subset_ssp_data): + ssp_data, emline_wave_aa = fake_subset_ssp_data + feniks_fitting_data = load_feniks.get_feniks_fitting_data( + FENIKS_DRN, + ran_key, + ssp_data, + phot=PHOT, + zout=ZOUT, + add_random_rows_for_testing=True, + ) + return feniks_fitting_data + + @pytest.fixture(scope="session") def feniks_single_z_data(ran_key, fake_subset_ssp_data, feniks): ssp_data, emline_wave_aa = fake_subset_ssp_data diff --git a/diffhtwo/experimental/data_loaders/N_utils.py b/diffhtwo/experimental/data_loaders/N_utils.py new file mode 100644 index 000000000..3dc8f8695 --- /dev/null +++ b/diffhtwo/experimental/data_loaders/N_utils.py @@ -0,0 +1,145 @@ +from collections import namedtuple + +import jax.numpy as jnp +import numpy as np +from diffsky import diffndhist_lomem + +from ..defaults import AppMagFunc, ColorColor, ColorCondMag, FeniksFilters, MagColor + + +def get_N_1d(dim1, dim1_bin_edges=None, dmag=0.2, sig_scale=0.5): + dataset = dim1.reshape(dim1.size, 1) + if dim1_bin_edges is None: + dim1_bin_edges = np.arange(dim1.min(), dim1.max(), dmag) + + bin_lo = dim1_bin_edges[:-1].reshape(dim1_bin_edges[:-1].size, 1) + bin_hi = dim1_bin_edges[1:].reshape(dim1_bin_edges[1:].size, 1) + + sig = jnp.zeros_like(bin_lo) + (dmag * sig_scale) + + N_1d = diffndhist_lomem.tw_ndhist( + dataset, + sig, + bin_lo, + bin_hi, + ) + + return ( + N_1d, + sig, + bin_lo, + bin_hi, + ) + + +def get_N_2d(dim1, dim2, sig_scale=0.5): + dataset = np.vstack((dim1, dim2)).T + + dim1_bin_edges = np.linspace(dim1.min(), dim1.max(), 11) + dim2_bin_edges = np.linspace(dim2.min(), dim2.max(), 11) + + dim1_lo = dim1_bin_edges[:-1] + dim2_lo = dim2_bin_edges[:-1] + bin_lo = np.meshgrid(dim1_lo, dim2_lo, indexing="ij") + bin_lo = np.array(bin_lo).T.reshape(-1, 2) + + dim1_hi = dim1_bin_edges[1:] + dim2_hi = dim2_bin_edges[1:] + bin_hi = np.meshgrid(dim1_hi, dim2_hi, indexing="ij") + bin_hi = np.array(bin_hi).T.reshape(-1, 2) + + sig1 = np.diff(dim1_bin_edges) * sig_scale + sig2 = np.diff(dim2_bin_edges) * sig_scale + sig = np.meshgrid(sig1, sig2, indexing="ij") + sig = np.array(sig).T.reshape(-1, 2) + + N_2d = diffndhist_lomem.tw_ndhist( + dataset, + sig, + bin_lo, + bin_hi, + ) + + return N_2d, sig, bin_lo, bin_hi + + +def filter_name_to_idx(filter_name): + return FeniksFilters._fields.index(filter_name) + + +def get_mag_space(namedtuple_name, mag, filter_name, z_sel, fit=True): + AppMagFuncSpace = namedtuple(namedtuple_name, AppMagFunc._fields) + mag_idx = filter_name_to_idx(filter_name) + N_1d, sig, bin_lo, bin_hi = get_N_1d(mag[z_sel]) + return AppMagFuncSpace(mag_idx, sig, bin_lo, bin_hi, N_1d, fit) + + +def get_colorcolor_space( + namedtuple_name, color1, color2, col_filter_names, z_sel, fit=True +): + ColorColorSpace = namedtuple(namedtuple_name, ColorColor._fields) + + N_2d, sig, bin_lo, bin_hi = get_N_2d(color1[z_sel], color2[z_sel]) + + col_idx = [] + for n in col_filter_names: + col_idx.append(filter_name_to_idx(n)) + + return ColorColorSpace(col_idx, sig, bin_lo, bin_hi, N_2d, fit) + + +def get_color_cond_space_list( + namedtuple_name, + color, + cond_mag, + col_filter_names, + cond_filter_name, + z_sel, + cond_dmag=2, + fit=True, +): + ColorCondSpace = namedtuple(namedtuple_name, ColorCondMag._fields) + + col_idx = [] + for n in col_filter_names: + col_idx.append(filter_name_to_idx(n)) + cond_idx = filter_name_to_idx(cond_filter_name) + + cond_mag_bins = np.arange(cond_mag[z_sel].min(), cond_mag[z_sel].max(), cond_dmag) + + color_cond_list = [] + for b in range(len(cond_mag_bins) - 1): + cond_sel = (cond_mag[z_sel] > cond_mag_bins[b]) & ( + cond_mag[z_sel] <= cond_mag_bins[b + 1] + ) + N_1d, sig, bin_lo, bin_hi = get_N_1d(color[z_sel][cond_sel]) + color_cond_list.append( + ColorCondSpace( + col_idx, + cond_idx, + cond_mag_bins[b], + cond_mag_bins[b + 1], + sig, + bin_lo, + bin_hi, + N_1d, + fit, + ) + ) + + return color_cond_list + + +def get_mag_color_space( + namedtuple_name, mag, color, mag_filter_name, col_filter_names, z_sel, fit=True +): + MagColorSpace = namedtuple(namedtuple_name, MagColor._fields) + + mag_idx = filter_name_to_idx(mag_filter_name) + col_idx = [] + for n in col_filter_names: + col_idx.append(filter_name_to_idx(n)) + + N_2d, sig, bin_lo, bin_hi = get_N_2d(mag[z_sel], color[z_sel]) + + return MagColorSpace(mag_idx, col_idx, sig, bin_lo, bin_hi, N_2d, fit) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index d632a6c4f..6d1b035a6 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -15,10 +15,13 @@ FENIKS_Z_MAX, FENIKS_Z_MIN, Dataset, + FeniksFilters, FilterInfo, ) from ..latin_hypercube import latin_hypercube as lh -from ..utils import load_feniks_tcurve +from ..lightcone_generators import generate_lc_data +from ..utils import add_random_rows, load_feniks_tcurve +from . import N_utils BASE_PATH = Path(__file__).resolve().parent.parent FENIKS_FILTERS_PATH = BASE_PATH / "data" / "feniks_filters" @@ -144,6 +147,13 @@ def get_feniks_data( lh_d_mag=0.6, phot=PHOT, zout=ZOUT, + num_halos_coarse_zbins=250, + num_halos_fine_zbins=150, + lgmp_min=10.0, + lgmp_max=15.0, + lc_sky_area_degsq=100, + n_z_phot_table=30, + add_random_rows_for_testing=False, ): # Transmission curves and filter mag thresholds @@ -152,8 +162,9 @@ def get_feniks_data( HSC_G=False, HSC_R=False, HSC_I=False, + NB0816=False, HSC_Z=False, - # VIDEO_Y=False, + NB0921=False, UDS_J=False, UDS_H=False, UDS_K=True, @@ -168,13 +179,18 @@ def get_feniks_data( phot = ascii.read(drn_path / phot) zout = ascii.read(drn_path / zout) + if add_random_rows_for_testing: + phot = add_random_rows(phot, N=10000) + zout = add_random_rows(zout, N=10000) + # get mags megacam_uS = get_mag_ab(phot, "fcol_MegaCam_uS") hsc_g = get_mag_ab(phot, "fcol_HSC_G") hsc_r = get_mag_ab(phot, "fcol_HSC_R") hsc_i = get_mag_ab(phot, "fcol_HSC_I") + nb816 = get_mag_ab(phot, "fcol_NB0816") hsc_z = get_mag_ab(phot, "fcol_HSC_Z") - # video_Y = get_mag_ab(phot, "fcol_VIDEO_Y") + nb921 = get_mag_ab(phot, "fcol_NB0921") uds_J = get_mag_ab(phot, "fcol_UDS_J") uds_H = get_mag_ab(phot, "fcol_UDS_H") uds_K = get_mag_ab(phot, "fcol_UDS_K") @@ -184,7 +200,9 @@ def get_feniks_data( HSC_G=25.1, HSC_R=25.3, HSC_I=25.1, + NB0816=25.3, HSC_Z=24.9, + NB0921=25.3, UDS_J=24.5, UDS_H=24.3, UDS_K=FENIKS_MAGK_THRESH, @@ -198,8 +216,9 @@ def get_feniks_data( & (hsc_g < feniks_mag_thresh.HSC_G) & (hsc_r < feniks_mag_thresh.HSC_R) & (hsc_i < feniks_mag_thresh.HSC_I) + & (nb816 < feniks_mag_thresh.NB0816) & (hsc_z < feniks_mag_thresh.HSC_Z) - # & (video_Y < feniks_mag_thresh.VIDEO_Y) + & (nb921 < feniks_mag_thresh.NB0921) & (uds_J < feniks_mag_thresh.UDS_J) & (uds_H < feniks_mag_thresh.UDS_H) & (uds_K < feniks_mag_thresh.UDS_K) @@ -214,13 +233,14 @@ def get_feniks_data( hsc_g = hsc_g[mag_thresh] hsc_r = hsc_r[mag_thresh] hsc_i = hsc_i[mag_thresh] + nb816 = nb816[mag_thresh] hsc_z = hsc_z[mag_thresh] - # video_Y = video_Y[mag_thresh] + nb921 = nb921[mag_thresh] uds_J = uds_J[mag_thresh] uds_H = uds_H[mag_thresh] uds_K = uds_K[mag_thresh] - N_obj_pre_cuts = len(zout) + n_gals_pre_cuts = len(zout) # remove mags with bad data in any of the bands clean = ( @@ -228,8 +248,9 @@ def get_feniks_data( & (hsc_g != -99) & (hsc_r != -99) & (hsc_i != -99) + & (nb816 != -99) & (hsc_z != -99) - # & (video_Y != -99) + & (nb921 != -99) & (uds_J != -99) & (uds_H != -99) & (uds_K != -99) @@ -241,39 +262,15 @@ def get_feniks_data( hsc_g = hsc_g[clean] hsc_r = hsc_r[clean] hsc_i = hsc_i[clean] + nb816 = nb816[clean] hsc_z = hsc_z[clean] - # video_Y = video_Y[clean] + nb921 = nb921[clean] uds_J = uds_J[clean] uds_H = uds_H[clean] uds_K = uds_K[clean] - # mask nans - # nans = ( - # (megacam_uS == -99.0) - # | (hsc_g == -99.0) - # | (hsc_r == -99.0) - # | (hsc_i == -99.0) - # | (hsc_z == -99.0) - # | (video_Y == -99) - # | (uds_J == -99.0) - # | (uds_H == -99.0) - # | (uds_K == -99.0) - # ) - - # megacam_uS = megacam_uS[~nans] - # hsc_g = hsc_g[~nans] - # hsc_r = hsc_r[~nans] - # hsc_i = hsc_i[~nans] - # hsc_z = hsc_z[~nans] - # video_Y = video_Y[~nans] - # uds_J = uds_J[~nans] - # uds_H = uds_H[~nans] - # uds_K = uds_K[~nans] - - # zout = zout[~nans] - - N_obj_post_cuts = len(zout) - frac_cat = N_obj_post_cuts / N_obj_pre_cuts + n_gals_post_cuts = len(zout) + frac_cat = n_gals_post_cuts / n_gals_pre_cuts mags = np.vstack( ( @@ -281,8 +278,9 @@ def get_feniks_data( hsc_g, hsc_r, hsc_i, + nb816, hsc_z, - # video_Y, + nb921, uds_J, uds_H, uds_K, @@ -290,15 +288,32 @@ def get_feniks_data( ) ).T + mags_labels = [ + r"$uS_{MegaCam}$", + r"$g_{HSC}$", + r"$r_{HSC}$", + r"$i_{HSC}$", + r"$NB816_{HSC}$", + r"$z_{HSC}$", + r"$NB921_{HSC}$", + r"$J_{UDS}$", + r"$H_{UDS}$", + r"$K_{UDS}$", + r"$redshift$", + ] + # derive colors from mags megacam_hsc_uSg = megacam_uS - hsc_g hsc_gr = hsc_g - hsc_r + hsc_rz = hsc_r - hsc_z hsc_ri = hsc_r - hsc_i + hsc_i816 = hsc_i - nb816 hsc_iz = hsc_i - hsc_z + hsc_z921 = hsc_z - nb921 hsc_uds_zJ = hsc_z - uds_J - # video_uds_YJ = video_Y - uds_J uds_JH = uds_J - uds_H uds_HK = uds_H - uds_K + hsc_uds_rK = hsc_r - uds_K # stack colors_mag dataset = np.vstack( @@ -306,7 +321,9 @@ def get_feniks_data( megacam_hsc_uSg, hsc_gr, hsc_ri, + hsc_i816, hsc_iz, + hsc_z921, hsc_uds_zJ, uds_JH, uds_HK, @@ -320,9 +337,10 @@ def get_feniks_data( r"$uS_{MegaCam} - g_{HSC}$", r"$g_{HSC} - r_{HSC}$", r"$r_{HSC} - i_{HSC}$", + r"$i_{HSC} - NB816_{HSC}$", r"$i_{HSC} - z_{HSC}$", + r"$z_{HSC} - NB921_{HSC}$", r"$z_{HSC} - J_{UDS}$", - # r"$Y_{VIDEO} - J_{UDS}$", r"$J_{UDS} - H_{UDS}$", r"$H_{UDS} - K_{UDS}$", r"$uS_{MegaCam}$", @@ -330,23 +348,607 @@ def get_feniks_data( r"$redshift$", ] - mags_labels = [ - r"$uS_{MegaCam}$", - r"$g_{HSC}$", - r"$r_{HSC}$", - r"$i_{HSC}$", - r"$z_{HSC}$", - # r"$Y_{VIDEO}$", - r"$J_{UDS}$", - r"$H_{UDS}$", - r"$K_{UDS}$", - ] + ############################################################################## + # prepare 2D and 1D color spaces in coarse z-bins for fitting + zbins = np.array( + [ + [0.2, 0.7], + [0.7, 1.5], + # [1.0, 1.5], + [1.5, 2.5], + ] + ) + ############################################################################## + # Z1 spaces: + # 2D (g - r, r - i) + # 2D (K, g - r) + # 2D (K, r - i) + # 2D (K, J - H) + + colors = [] + Z1 = namedtuple( + "Z1", + [ + "z_min", + "z_max", + "lc_data", + "gr_ri", + "ug", + "ri", + "iz", + "jh", + "K_ri", + "K_gr", + "K_JH", + ], + ) + zbin = 0 + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_coarse_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + + # 2D (g - r, r - i) + gr_ri = N_utils.get_colorcolor_space( + "Gr_ri", hsc_gr, hsc_ri, ["HSC_G", "HSC_R", "HSC_R", "HSC_I"], z_sel, fit=True + ) + + # 1D (u - g | K) + ug = N_utils.get_color_cond_space_list( + "Ug_condK", + megacam_hsc_uSg, + uds_K, + ["MegaCam_uS", "HSC_G"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, + ) + + # 1D (r − i | K) + ri = N_utils.get_color_cond_space_list( + "Ri_condK", + hsc_ri, + uds_K, + ["HSC_R", "HSC_I"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, + ) + + # 1D (i − z | K) + iz = N_utils.get_color_cond_space_list( + "Iz_condK", + hsc_iz, + uds_K, + ["HSC_I", "HSC_Z"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, + ) + + # 1D (J − H | K) + jh = N_utils.get_color_cond_space_list( + "JH_condK", + uds_JH, + uds_K, + ["UDS_J", "UDS_H"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, + ) + + # 2D (K, r - i) + K_ri = N_utils.get_mag_color_space( + "K_ri", uds_K, hsc_ri, "UDS_K", ["HSC_R", "HSC_I"], z_sel, fit=False + ) + + # 2D (K, g - r) + K_gr = N_utils.get_mag_color_space( + "K_gr", uds_K, hsc_gr, "UDS_K", ["HSC_G", "HSC_R"], z_sel, fit=False + ) + + # 2D (K, J - H) + K_JH = N_utils.get_mag_color_space( + "K_JH", uds_K, uds_JH, "UDS_K", ["UDS_J", "UDS_H"], z_sel, fit=False + ) + + z1 = Z1(z_min, z_max, lc_data, gr_ri, ug, ri, iz, jh, K_ri, K_gr, K_JH) + colors.append(z1) + + ############################################################################## + # Z2a spaces: + # 2D (r - z, z - J) + # 2D (K, u - g) + # 2D (K, r - z) + # 2D (i - NB816, g - r) -- metallicity vs age + # 2D (i - NB816, r - K) -- metallicity vs mass-to-light + # 1D (i - NB816 | K) -- metallicity at fixed mass + # 1D (z - NB921 | K) -- cross-check metallicity at fixed mass + + # Z2a = namedtuple( + # "Z2a", + # [ + # "z_min", + # "z_max", + # "lc_data", + # "rz_zJ", + # "ug", + # "rz", + # "jh", + # "K_ug", + # "K_rz", + # "iNB816_gr", + # "iNB816_rK", + # "iNB816_condK", + # "zNB921_condK", + # ], + # ) + # zbin = 1 + # z_min = zbins[zbin][0] + # z_max = zbins[zbin][1] + + # z_phot_table = 10 ** jnp.linspace( + # jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + # ) + # lc_args = ( + # ran_key, + # num_halos_coarse_zbins, + # z_min, + # z_max, + # lgmp_min, + # lgmp_max, + # lc_sky_area_degsq, + # ssp_data, + # tcurves, + # z_phot_table, + # ) + + # lc_data = generate_lc_data(*lc_args) + + # z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + + # # 2D (r - z, z - J) + # rz_zJ = N_utils.get_colorcolor_space( + # "Rz_zJ", + # hsc_rz, + # hsc_uds_zJ, + # ["HSC_R", "HSC_Z", "HSC_Z", "UDS_J"], + # z_sel, + # fit=True, + # ) + + # # 1D (u - g | K) + # ug = N_utils.get_color_cond_space_list( + # "Ug_condK", + # megacam_hsc_uSg, + # uds_K, + # ["MegaCam_uS", "HSC_G"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=True, + # ) + + # # 1D (r - z | K) + # rz = N_utils.get_color_cond_space_list( + # "Rz_condK", + # hsc_rz, + # uds_K, + # ["HSC_R", "HSC_Z"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=True, + # ) + + # # 1D (J − H | K) + # jh = N_utils.get_color_cond_space_list( + # "JH_condK", + # uds_JH, + # uds_K, + # ["UDS_J", "UDS_H"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=True, + # ) + + # # 2D (K, u - g) + # K_ug = N_utils.get_mag_color_space( + # "K_ug", + # uds_K, + # megacam_hsc_uSg, + # "UDS_K", + # ["MegaCam_uS", "HSC_G"], + # z_sel, + # fit=False, + # ) + + # # 2D (K, r - z) + # K_rz = N_utils.get_mag_color_space( + # "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=False + # ) + + # # 2D (i - NB816, g - r) + # i816_gr = N_utils.get_colorcolor_space( + # "I816_gr", + # hsc_i816, + # hsc_gr, + # ["HSC_I", "NB0816", "HSC_G", "HSC_R"], + # z_sel, + # fit=False, + # ) + + # # 2D (i - NB816, r - K) + # i816_rK = N_utils.get_colorcolor_space( + # "I816_rK", + # hsc_i816, + # hsc_uds_rK, + # ["HSC_I", "NB0816", "HSC_R", "UDS_K"], + # z_sel, + # fit=False, + # ) + + # # 1D (i - NB816 | K) + # i816_condK = N_utils.get_color_cond_space_list( + # "I816_condK", + # hsc_i816, + # uds_K, + # ["HSC_I", "NB0816"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=False, + # ) + + # # 1D (z - NB921 | K) + # z921_condK = N_utils.get_color_cond_space_list( + # "Z921_condK", + # hsc_z921, + # uds_K, + # ["HSC_Z", "NB0921"], + # "UDS_K", + # z_sel, + # cond_dmag=2, + # fit=False, + # ) - # mask redshift - z_mask = (zout["z_phot"] > FENIKS_Z_MIN) & (zout["z_phot"] <= FENIKS_Z_MAX) - dataset = dataset[z_mask] - mags = mags[z_mask] - zout = zout[z_mask] + # z2a = Z2a( + # z_min, + # z_max, + # lc_data, + # rz_zJ, + # ug, + # rz, + # jh, + # K_ug, + # K_rz, + # i816_gr, + # i816_rK, + # i816_condK, + # z921_condK, + # ) + # colors.append(z2a) + + ############################################################################## + # Z2b spaces: + # 2D (r - z, z - J) + # 2D (K, u - g) + # 2D (K, r - z) + + Z2b = namedtuple( + "Z2b", + ["z_min", "z_max", "lc_data", "rz_zJ", "ug", "rz", "jh", "K_ug", "K_rz"], + ) + zbin = 1 + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_coarse_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + + # 2D (r - z, z - J) + rz_zJ = N_utils.get_colorcolor_space( + "Rz_zJ", + hsc_rz, + hsc_uds_zJ, + ["HSC_R", "HSC_Z", "HSC_Z", "UDS_J"], + z_sel, + fit=True, + ) + + # 1D (u - g | K) + ug = N_utils.get_color_cond_space_list( + "Ug_condK", + megacam_hsc_uSg, + uds_K, + ["MegaCam_uS", "HSC_G"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, + ) + + # 1D (r - z | K) + rz = N_utils.get_color_cond_space_list( + "Rz_condK", + hsc_rz, + uds_K, + ["HSC_R", "HSC_Z"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, + ) + + # 1D (J − H | K) + jh = N_utils.get_color_cond_space_list( + "JH_condK", + uds_JH, + uds_K, + ["UDS_J", "UDS_H"], + "UDS_K", + z_sel, + cond_dmag=2, + fit=True, + ) + + # 2D (K, u - g) + K_ug = N_utils.get_mag_color_space( + "K_ug", + uds_K, + megacam_hsc_uSg, + "UDS_K", + ["MegaCam_uS", "HSC_G"], + z_sel, + fit=False, + ) + + # 2D (K, r - z) + K_rz = N_utils.get_mag_color_space( + "K_rz", uds_K, hsc_rz, "UDS_K", ["HSC_R", "HSC_Z"], z_sel, fit=False + ) + + z2b = Z2b(z_min, z_max, lc_data, rz_zJ, ug, rz, jh, K_ug, K_rz) + colors.append(z2b) + + ############################################################################## + # Z3 spaces: + # 2D (z - J, J - H) + # 2D (u - g, g - r) + # 2D (K, u - g) + # 2D (K, g - r) + # 2D (K, J − H): residual quenching scatter at fixed stellar mass + + Z3 = namedtuple( + "Z3", + [ + "z_min", + "z_max", + "lc_data", + "zJ_JH", + "ug_gr", + "ug", + "gr", + "jh", + "K_ug", + "K_gr", + "K_JH", + ], + ) + zbin = 2 + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_coarse_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + + # 2D (z - J, J - H) + zJ_JH = N_utils.get_colorcolor_space( + "ZJ_JH", + hsc_uds_zJ, + uds_JH, + ["HSC_Z", "UDS_J", "UDS_J", "UDS_H"], + z_sel, + fit=True, + ) + + # 2D (u - g, g - r) + ug_gr = N_utils.get_colorcolor_space( + "Ug_gr", + megacam_hsc_uSg, + hsc_gr, + ["MegaCam_uS", "HSC_G", "HSC_G", "HSC_R"], + z_sel, + fit=True, + ) + + # 1D (u - g | K) + ug = N_utils.get_color_cond_space_list( + "Ug_condK", + megacam_hsc_uSg, + uds_K, + ["MegaCam_uS", "HSC_G"], + "UDS_K", + z_sel, + cond_dmag=4, + fit=True, + ) + + # 1D (g - r | K) + gr = N_utils.get_color_cond_space_list( + "Gr_condK", + hsc_gr, + uds_K, + ["HSC_G", "HSC_R"], + "UDS_K", + z_sel, + cond_dmag=4, + fit=True, + ) + + # 1D (J − H | K) + jh = N_utils.get_color_cond_space_list( + "JH_condK", + uds_JH, + uds_K, + ["UDS_J", "UDS_H"], + "UDS_K", + z_sel, + cond_dmag=4, + fit=True, + ) + + # 2D (K, u - g) + K_ug = N_utils.get_mag_color_space( + "K_ug", + uds_K, + megacam_hsc_uSg, + "UDS_K", + ["MegaCam_uS", "HSC_G"], + z_sel, + fit=False, + ) + + # 2D (K, g - r) + K_gr = N_utils.get_mag_color_space( + "K_gr", uds_K, hsc_gr, "UDS_K", ["HSC_G", "HSC_R"], z_sel, fit=False + ) + + # 2D (K, J - H) + K_JH = N_utils.get_mag_color_space( + "K_JH", uds_K, uds_JH, "UDS_K", ["UDS_J", "UDS_H"], z_sel, fit=False + ) + + z3 = Z3(z_min, z_max, lc_data, zJ_JH, ug_gr, ug, gr, jh, K_ug, K_gr, K_JH) + colors.append(z3) + + ############################################################################## + # prepare 1D app mag funcs in finer z-bins for fitting + fine_zbins = np.array( + [ + [0.2, 0.5], + [0.5, 0.7], + [0.7, 1.0], + [1.0, 1.5], + [1.5, 2.0], + [2.0, 2.5], + ] + ) + ############################################################################## + AppMagFuncs = namedtuple( + "AppMagFuncs", + ["z_min", "z_max", "lc_data", "u", "g", "r", "i", "z", "J", "H", "K"], + ) + + app_mag_funcs = [] + for zbin in range(0, len(fine_zbins)): + z_min = fine_zbins[zbin][0] + z_max = fine_zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_fine_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (zout["z_phot"] > z_min) & (zout["z_phot"] <= z_max) + + # 1D (u) + u = N_utils.get_mag_space("U", megacam_uS, "MegaCam_uS", z_sel, fit=True) + + # 1D (g) + g = N_utils.get_mag_space("G", hsc_g, "HSC_G", z_sel, fit=False) + + # 1D (r) + r = N_utils.get_mag_space("R", hsc_r, "HSC_R", z_sel, fit=True) + + # 1D (i) + i = N_utils.get_mag_space("I", hsc_i, "HSC_I", z_sel, fit=False) + + # 1D (z) + z = N_utils.get_mag_space("Z", hsc_z, "HSC_Z", z_sel, fit=False) + + # 1D (J) + j = N_utils.get_mag_space("J", uds_J, "UDS_J", z_sel, fit=False) + + # 1D (H) + h = N_utils.get_mag_space("H", uds_H, "UDS_H", z_sel, fit=False) + + # 1D (K) + k = N_utils.get_mag_space("K", uds_K, "UDS_K", z_sel, fit=True) + + app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, g, r, i, z, j, h, k)) + + ############################################################################## lh_centroids, d_centroids = get_lh_centroids(dataset, lh_d_mag) @@ -367,6 +969,9 @@ def get_feniks_data( dataset_dim_labels, mags, mags_labels, + colors, + app_mag_funcs, + fine_zbins, filter_info, frac_cat, lh_centroids, @@ -378,17 +983,47 @@ def get_feniks_data( ) -FeniksFilters = namedtuple( - "FeniksFilters", - [ - "MegaCam_uS", - "HSC_G", - "HSC_R", - "HSC_I", - "HSC_Z", - # "VIDEO_Y", - "UDS_J", - "UDS_H", - "UDS_K", - ], -) +def get_feniks_fitting_data( + feniks_drn, + ran_key, + ssp_data, + phot=PHOT, + zout=ZOUT, + num_halos_coarse_zbins=250, + num_halos_fine_zbins=150, + add_random_rows_for_testing=False, +): + feniks = get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + phot=phot, + zout=zout, + num_halos_coarse_zbins=num_halos_coarse_zbins, + num_halos_fine_zbins=num_halos_fine_zbins, + add_random_rows_for_testing=add_random_rows_for_testing, + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} + ) + + return feniks_fitting_data + + +# FeniksFilters = namedtuple( +# "FeniksFilters", +# [ +# "MegaCam_uS", +# "HSC_G", +# "HSC_R", +# "HSC_I", +# "NB0816", +# "HSC_Z", +# "NB0921", +# "UDS_J", +# "UDS_H", +# "UDS_K", +# ], +# ) diff --git a/diffhtwo/experimental/data_loaders/load_hizels.py b/diffhtwo/experimental/data_loaders/load_hizels.py index de3c7bfc7..caa624d8d 100644 --- a/diffhtwo/experimental/data_loaders/load_hizels.py +++ b/diffhtwo/experimental/data_loaders/load_hizels.py @@ -18,6 +18,8 @@ "z", "dz", "lc_data", + "n_bins", + "n_gals", ], ) DELTA_L_HALPHA = -0.4 # uncorrect HiZELS h-alpha L for dust (A_halpha = 1 mag) @@ -42,6 +44,8 @@ def get_hizels_data( hizels_halpha_lg_phi_data, hizels_halpha_z_data, hizels_halpha_delta_z_data, + hizels_halpha_n_bins, + hizels_halpha_n_gals, ) = get_hizels_halpha(drn) line_wave_aa = [halpha_wave_aa] @@ -82,7 +86,16 @@ def get_hizels_data( lc_data.append(line_lc_data) return Hizels( - line_wave_aa, lg_Lbin_edges, N_data, vol_Mpc3_data, lg_phi_data, z, dz, lc_data + line_wave_aa, + lg_Lbin_edges, + N_data, + vol_Mpc3_data, + lg_phi_data, + z, + dz, + lc_data, + hizels_halpha_n_bins, + hizels_halpha_n_gals, ) @@ -213,6 +226,13 @@ def get_hizels_halpha(drn): ) ) + hizels_halpha_n_bins = ( + (lg_halpha_Lbin_edges_z0p4.size - 1) + + (lg_halpha_Lbin_edges_z0p84.size - 1) + + (lg_halpha_Lbin_edges_z1p47.size - 1) + + (lg_halpha_Lbin_edges_z2p23.size - 1) + ) + hizels_lg_halpha_Lbin_edges_data = [ lg_halpha_Lbin_edges_z0p4, lg_halpha_Lbin_edges_z0p84, @@ -227,6 +247,13 @@ def get_hizels_halpha(drn): halpha_N_data_z2p23, ] + hizels_halpha_n_gals = ( + (halpha_N_data_z0p4.sum()) + + (halpha_N_data_z0p84.sum()) + + (halpha_N_data_z1p47.sum()) + + (halpha_N_data_z2p23.sum()) + ) + hizels_halpha_vol_Mpc3 = [ halpha_vol_Mpc3_z0p4, halpha_vol_Mpc3_z0p84, @@ -262,4 +289,6 @@ def get_hizels_halpha(drn): hizels_halpha_lg_phi_data, hizels_halpha_z_data, hizels_halpha_delta_z_data, + hizels_halpha_n_bins, + hizels_halpha_n_gals, ) diff --git a/diffhtwo/experimental/data_loaders/load_sdss.py b/diffhtwo/experimental/data_loaders/load_sdss.py index 85db9efca..e7ef64afe 100644 --- a/diffhtwo/experimental/data_loaders/load_sdss.py +++ b/diffhtwo/experimental/data_loaders/load_sdss.py @@ -11,17 +11,23 @@ SDSS_MAGR_THRESH, SDSS_Z_MAX, SDSS_Z_MIN, + AppMagFunc, + ColorColor, + ColorCondMag, Dataset, FilterInfo, + MagColor, ) from ..latin_hypercube import latin_hypercube as lh +from ..lightcone_generators import generate_lc_data +from .N_utils import get_N_1d, get_N_2d Sdss = namedtuple("Sdss", Dataset._fields) LH_N_CENTROIDS = 20_000 LH_SIG = 3.5 LH_D_MAG = 0.1 -LH_D_Z = 0.01 +LH_D_Z = 0.05 def apply_ra_dec_cut(sdss, ra_min=120, ra_max=240, dec_min=0, dec_max=60): @@ -33,13 +39,19 @@ def apply_ra_dec_cut(sdss, ra_min=120, ra_max=240, dec_min=0, dec_max=60): ] -def load_sdss_cuts_applied(drn): +def load_sdss_cuts_applied(drn, sdss_mag_thresh): sdss = sdl.load_sdss_cat(drn) sdss = apply_ra_dec_cut(sdss) # implement r <= 17.6 - mag_thresh_mask = sdss["modelMag_r"] <= SDSS_MAGR_THRESH + mag_thresh_mask = ( + (sdss["modelMag_u"] < sdss_mag_thresh.sdss_u) + & (sdss["modelMag_g"] < sdss_mag_thresh.sdss_g) + & (sdss["modelMag_r"] < sdss_mag_thresh.sdss_r) + & (sdss["modelMag_i"] < sdss_mag_thresh.sdss_i) + & (sdss["modelMag_z"] < sdss_mag_thresh.sdss_z) + ) sdss = sdss[mag_thresh_mask] N_obj_pre_outlier_cut = len(sdss) @@ -107,16 +119,22 @@ def get_sdss_data( drn, ran_key, ssp_data, + num_halos_coarse_zbins=150, + num_halos_fine_zbins=250, + lgmp_min=10.0, + lgmp_max=15.0, + lc_sky_area_degsq=100, + n_z_phot_table=30, ): - sdss, frac_cat = load_sdss_cuts_applied(drn) - sdss_mag_thresh = SdssFilters( - sdss_u=None, - sdss_g=None, + sdss_u=19.7, + sdss_g=18.0, sdss_r=SDSS_MAGR_THRESH, - sdss_i=None, - sdss_z=None, + sdss_i=17.0, + sdss_z=17.0, ) + sdss, frac_cat = load_sdss_cuts_applied(drn, sdss_mag_thresh) + sdss_in_lh = SdssFilters( sdss_u=True, sdss_g=False, @@ -143,6 +161,7 @@ def get_sdss_data( # derive colors from mags sdss_ug = sdss_u - sdss_g sdss_gr = sdss_g - sdss_r + sdss_ur = sdss_u - sdss_r sdss_ri = sdss_r - sdss_i sdss_iz = sdss_i - sdss_z @@ -161,6 +180,207 @@ def get_sdss_data( ] mag_labels = [r"$u$", r"$g$", r"$r$", r"$i$", r"$z$"] + ############################################################################## + # prepare 2D and 1D color spaces in coarse z-bins for fitting + zbins = np.array( + [ + [0.02, 0.1], + [0.1, 0.2], + ] + ) + ############################################################################## + Colors = namedtuple( + "Colors", + ["z_min", "z_max", "lc_data", "ur_ri", "gr_ri", "ur", "ri", "r_ur", "r_ri"], + ) + # 2D (u - r, r - i) + Ur_ri = namedtuple("Ur_ri", ColorColor._fields) + + # 2D (g - r, r - i) + Gr_ri = namedtuple("Gr_ri", ColorColor._fields) + + # 2D (r, u - r) + R_ur = namedtuple("R_ur", MagColor._fields) + + # 2D (r, r - i) + R_ri = namedtuple("R_ri", MagColor._fields) + + # 1D (u - r | r) + Ur_condr = namedtuple("Ug_condK", ColorCondMag._fields) + + # 1D (r - i | r) + Ri_condr = namedtuple("Ri_condr", ColorCondMag._fields) + + colors = [] + for zbin in range(0, len(zbins)): + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_coarse_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (sdss_redshift > z_min) & (sdss_redshift <= z_max) + + # 2D (u - r, r - i) + N_ur_ri, sig_ur_ri, bin_lo_ur_ri, bin_hi_ur_ri = get_N_2d( + sdss_ur[z_sel], sdss_ri[z_sel] + ) + col_idx = [0, 2, 2, 3] + ur_ri = Ur_ri(col_idx, sig_ur_ri, bin_lo_ur_ri, bin_hi_ur_ri, N_ur_ri, True) + + # 2D (g - r, r - i) + N_gr_ri, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri = get_N_2d( + sdss_gr[z_sel], sdss_ri[z_sel] + ) + col_idx = [1, 2, 2, 3] + gr_ri = Gr_ri(col_idx, sig_gr_ri, bin_lo_gr_ri, bin_hi_gr_ri, N_gr_ri, True) + + # 1D (u - r | r) + rbins = np.arange(sdss_r[z_sel].min(), sdss_r[z_sel].max(), 2) + + col_idx = [0, 2] + cond_idx = 2 + ur = [] + for r in range(len(rbins) - 1): + r_sel = (sdss_r[z_sel] > rbins[r]) & (sdss_r[z_sel] <= rbins[r + 1]) + N_1d_ur, sig_ur, bin_lo_ur, bin_hi_ur = get_N_1d(sdss_ur[z_sel][r_sel]) + ur.append( + Ur_condr( + col_idx, + cond_idx, + rbins[r], + rbins[r + 1], + sig_ur, + bin_lo_ur, + bin_hi_ur, + N_1d_ur, + True, + ) + ) + + # 1D (r - i | r) + col_idx = [2, 3] + cond_idx = 2 + ri = [] + for r in range(len(rbins) - 1): + r_sel = (sdss_r[z_sel] > rbins[r]) & (sdss_r[z_sel] <= rbins[r + 1]) + N_1d_ri, sig_ri, bin_lo_ri, bin_hi_ri = get_N_1d(sdss_ri[z_sel][r_sel]) + ri.append( + Ri_condr( + col_idx, + cond_idx, + rbins[r], + rbins[r + 1], + sig_ri, + bin_lo_ri, + bin_hi_ri, + N_1d_ri, + True, + ) + ) + + # 2D (r, u - r) + N_r_ur, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur = get_N_2d( + sdss_r[z_sel], sdss_ur[z_sel] + ) + mag_idx = 2 + col_idx = [0, 2] + r_ur = R_ur(mag_idx, col_idx, sig_r_ur, bin_lo_r_ur, bin_hi_r_ur, N_r_ur, False) + + # 2D (r, r - i) + N_r_ri, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri = get_N_2d( + sdss_r[z_sel], sdss_ri[z_sel] + ) + mag_idx = 2 + col_idx = [2, 3] + r_ri = R_ri(mag_idx, col_idx, sig_r_ri, bin_lo_r_ri, bin_hi_r_ri, N_r_ri, False) + + colors.append(Colors(z_min, z_max, lc_data, ur_ri, gr_ri, ur, ri, r_ur, r_ri)) + + ############################################################################## + ############################################################################## + # prepare 1D app mag funcs in finer z-bins for fitting + fine_zbins = np.array([[0.02, 0.06], [0.06, 0.1], [0.1, 0.14], [0.14, 0.2]]) + ############################################################################## + AppMagFuncs = namedtuple( + "AppMagFuncs", + ["z_min", "z_max", "lc_data", "u", "g", "r", "i", "z"], + ) + U = namedtuple("U", AppMagFunc._fields) + G = namedtuple("G", AppMagFunc._fields) + R = namedtuple("R", AppMagFunc._fields) + I = namedtuple("I", AppMagFunc._fields) # noqa: E741 + Z = namedtuple("Z", AppMagFunc._fields) + + app_mag_funcs = [] + for zbin in range(0, len(fine_zbins)): + z_min = fine_zbins[zbin][0] + z_max = fine_zbins[zbin][1] + + z_phot_table = 10 ** jnp.linspace( + jnp.log10(z_min), jnp.log10(z_max), n_z_phot_table + ) + lc_args = ( + ran_key, + num_halos_fine_zbins, + z_min, + z_max, + lgmp_min, + lgmp_max, + lc_sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + + lc_data = generate_lc_data(*lc_args) + + z_sel = (sdss_redshift > z_min) & (sdss_redshift <= z_max) + + # 1D (u) + mag_idx_u = 0 + N_1d_u, sig_u, bin_lo_u, bin_hi_u = get_N_1d(sdss_u[z_sel]) + u = U(mag_idx_u, sig_u, bin_lo_u, bin_hi_u, N_1d_u, True) + + # 1D (g) + mag_idx_g = 1 + N_1d_g, sig_g, bin_lo_g, bin_hi_g = get_N_1d(sdss_g[z_sel]) + g = G(mag_idx_g, sig_g, bin_lo_g, bin_hi_g, N_1d_g, True) + + # 1D (r) + mag_idx_r = 2 + N_1d_r, sig_r, bin_lo_r, bin_hi_r = get_N_1d(sdss_r[z_sel]) + r = R(mag_idx_r, sig_r, bin_lo_r, bin_hi_r, N_1d_r, True) + + # 1D (i) + mag_idx_i = 3 + N_1d_i, sig_i, bin_lo_i, bin_hi_i = get_N_1d(sdss_i[z_sel]) + i = I(mag_idx_i, sig_i, bin_lo_i, bin_hi_i, N_1d_i, True) + + # 1D (z) + mag_idx_z = 4 + N_1d_z, sig_z, bin_lo_z, bin_hi_z = get_N_1d(sdss_z[z_sel]) + z = Z(mag_idx_z, sig_z, bin_lo_z, bin_hi_z, N_1d_z, True) + + app_mag_funcs.append(AppMagFuncs(z_min, z_max, lc_data, u, g, r, i, z)) + + ############################################################################## + lh_centroids, d_centroids = get_lh_centroids(dataset) # run initial diffndhist_lomem with fixed dmag @@ -179,6 +399,9 @@ def get_sdss_data( dataset_dim_labels, mags, mag_labels, + colors, + app_mag_funcs, + fine_zbins, filter_info, frac_cat, lh_centroids, diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 1c83c0634..e1306602b 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -19,7 +19,7 @@ FENIKS_AREA_DEG2 = 2828.247933129912 / 3600 FENIKS_Z_MIN = 0.2 -FENIKS_Z_MAX = 3.0 +FENIKS_Z_MAX = 2.0 FENIKS_MAGK_THRESH = 24.3 # col mag SDSS_AREA_DEG2 = 7199 @@ -35,6 +35,9 @@ "dataset_dim_labels", "mags", "mags_labels", + "colors", + "app_mag_funcs", + "fine_zbins", "filter_info", "frac_cat", "lh_centroids", @@ -45,3 +48,47 @@ "data_sky_area_degsq", ], ) + +ColorColor = namedtuple( + "ColorColor", ["col_idx", "sig", "bin_lo", "bin_hi", "N_data", "fit"] +) + +ColorCondMag = namedtuple( + "ColorCondMag", + [ + "col_idx", + "cond_idx", + "cond_min", + "cond_max", + "sig", + "bin_lo", + "bin_hi", + "N_data", + "fit", + ], +) + +MagColor = namedtuple( + "MagColor", ["mag_idx", "col_idx", "sig", "bin_lo", "bin_hi", "N_data", "fit"] +) + +AppMagFunc = namedtuple( + "AppMagFunc", + ["mag_idx", "sig", "bin_lo", "bin_hi", "N_data", "fit"], +) + +FeniksFilters = namedtuple( + "FeniksFilters", + [ + "MegaCam_uS", + "HSC_G", + "HSC_R", + "HSC_I", + "NB0816", + "HSC_Z", + "NB0921", + "UDS_J", + "UDS_H", + "UDS_K", + ], +) diff --git a/diffhtwo/experimental/diagnostics/plot_burstpop.py b/diffhtwo/experimental/diagnostics/plot_burstpop.py index 795249c4b..0e0354241 100644 --- a/diffhtwo/experimental/diagnostics/plot_burstpop.py +++ b/diffhtwo/experimental/diagnostics/plot_burstpop.py @@ -143,7 +143,7 @@ def plot_lgfburst_mh_z( ] ) fig, ax = plt.subplots(1, 2, figsize=(10, 4), width_ratios=[1, 1.2]) - vmin, vmax = -6, -2.5 + vmin, vmax = -6, -1.5 """Plot fburst w/ halo mass and redshift""" ax[0].hexbin( @@ -160,7 +160,7 @@ def plot_lgfburst_mh_z( ) ax[0].set_xlabel("redshift") - ax[0].set_ylabel("log$_{10}$ (M$_{h, peak}$ [M\u2609])") + ax[0].set_ylabel(r"log$_{10}$ (M$_{h, peak}$ [M${_\odot}$])") ax[0].set_xlim(z_min, z_max) ax[0].set_ylim(10, 15) diff --git a/diffhtwo/experimental/diagnostics/plot_contour.py b/diffhtwo/experimental/diagnostics/plot_contour.py new file mode 100644 index 000000000..f6332c8ea --- /dev/null +++ b/diffhtwo/experimental/diagnostics/plot_contour.py @@ -0,0 +1,214 @@ +import matplotlib.lines as mlines +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import LinearSegmentedColormap +from scipy.ndimage import gaussian_filter + +from ..kernels.N_phot import N_colors_mags + +plt.rc("font", family="serif", serif=["Times New Roman"]) + +# Pantone: Dress Blues → Classic Blue → Aqua Sky → Minty Green → Illuminating +density_cmap = LinearSegmentedColormap.from_list( + "pantone_density", + [ + "#1B2A4A", # Dress Blues — empty/low + "#0F4C81", # Classic Blue + "#00A591", # Arcadia + "#84BD00", # Greenery + "#FEDF00", # Illuminating — peak density + ], +) +dusk = LinearSegmentedColormap.from_list( + "dusk", + [ + "#1B1F3B", # Evening Blue + "#7B4F9E", # Amethyst Orchid + "#E8A598", # Peach Pink + "#F5E6C8", # Almond Milk + ], +) + + +def plot_density( + bin_lo, + bin_hi, + N, + ax, + xlabel, + ylabel, + cmap, + data_label, + N_model=None, + sigma=0.5, + n_levels=10, +): + x_edges = np.unique(np.append(bin_lo[:, 0], bin_hi[-1, 0])) + y_edges = np.unique(np.append(bin_lo[:, 1], bin_hi[-1, 1])) + xc = 0.5 * (x_edges[:-1] + x_edges[1:]) + yc = 0.5 * (y_edges[:-1] + y_edges[1:]) + Z = np.log10( + gaussian_filter( + (N / N.sum()).reshape(len(y_edges) - 1, len(x_edges) - 1).astype(float), + sigma=sigma, + ).clip(min=np.finfo(float).tiny) + ) + Z_min = np.max((-10, Z.min())) + Z_max = Z.max() + levels = np.linspace(Z_min, Z_max, n_levels) + qm = ax.contourf(xc, yc, Z, levels=levels, cmap=cmap, alpha=0.5) + ax.get_figure().colorbar(qm, ax=ax, label=r"$\log_{10}(N / N_{\rm tot})$") + + legend_handles = [mpatches.Patch(color=cmap(0.7), alpha=0.5, label=data_label)] + + if N_model is not None: + Z_model = np.log10( + gaussian_filter( + (N_model / N_model.sum()) + .reshape(len(y_edges) - 1, len(x_edges) - 1) + .astype(float), + sigma=sigma, + ).clip(min=np.finfo(float).tiny) + ) + ax.contour( + xc, + yc, + Z_model, + levels=levels, + cmap=cmap, + linewidths=1.5, + alpha=0.9, + linestyles="dashed", + ) + legend_handles.append( + mlines.Line2D( + [], + [], + color=cmap(0.7), + linewidth=1.5, + linestyle="dashed", + alpha=0.9, + label="diffsky", + ) + ) + + ax.legend( + handles=legend_handles, + loc="upper center", + bbox_to_anchor=(0.5, 1.1), + ncol=len(legend_handles), + frameon=False, + fontsize=12, + ) + ax.set_xlabel(xlabel, fontsize=14) + ax.set_ylabel(ylabel, fontsize=14) + + +def plot_density_raw(bin_lo, bin_hi, N, ax, xlabel, ylabel, cmap, N_model=None): + x_edges = np.unique(np.append(bin_lo[:, 0], bin_hi[-1, 0])) + y_edges = np.unique(np.append(bin_lo[:, 1], bin_hi[-1, 1])) + xc = 0.5 * (x_edges[:-1] + x_edges[1:]) + yc = 0.5 * (y_edges[:-1] + y_edges[1:]) + Z = np.log10( + (N / N.sum()) + .reshape(len(y_edges) - 1, len(x_edges) - 1) + .astype(float) + .clip(min=np.finfo(float).tiny) + ) + qm = ax.pcolormesh(x_edges, y_edges, Z, cmap=cmap) + ax.get_figure().colorbar(qm, ax=ax, label=r"$\log_{10}(N / N_{\rm tot})$") + if N_model is not None: + Z_model = np.log10( + (N_model / N_model.sum()) + .reshape(len(y_edges) - 1, len(x_edges) - 1) + .astype(float) + .clip(min=np.finfo(float).tiny) + ) + levels = np.linspace(Z.min(), Z.max(), 8) + ax.contour(xc, yc, Z_model, levels=levels, cmap=cmap, linewidths=0.8, alpha=0.9) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + +def plot_color_contours( + ran_key, + param_collection, + data, + mag_thresh, + frac_cat, + data_label, + savedir, + sigma=0.5, + n_levels=10, +): + for z in range(0, len(data)): + z_data = data[z] + + z_data_model = N_colors_mags( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, + ) + fields = z_data_model._fields[3:] + z_min = z_data_model.z_min + z_max = z_data_model.z_max + + for f in range(0, len(fields)): + space = getattr(z_data_model, fields[f]) + + if isinstance(space, list): + pass + + else: + fig, ax = plt.subplots(figsize=(6.4, 5.2), constrained_layout=True) + fig.suptitle(str(z_min) + " < z < " + str(z_max), fontsize=14, y=0.98) + fig.get_layout_engine().set(h_pad=0.0, hspace=0.0, rect=(0, 0, 1, 0.95)) + + name = type(space).__name__ + xlabel, ylabel = parse_color_labels(name) + plot_density( + space.bin_lo, + space.bin_hi, + space.N_data, + ax, + xlabel, + ylabel, + dusk, + data_label, + N_model=space.N_model, + sigma=sigma, + n_levels=n_levels, + ) + fig.savefig( + savedir + + "/" + + data_label + + "_" + + name + + "_" + + str(z_min) + + "-" + + str(z_max) + + ".png", + dpi=300, + ) + plt.close() + + +def parse_axis_label(s): + nir_bands = {"j", "h", "k"} + + def fmt(b): + return b.upper() if b in nir_bands else b + + if len(s) == 2: + return f"${fmt(s[0])}-{fmt(s[1])}$" + return f"${fmt(s)}$" + + +def parse_color_labels(name): + x_str, y_str = name.lower().split("_") + return parse_axis_label(x_str), parse_axis_label(y_str) diff --git a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py b/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py deleted file mode 100644 index af5c2fec6..000000000 --- a/diffhtwo/experimental/diagnostics/plot_mag_color_1d_hist.py +++ /dev/null @@ -1,1027 +0,0 @@ -import corner -import jax.numpy as jnp -import numpy as np - -# from diffsky import diffndhist -from diffsky.experimental import lc_phot_kern -from diffsky.experimental import mc_lightcone_halos as mclh -from diffsky.experimental import precompute_ssp_phot as psspp -from diffstar.defaults import FB, T_TABLE_MIN -from dsps.cosmology import flat_wcdm -from dsps.cosmology.defaults import DEFAULT_COSMOLOGY -from jax import random as jran -from lc_utils import zbin_area, zbin_volume - -blue = "#1E90FF" # DodgerBlue -orange = "#FF8C00" # DarkOrange -# blue = "#4169E1" # RoyalBlue -# orange = "#D2691E" # Chocolate -# blue = "#00BFFF" # DeepSkyBlue -# orange = "#FFA500" # Orange - -color1 = orange -color2 = "k" -color_data = blue - -alpha1 = 1.0 -alpha2 = 0.7 -alpha_data = 0.5 - - -lw = 1.5 -fontsize = 24 -labelsize = 20 -legend_fontsize = 30 - - -try: - import matplotlib as mpl - from matplotlib import pyplot as plt - from matplotlib.lines import Line2D - - plt.rc("font", family="serif", serif=["Times New Roman"]) - - HAS_MATPLOTLIB = True -except ImportError: - HAS_MATPLOTLIB = False - - -mpl.rcParams["axes.linewidth"] = 2 - - -def plot_n_mag( - diffstarpop_params1, - spspop_params1, - ssp_err_pop_params1, - label1, - tcurves, - mag_thresh_column, - mag_thresh, - frac_cat, - dimension_labels, - ran_key, - zmins, - zmaxs, - ssp_data, - mzr_params, - scatter_params, - suptitle, - zbin_titles, - savedir, - dataset_mags=None, - n_bands=None, - data_sky_area_degsq=None, - diffstarpop_params2=None, - spspop_params2=None, - ssp_err_pop_params2=None, - label2=None, - dmag=0.1, - lgmp_min=10.0, - lgmp_max=15.0, - lc_vol_mpc3=7e4, - cosmo_params=DEFAULT_COSMOLOGY, - fb=FB, -): - # Plot 1D histograms - n_bands = len(tcurves) - n_zbins = len(zmins) - - fig_width = 3.0 * n_bands - fig_height = 3.0 * n_zbins - fig, ax = plt.subplots( - n_zbins, - n_bands, - figsize=(fig_width, fig_height), - ) - - fig.subplots_adjust( - left=0.065, hspace=0, top=0.95, right=0.99, bottom=0.05, wspace=0.0 - ) - fig.suptitle(suptitle, fontsize=32) - - fig_offset, ax_offset = plt.subplots( - n_zbins, - n_bands, - figsize=(fig_width, fig_height), - ) - - fig_offset.subplots_adjust( - left=0.065, hspace=0, top=0.95, right=0.99, bottom=0.05, wspace=0.0 - ) - fig_offset.suptitle(suptitle, fontsize=32) - - dataset_mags_z1 = np.array(dataset_mags[0]) - - for z in range(0, n_zbins): - zmin = zmins[z] - zmax = zmaxs[z] - dataset_mags_z = np.array(dataset_mags[z]) - - t = int(n_bands / 2) - ax[z, t].set_title(zbin_titles[z], y=0.85, fontsize=labelsize) - - """mc lightcone""" - ran_key, lc_key = jran.split(ran_key, 2) - sky_area_degsq = zbin_area(lc_vol_mpc3, zlow=zmin, zhigh=zmax).value - lc_args = (lc_key, lgmp_min, zmin, zmax, sky_area_degsq) - lc_halopop = mclh.mc_lightcone_host_halo_diffmah( - *lc_args, cosmo_params=cosmo_params, lgmp_max=lgmp_max - ) - - if data_sky_area_degsq is not None: - data_vol_mpc3 = zbin_volume( - data_sky_area_degsq, zlow=zmin, zhigh=zmax - ).value - - n_z_phot_table = 33 - - if (zmin < 0.24) & (zmax > 0.24): - nb_z = jnp.array([0.2445706, 0.40185568]) - nb816_zspan = np.linspace(nb_z[0] - 0.02, nb_z[0] + 0.02, 11) - nb921_zspan = np.linspace(nb_z[1] - 0.02, nb_z[1] + 0.02, 11) - z1_zspan = np.linspace(0.2, 0.5, 11) - z_phot_table = np.concatenate((nb816_zspan, nb921_zspan, z1_zspan)) - z_phot_table.sort() - else: - z_phot_table = jnp.linspace(zmin, zmax, n_z_phot_table) - - t_0 = flat_wcdm.age_at_z0(*DEFAULT_COSMOLOGY) - lgt0 = jnp.log10(t_0) - t_table = jnp.linspace(T_TABLE_MIN, 10**lgt0, 100) - - precomputed_ssp_mag_table = psspp.get_precompute_ssp_mag_redshift_table( - tcurves, ssp_data, z_phot_table, DEFAULT_COSMOLOGY - ) - - wave_eff_table = lc_phot_kern.get_wave_eff_table(z_phot_table, tcurves) - - ran_key, phot_key1 = jran.split(ran_key, 2) - phot_args1 = ( - phot_key1, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - diffstarpop_params1, - mzr_params, - spspop_params1, - scatter_params, - ssp_err_pop_params1, - cosmo_params, - fb, - ) - - lc_phot1 = lc_phot_kern.multiband_lc_phot_kern(*phot_args1) - if n_bands is None: - num_halos, n_bands = lc_phot1.obs_mags_q.shape - - # set weights=0 for mag > mag_thresh for the band indicated by mag_thresh_column - obs_mag_q1 = lc_phot1.obs_mags_q[:, mag_thresh_column] - obs_mag_smooth_ms1 = lc_phot1.obs_mags_smooth_ms[:, mag_thresh_column] - obs_mag_bursty_ms1 = lc_phot1.obs_mags_bursty_ms[:, mag_thresh_column] - - lc_phot_weights_q1 = jnp.where( - obs_mag_q1 < mag_thresh, - lc_phot1.weights_q, - jnp.zeros_like(lc_phot1.weights_q), - ) - lc_phot_weights_smooth_ms1 = jnp.where( - obs_mag_smooth_ms1 < mag_thresh, - lc_phot1.weights_smooth_ms, - jnp.zeros_like(lc_phot1.weights_smooth_ms), - ) - lc_phot_weights_bursty_ms1 = jnp.where( - obs_mag_bursty_ms1 < mag_thresh, - lc_phot1.weights_bursty_ms, - jnp.zeros_like(lc_phot1.weights_bursty_ms), - ) - N_weights1 = np.concatenate( - [ - lc_phot_weights_q1 * frac_cat, - lc_phot_weights_smooth_ms1 * frac_cat, - lc_phot_weights_bursty_ms1 * frac_cat, - ] - ) - - if diffstarpop_params2 is not None: - ran_key, phot_key2 = jran.split(ran_key, 2) - phot_args2 = ( - phot_key2, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - diffstarpop_params2, - mzr_params, - spspop_params2, - scatter_params, - ssp_err_pop_params2, - cosmo_params, - fb, - ) - - lc_phot2 = lc_phot_kern.multiband_lc_phot_kern(*phot_args2) - - # set weights=0 for mag > mag_thresh for the band indicated by mag_thresh_column - obs_mag_q2 = lc_phot2.obs_mags_q[:, mag_thresh_column] - obs_mag_smooth_ms2 = lc_phot2.obs_mags_smooth_ms[:, mag_thresh_column] - obs_mag_bursty_ms2 = lc_phot2.obs_mags_bursty_ms[:, mag_thresh_column] - - lc_phot_weights_q2 = jnp.where( - obs_mag_q2 < mag_thresh, - lc_phot2.weights_q, - jnp.zeros_like(lc_phot2.weights_q), - ) - lc_phot_weights_smooth_ms2 = jnp.where( - obs_mag_smooth_ms2 < mag_thresh, - lc_phot2.weights_smooth_ms, - jnp.zeros_like(lc_phot2.weights_smooth_ms), - ) - lc_phot_weights_bursty_ms2 = jnp.where( - obs_mag_bursty_ms2 < mag_thresh, - lc_phot2.weights_bursty_ms, - jnp.zeros_like(lc_phot2.weights_bursty_ms), - ) - N_weights2 = np.concatenate( - [ - lc_phot_weights_q2 * frac_cat, - lc_phot_weights_smooth_ms2 * frac_cat, - lc_phot_weights_bursty_ms2 * frac_cat, - ] - ) - - for i in range(0, n_bands): - sigma = np.std(dataset_mags_z1[:, i]) - lower_limit = np.mean(dataset_mags_z1[:, i]) - (4 * sigma) - upper_limit = np.mean(dataset_mags_z1[:, i]) + (4 * sigma) - if i == n_bands - 1: - upper_limit = mag_thresh - mag_bin_edges = np.arange( - lower_limit, - upper_limit, - dmag, - ) - ax[z, i].set_xlim(lower_limit, upper_limit) - ax_offset[z, i].set_xlim(lower_limit, upper_limit) - - # model 1 - lc_phot1_obs_mags = np.concatenate( - [ - lc_phot1.obs_mags_q[:, i], - lc_phot1.obs_mags_smooth_ms[:, i], - lc_phot1.obs_mags_bursty_ms[:, i], - ] - ) - lc_phot1_hist = ax[z, i].hist( - lc_phot1_obs_mags, - weights=N_weights1 * (1 / lc_vol_mpc3), - bins=mag_bin_edges, - histtype="step", - color=color1, - alpha=alpha1, - label=label1, - lw=lw + 1, - ) - - # model 2 - if diffstarpop_params2 is not None: - lc_phot2_obs_mags = np.concatenate( - [ - lc_phot2.obs_mags_q[:, i], - lc_phot2.obs_mags_smooth_ms[:, i], - lc_phot2.obs_mags_bursty_ms[:, i], - ] - ) - - ax[z, i].hist( - lc_phot2_obs_mags, - weights=N_weights2 * (1 / lc_vol_mpc3), - bins=mag_bin_edges, - histtype="step", - color=color2, - alpha=alpha2, - lw=lw, - label=label2, - ) - - # data - data_hist = ax[z, i].hist( - dataset_mags_z[:, i], - weights=np.ones_like(dataset_mags_z[:, i]) * (1 / data_vol_mpc3), - bins=mag_bin_edges, - color=color_data, - lw=lw, - alpha=alpha_data, - label="FENIKS-UDS", - ) - - """ax_offset""" - mag_bin_centers = (mag_bin_edges[1:] + mag_bin_edges[:-1]) / 2 - offset = data_hist[0] / lc_phot1_hist[0] - ax_offset[z, i].plot(mag_bin_centers, offset, lw=2, color="k") - ax_offset[z, i].set_ylim(0.09, 10.1) - ax_offset[z, i].set_yticks([0.1, 0.2, 0.5, 1, 2, 5, 10]) - ax_offset[z, i].set_yscale("log") - ax_offset[z, i].set_xlabel(dimension_labels[i], fontsize=fontsize) - ax_offset[z, i].tick_params( - axis="both", direction="in", labelsize=labelsize - ) - - ax[z, i].set_yscale("log") - ax[z, i].set_xlabel(dimension_labels[i], fontsize=fontsize) - ax[z, i].set_ylim(1e-6, 5e-3) - ax[z, i].tick_params(axis="both", direction="in", labelsize=labelsize) - - ax_offset_yticks = np.array([0.2, 0.5, 1, 2, 5]) - ax_offset[z, i].set_yticks(ax_offset_yticks) - ax_offset[z, i].axhspan( - ax_offset_yticks[1], ax_offset_yticks[3], color="orange", alpha=0.5 - ) - ax_offset[z, i].axhspan( - ax_offset_yticks[0], ax_offset_yticks[1], color="r", alpha=0.5 - ) - ax_offset[z, i].axhspan( - ax_offset_yticks[3], ax_offset_yticks[4], color="r", alpha=0.5 - ) - ax_offset[z, i].axhspan(0, ax_offset_yticks[0], color="r", alpha=0.8) - ax_offset[z, i].axhspan(ax_offset_yticks[4], 10, color="r", alpha=0.8) - - if i != 0: - ax[z, i].set_yticklabels([]) - ax_offset[z, i].set_yticklabels([]) - if z != n_zbins - 1: - ax[z, i].set_xticklabels([]) - ax_offset[z, i].set_xticklabels([]) - if i == 0: - ax_offset[z, i].set_yticklabels(["5x", "2x", "1x", "2x", "5x"]) - - ax[0, -1].legend( - framealpha=0.5, - loc="upper left", - bbox_to_anchor=(-2, 1.4), - ncols=3, - fontsize=legend_fontsize, - ) - fig.supylabel("\u03d5 [Mpc$^{-3}$]", fontsize=fontsize) - fig.savefig(savedir + "/mags_" + savedir.split("/")[-1] + ".pdf") - - fig_offset.supylabel("n$_{FENIKS}$ / n$_{diffsky}$", fontsize=fontsize) - fig_offset.savefig(savedir + "/mags_offsets_" + savedir.split("/")[-1] + ".pdf") - - plt.show() - - -def plot_n( - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - label1, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - dimension_labels, - ran_key, - zmins, - zmaxs, - ssp_data, - mzr_params, - scatter_params, - suptitle, - zbin_titles, - savedir, - dataset_colors_mag=None, - data_sky_area_degsq=None, - diffstarpop_params2=None, - spspop_params2=None, - ssperrpop_params2=None, - label2=None, - lh_centroids=None, - lg_n_data_err_lh=None, - lg_n_thresh=None, - dmag=0.1, - lgmp_min=10.0, - lgmp_max=15.0, - lc_vol_mpc3=7e4, - cosmo_params=DEFAULT_COSMOLOGY, - fb=FB, - n_z_phot_table=15, -): - # Plot 1D histograms - n_bands = len(tcurves) - n_dims = n_bands - 1 + len(mag_columns) - n_zbins = len(zmins) - - fig_width = 3.00 * n_dims - fig_height = 3.25 * n_zbins - fig, ax = plt.subplots( - n_zbins, - n_dims, - figsize=(fig_width, fig_height), - ) - - fig.subplots_adjust( - left=0.065, hspace=0, top=0.95, right=0.99, bottom=0.05, wspace=0.0 - ) - fig.suptitle(suptitle, fontsize=32) - - fig_offset, ax_offset = plt.subplots( - n_zbins, - n_dims, - figsize=(fig_width, fig_height), - ) - - fig_offset.subplots_adjust( - left=0.065, hspace=0, top=0.95, right=0.99, bottom=0.05, wspace=0.0 - ) - fig_offset.suptitle(suptitle, fontsize=32) - - dataset_colors_mag_z1 = np.array(dataset_colors_mag[0]) - for z in range(0, n_zbins): - zmin = zmins[z] - zmax = zmaxs[z] - dataset_colors_mag_z = np.array(dataset_colors_mag[z]) - - t = int(n_dims / 2) - ax[z, t].set_title(zbin_titles[z], y=0.85, fontsize=labelsize) - - if data_sky_area_degsq is not None: - data_vol_mpc3 = zbin_volume( - data_sky_area_degsq, zlow=zmin, zhigh=zmax - ).value - - if diffstarpop_params2 is not None: - ( - obs_colors_mag1, - N_weights1, - obs_colors_mag2, - N_weights2, - ) = get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - diffstarpop_params2=diffstarpop_params2, - spspop_params2=spspop_params2, - ssperrpop_params2=ssperrpop_params2, - ) - - else: - obs_colors_mag1, N_weights1 = get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - ) - - for i in range(0, n_dims): - sigma = np.std(dataset_colors_mag_z1[:, i]) - lower_limit = np.mean(dataset_colors_mag_z1[:, i]) - (4 * sigma) - upper_limit = np.mean(dataset_colors_mag_z1[:, i]) + (4 * sigma) - if i == n_dims - 1: - upper_limit = mag_thresh - bins = np.arange( - lower_limit, - upper_limit, - dmag, - ) - ax[z, i].set_xlim(lower_limit, upper_limit) - ax_offset[z, i].set_xlim(lower_limit, upper_limit) - - obs_colors_mag1_hist = ax[z, i].hist( - obs_colors_mag1[:, i], - weights=N_weights1 * (1 / lc_vol_mpc3), - bins=bins, - histtype="step", - color=color1, - alpha=alpha1, - lw=lw + 2, - label=label1, - ) - if diffstarpop_params2 is not None: - ax[z, i].hist( - obs_colors_mag2[:, i], - weights=N_weights2 * (1 / lc_vol_mpc3), - bins=bins, - histtype="step", - color=color2, - alpha=alpha2, - lw=lw, - label=label2, - ) - - # data - if dataset_colors_mag_z is not None: - dataset_colors_mag_hist = ax[z, i].hist( - dataset_colors_mag_z[:, i], - weights=np.ones_like(dataset_colors_mag_z[:, i]) - * (1 / data_vol_mpc3), - bins=bins, - color=color_data, - alpha=alpha_data, - lw=lw, - label="FENIKS-UDS", - ) - """ax_offset""" - bin_centers = (bins[1:] + bins[:-1]) / 2 - offset = dataset_colors_mag_hist[0] / obs_colors_mag1_hist[0] - ax_offset[z, i].plot(bin_centers, offset, lw=2, color="k") - ax_offset[z, i].set_ylim(0.09, 10.1) - ax_offset[z, i].set_yticks([0.1, 0.2, 0.5, 1, 2, 5, 10]) - ax_offset[z, i].set_yscale("log") - ax_offset[z, i].set_xlabel(dimension_labels[i], fontsize=fontsize) - ax_offset[z, i].tick_params( - axis="both", direction="in", labelsize=labelsize - ) - - ax[z, i].set_yscale("log") - ax[z, i].set_xlabel(dimension_labels[i], fontsize=fontsize) - ax[z, i].set_ylim(1e-6, 3e-2) - ax[z, i].tick_params(axis="both", direction="in", labelsize=labelsize) - - ax_offset_yticks = np.array([0.2, 0.5, 1, 2, 5]) - ax_offset[z, i].set_yticks(ax_offset_yticks) - ax_offset[z, i].axhspan( - ax_offset_yticks[1], ax_offset_yticks[3], color="orange", alpha=0.5 - ) - ax_offset[z, i].axhspan( - ax_offset_yticks[0], ax_offset_yticks[1], color="r", alpha=0.5 - ) - ax_offset[z, i].axhspan( - ax_offset_yticks[3], ax_offset_yticks[4], color="r", alpha=0.5 - ) - ax_offset[z, i].axhspan(0, ax_offset_yticks[0], color="r", alpha=0.8) - ax_offset[z, i].axhspan(ax_offset_yticks[4], 10, color="r", alpha=0.8) - - if i != 0: - ax[z, i].set_yticklabels([]) - ax_offset[z, i].set_yticklabels([]) - if z != n_zbins - 1: - ax[z, i].set_xticklabels([]) - ax_offset[z, i].set_xticklabels([]) - if i == 0: - ax_offset[z, i].set_yticklabels(["5x", "2x", "1x", "2x", "5x"]) - - ax[0, -1].legend( - framealpha=0.5, - loc="upper left", - bbox_to_anchor=(-2, 1.4), - ncols=3, - fontsize=legend_fontsize, - ) - - fig.supylabel("\u03d5 [Mpc$^{-3}$]", fontsize=fontsize) - fig.savefig(savedir + "/phot_fit_" + savedir.split("/")[-1] + ".pdf") - - fig_offset.supylabel("n$_{FENIKS}$ / n$_{diffsky}$", fontsize=fontsize) - fig_offset.savefig(savedir + "/phot_offsets_" + savedir.split("/")[-1] + ".pdf") - - plt.show() - - -def plot_n_corner( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - label1, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - dataset_colors_mag, - dimension_labels, - title, - savedir, - ssp_data, - mzr_params, - scatter_params, - dmag=0.1, - cosmo_params=DEFAULT_COSMOLOGY, - fb=FB, - lgmp_min=10.0, - lgmp_max=15.0, - lc_vol_mpc3=7e4, - diffstarpop_params2=None, - spspop_params2=None, - ssperrpop_params2=None, - label2=None, -): - if diffstarpop_params2 is not None: - ( - obs_colors_mag1, - N_weights1, - obs_colors_mag2, - N_weights2, - ) = get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - diffstarpop_params2=diffstarpop_params2, - spspop_params2=spspop_params2, - ssperrpop_params2=ssperrpop_params2, - ) - - else: - obs_colors_mag1, N_weights1 = get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - ) - color_bin_edges = np.arange(-0.5 - dmag / 2, 2.2, dmag) - mag_bin_edges = np.arange(18.0 - dmag / 2, mag_thresh, dmag) - ranges = [(color_bin_edges[0], color_bin_edges[-1])] * ( - len(dimension_labels) - len(mag_columns) - ) - for m in range(0, len(mag_columns)): - ranges.append((mag_bin_edges[0], mag_bin_edges[-1])) - - # data - fig_corner = corner.corner( - dataset_colors_mag, - # weights=dataset_colors_mag, - color=color_data, - labels=dimension_labels, - label_kwargs={"fontsize": 20}, - plot_datapoints=False, - smooth=1.0, - levels=[0.68, 0.95], - hist_kwargs={ - "histtype": "stepfilled", - "alpha": 0.5, - "lw": lw, - "density": True, - }, - fill_contours=False, - plot_density=False, - contour_kwargs={"linewidths": 3.5, "alpha": 0.75}, - range=ranges, - ) - - fig_corner.suptitle(title, fontsize=fontsize + 4) - - # model 1 - corner.corner( - obs_colors_mag1, - weights=N_weights1, - fig=fig_corner, - color=color1, - smooth=1.0, - plot_datapoints=False, - levels=[0.68, 0.95], - hist_kwargs={ - "histtype": "step", - "alpha": 0.5, - "lw": lw + 2, - "density": True, - }, - fill_contours=False, - plot_density=False, - contour_kwargs={"linewidths": 3.5, "alpha": 0.75}, - range=ranges, - ) - - # model 2 - if diffstarpop_params2 is not None: - corner.corner( - obs_colors_mag2, - weights=N_weights2, - fig=fig_corner, - color=color2, - smooth=1.0, - plot_datapoints=False, - levels=[0.68, 0.95], - hist_kwargs={ - "histtype": "step", - "alpha": 0.5, - "lw": lw + 2, - "density": True, - }, - plot_density=False, - fill_contours=False, - contour_kwargs={"linewidths": 3.5, "alpha": 0.75}, - range=ranges, - ) - - if label2 is not None: - handles = [ - Line2D([], [], color=color1, lw=lw + 1, label=label1), - Line2D([], [], color=color2, lw=lw + 1, label=label2), - Line2D([], [], color=color_data, lw=lw, label="FENIKS-UDS"), - ] - else: - handles = [ - Line2D([], [], color=color1, lw=lw + 1, label=label1), - Line2D([], [], color=color_data, lw=lw, label="FENIKS-UDS"), - ] - - fig_corner.axes[0].legend( - handles=handles, - loc="center left", - bbox_to_anchor=(1.0, 0.5), - frameon=False, - fontsize=fontsize, - ) - - for ax in fig_corner.get_axes(): - ax.tick_params(axis="both", direction="in", labelsize=labelsize / 2) - fig_corner.savefig( - savedir - + "/z" - + str(zmin) - + "-" - + str(zmax) - + "_corner_fit_" - + savedir.split("/")[-1] - + ".pdf" - ) - plt.show() - - -def get_model_colors_mag( - ran_key, - diffstarpop_params1, - spspop_params1, - ssperrpop_params1, - lc_vol_mpc3, - zmin, - zmax, - tcurves, - mag_columns, - mag_thresh_column, - mag_thresh, - frac_cat, - ssp_data, - lgmp_min, - lgmp_max, - mzr_params, - scatter_params, - cosmo_params, - fb, - n_z_phot_table=15, - data_sky_area_degsq=None, - diffstarpop_params2=None, - spspop_params2=None, - ssperrpop_params2=None, -): - """mc lightcone""" - ran_key, lc_key = jran.split(ran_key, 2) - sky_area_degsq = zbin_area(lc_vol_mpc3, zlow=zmin, zhigh=zmax).value - lc_args = (lc_key, lgmp_min, zmin, zmax, sky_area_degsq) - lc_halopop = mclh.mc_lightcone_host_halo_diffmah( - *lc_args, cosmo_params=cosmo_params, lgmp_max=lgmp_max - ) - - z_phot_table = jnp.linspace(zmin, zmax, n_z_phot_table) - t_0 = flat_wcdm.age_at_z0(*cosmo_params) - lgt0 = jnp.log10(t_0) - t_table = jnp.linspace(T_TABLE_MIN, 10**lgt0, 100) - - precomputed_ssp_mag_table = psspp.get_precompute_ssp_mag_redshift_table( - tcurves, ssp_data, z_phot_table, cosmo_params - ) - - wave_eff_table = lc_phot_kern.get_wave_eff_table(z_phot_table, tcurves) - - ran_key, phot_key1 = jran.split(ran_key, 2) - phot_args1 = ( - phot_key1, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - diffstarpop_params1, - mzr_params, - spspop_params1, - scatter_params, - ssperrpop_params1, - cosmo_params, - fb, - ) - - lc_phot1 = lc_phot_kern.multiband_lc_phot_kern(*phot_args1) - num_halos, n_bands = lc_phot1.obs_mags_q.shape - - ( - obs_colors_mag_q1, - obs_colors_mag_smooth_ms1, - obs_colors_mag_bursty_ms1, - ) = get_obs_colors_mag(lc_phot1, mag_columns) - obs_colors_mag1 = np.concatenate( - [obs_colors_mag_q1, obs_colors_mag_smooth_ms1, obs_colors_mag_bursty_ms1] - ) - - # set weights=0 for mag > mag_thresh for the band indicated by mag_thresh_column - obs_mag_q1 = lc_phot1.obs_mags_q[:, mag_thresh_column] - obs_mag_smooth_ms1 = lc_phot1.obs_mags_smooth_ms[:, mag_thresh_column] - obs_mag_bursty_ms1 = lc_phot1.obs_mags_bursty_ms[:, mag_thresh_column] - - lc_phot_weights_q1 = jnp.where( - obs_mag_q1 < mag_thresh, - lc_phot1.weights_q, - jnp.zeros_like(lc_phot1.weights_q), - ) - lc_phot_weights_smooth_ms1 = jnp.where( - obs_mag_smooth_ms1 < mag_thresh, - lc_phot1.weights_smooth_ms, - jnp.zeros_like(lc_phot1.weights_smooth_ms), - ) - lc_phot_weights_bursty_ms1 = jnp.where( - obs_mag_bursty_ms1 < mag_thresh, - lc_phot1.weights_bursty_ms, - jnp.zeros_like(lc_phot1.weights_bursty_ms), - ) - N_weights1 = np.concatenate( - [ - lc_phot_weights_q1 * frac_cat, - lc_phot_weights_smooth_ms1 * frac_cat, - lc_phot_weights_bursty_ms1 * frac_cat, - ] - ) - - if diffstarpop_params2 is not None: - ran_key, phot_key2 = jran.split(ran_key, 2) - phot_args2 = ( - phot_key2, - lc_halopop["z_obs"], - lc_halopop["t_obs"], - lc_halopop["mah_params"], - lc_halopop["logmp0"], - t_table, - ssp_data, - precomputed_ssp_mag_table, - z_phot_table, - wave_eff_table, - diffstarpop_params2, - mzr_params, - spspop_params2, - scatter_params, - ssperrpop_params2, - cosmo_params, - fb, - ) - - lc_phot2 = lc_phot_kern.multiband_lc_phot_kern(*phot_args2) - - ( - obs_colors_mag_q2, - obs_colors_mag_smooth_ms2, - obs_colors_mag_bursty_ms2, - ) = get_obs_colors_mag(lc_phot2, mag_columns) - obs_colors_mag2 = np.concatenate( - [ - obs_colors_mag_q2, - obs_colors_mag_smooth_ms2, - obs_colors_mag_bursty_ms2, - ] - ) - # set weights=0 for mag > mag_thresh for the band indicated by mag_thresh_column - obs_mag_q2 = lc_phot2.obs_mags_q[:, mag_thresh_column] - obs_mag_smooth_ms2 = lc_phot2.obs_mags_smooth_ms[:, mag_thresh_column] - obs_mag_bursty_ms2 = lc_phot2.obs_mags_bursty_ms[:, mag_thresh_column] - - lc_phot_weights_q2 = jnp.where( - obs_mag_q2 < mag_thresh, - lc_phot2.weights_q, - jnp.zeros_like(lc_phot2.weights_q), - ) - lc_phot_weights_smooth_ms2 = jnp.where( - obs_mag_smooth_ms2 < mag_thresh, - lc_phot2.weights_smooth_ms, - jnp.zeros_like(lc_phot2.weights_smooth_ms), - ) - lc_phot_weights_bursty_ms2 = jnp.where( - obs_mag_bursty_ms2 < mag_thresh, - lc_phot2.weights_bursty_ms, - jnp.zeros_like(lc_phot2.weights_bursty_ms), - ) - N_weights2 = np.concatenate( - [ - lc_phot_weights_q2 * frac_cat, - lc_phot_weights_smooth_ms2 * frac_cat, - lc_phot_weights_bursty_ms2 * frac_cat, - ] - ) - return obs_colors_mag1, N_weights1, obs_colors_mag2, N_weights2 - else: - return obs_colors_mag1, N_weights1 - - -def get_obs_colors_mag(lc_phot, mag_columns): - num_halos, n_bands = lc_phot.obs_mags_q.shape - - obs_colors_mag_q = [] - obs_colors_mag_smooth_ms = [] - obs_colors_mag_bursty_ms = [] - - for i in range(n_bands - 1): - obs_color_q = lc_phot.obs_mags_q[:, i] - lc_phot.obs_mags_q[:, i + 1] - obs_colors_mag_q.append(obs_color_q) - - obs_color_smooth_ms = ( - lc_phot.obs_mags_smooth_ms[:, i] - lc_phot.obs_mags_smooth_ms[:, i + 1] - ) - obs_colors_mag_smooth_ms.append(obs_color_smooth_ms) - - obs_color_bursty_ms = ( - lc_phot.obs_mags_bursty_ms[:, i] - lc_phot.obs_mags_bursty_ms[:, i + 1] - ) - obs_colors_mag_bursty_ms.append(obs_color_bursty_ms) - - """mag_column""" - for mag_column in mag_columns: - obs_mag_q = lc_phot.obs_mags_q[:, mag_column] - obs_colors_mag_q.append(obs_mag_q) - - obs_mag_smooth_ms = lc_phot.obs_mags_smooth_ms[:, mag_column] - obs_colors_mag_smooth_ms.append(obs_mag_smooth_ms) - - obs_mag_bursty_ms = lc_phot.obs_mags_bursty_ms[:, mag_column] - obs_colors_mag_bursty_ms.append(obs_mag_bursty_ms) - - obs_colors_mag_q = jnp.asarray(obs_colors_mag_q).T - obs_colors_mag_smooth_ms = jnp.asarray(obs_colors_mag_smooth_ms).T - obs_colors_mag_bursty_ms = jnp.asarray(obs_colors_mag_bursty_ms).T - - return obs_colors_mag_q, obs_colors_mag_smooth_ms, obs_colors_mag_bursty_ms diff --git a/diffhtwo/experimental/diagnostics/plot_phot.py b/diffhtwo/experimental/diagnostics/plot_phot.py index 0baed1faa..1c82a78cf 100644 --- a/diffhtwo/experimental/diagnostics/plot_phot.py +++ b/diffhtwo/experimental/diagnostics/plot_phot.py @@ -573,6 +573,7 @@ def plot_app_mag_funcs( data_label, param_collection, ran_key, + zbins, ssp_data, savedir, lgmp_min=10.0, @@ -580,6 +581,7 @@ def plot_app_mag_funcs( num_halos=5000, lc_sky_area_degsq=1000, n_z_phot_table=30, + dmag=0.5, cosmo_params=DEFAULT_COSMOLOGY, fb=FB, plt_show=True, @@ -587,59 +589,82 @@ def plot_app_mag_funcs( dataset_mags = dataset.mags data_sky_area_degsq = dataset.data_sky_area_degsq - feniks_zbins = np.array( - [ - [0.2, 0.5], - [0.5, 0.8], - [0.8, 1.2], - [1.2, 1.6], - [1.6, 2.0], + zbins = np.array(zbins) + labels_z = [" z = " + str(np.round(np.median(z), 2)) for z in zbins] + + if len(labels_z) == 3: + colors_z = [ + "#001219", + "#0a7a80", + "#c87820", + ] + + elif len(labels_z) == 4: + colors_z = [ + "#001219", + "#0a7a80", + "#80cca8", + "#c87820", ] - ) - labels_z = [" z = " + str(np.round(np.median(z), 2)) for z in feniks_zbins] - - colors_z = [ - "#001219", # deep navy - "#0a7a80", # teal - "#80cca8", # mint - "#c8b44a", # warm gold - "#c87820", # amber - ] - fig_width = 7.1 - fig_height = 5 + + elif len(labels_z) == 5: + colors_z = [ + "#001219", + "#0a7a80", + "#80cca8", + "#c8b44a", + "#c87820", + ] + elif len(labels_z) == 6: + colors_z = ["#001219", "#0a7a80", "#80cca8", "#c8b44a", "#c87820", "#9b1d20"] + + n_bands = dataset_mags.shape[1] - 1 + if n_bands <= 5: + nrows, ncols = 1, n_bands + else: + ncols = 4 + nrows = int(np.ceil(n_bands / ncols)) + + fig_width = 7.1 * ncols / 4 + fig_height = 5 * nrows / 2 fontsize = 10 labelsize = 10 alpha = 0.75 s = 10 - fig, ax = plt.subplots(2, 4, figsize=(fig_width, fig_height)) - fig.subplots_adjust( - left=0.05, hspace=0.3, top=0.875, right=0.99, bottom=0.1, wspace=0.1 + fig, axes = plt.subplots( + nrows, ncols, figsize=(fig_width, fig_height), constrained_layout=True ) + if nrows == 1: + axes = axes[np.newaxis, :] + if ncols == 1: + axes = axes[:, np.newaxis] handles = [ mlines.Line2D([], [], color=c, linewidth=6, solid_capstyle="butt", label=label) for c, label in zip(colors_z, labels_z) ] + fig.get_layout_engine().set(rect=[0, 0, 1, 0.92]) + fig.legend( handles=handles, loc="upper center", - ncol=8, + ncol=len(zbins), frameon=False, handlelength=3, handleheight=0.5, columnspacing=0.8, handletextpad=0.1, - bbox_to_anchor=(0.5, 0.92), - fontsize=7, + bbox_to_anchor=(0.5, 1.0), + fontsize=10, ) xlim = [] - for zbin in range(0, len(feniks_zbins)): - z_min = feniks_zbins[zbin][0] - z_max = feniks_zbins[zbin][1] + for zbin in range(len(zbins)): + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] z_min, z_max = np.round(z_min, 2), np.round(z_max, 2) z_mask = (dataset_mags[:, -1] > z_min) & (dataset_mags[:, -1] < z_max) @@ -669,12 +694,9 @@ def plot_app_mag_funcs( dataset.frac_cat, ) - n_bands = obs_mags.shape[1] - row = 0 col = 0 - dmag = 0.5 - for i in range(0, n_bands): + for i in range(n_bands): bins = np.arange( dataset_mags_z[:, i].min(), dataset_mags_z[:, i].max() + dmag, @@ -684,8 +706,7 @@ def plot_app_mag_funcs( xlim.append([bins.min() - 0.5, bins.max() + 1]) bin_centers = (bins[1:] + bins[:-1]) / 2 - ax[row, col].set_xlim(bins[0], bins[-1] + 0.2) - # ax[0, i].set_xticks([]) + axes[row, col].set_xlim(bins[0], bins[-1] + 0.2) n_data, bin_edges = np.histogram( dataset_mags_z[:, i], @@ -694,62 +715,65 @@ def plot_app_mag_funcs( ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - ax[row, col].scatter( + axes[row, col].scatter( bin_centers, np.log10(n_data), c=colors_z[zbin], alpha=alpha, s=s ) - ( - n_diffsky, - _, - ) = np.histogram( + n_diffsky, _ = np.histogram( obs_mags[:, i], weights=weights * (1 / lc_data.lc_tot_vol_mpc3), bins=bins, ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning) - ax[row, col].plot( + axes[row, col].plot( bin_centers, np.log10(n_diffsky), c=colors_z[zbin], alpha=alpha ) - ax[row, col].set_xticks(np.arange(15, 30, 2)) - ax[row, col].tick_params( + axes[row, col].set_xticks(np.arange(10, 30, 2)) + axes[row, col].minorticks_on() + axes[row, col].tick_params( which="major", - length=3, - width=1.5, direction="in", top=True, right=True, + length=6, + width=1, labelsize=labelsize, ) - ax[row, col].tick_params( + axes[row, col].tick_params( which="minor", - length=1.5, - width=1.5, direction="in", top=True, right=True, + length=3, + width=0.8, + labelsize=labelsize, ) - ax[row, col].set_ylim(-6.9, -2.5) - ax[row, col].set_xlim(xlim[i]) - ax[row, col].set_xlabel(dataset.mags_labels[i]) + axes[row, col].set_ylim(-6.9, -2.5) + axes[row, col].set_xlim(xlim[i]) + axes[row, col].set_xlabel(dataset.mags_labels[i]) if col != 0: - ax[row, col].set_yticklabels([]) + axes[row, col].set_yticklabels([]) - if col == 3: + if col == ncols - 1: row += 1 col = 0 else: col += 1 - ax[0, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) - ax[1, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) + for idx in range(n_bands, nrows * ncols): + r, c = divmod(idx, ncols) + axes[r, c].set_visible(False) + + for r in range(nrows): + axes[r, 0].set_ylabel("log$_{10}$ (n [Mpc$^{-3}$])", fontsize=fontsize) + fig.savefig( savedir + "/" + data_label + "_app_mag_funcs.png", - bbox_inches="tight", - dpi=200, + dpi=300, ) if plt_show: plt.show() diff --git a/diffhtwo/experimental/diagnostics/plot_insitu_sm.py b/diffhtwo/experimental/diagnostics/plot_sm.py similarity index 52% rename from diffhtwo/experimental/diagnostics/plot_insitu_sm.py rename to diffhtwo/experimental/diagnostics/plot_sm.py index fac609f82..caf228c5f 100644 --- a/diffhtwo/experimental/diagnostics/plot_insitu_sm.py +++ b/diffhtwo/experimental/diagnostics/plot_sm.py @@ -7,7 +7,7 @@ plt.rc("font", family="serif", serif=["Times New Roman"]) -def plot_insitu_sm( +def plot_insitu_sm_obs( ran_key, param_collection, z_min, @@ -105,7 +105,120 @@ def plot_insitu_sm( fig.savefig( savedir - + "/insitu_sm_" + + "/insitu_sm_obs_" + + model_nickname + + "_z" + + z_min_label + + "-" + + z_max_label + + ".png", + bbox_inches="tight", + dpi=200, + ) + if plt_show: + plt.show() + plt.close() + + +def plot_sm_obs( + ran_key, + param_collection, + z_min, + z_max, + dimension_labels, + ssp_data, + tcurves, + model_nickname, + savedir, + mag_thresh=None, + frac_cat=None, + num_halos=1000, + plt_show=True, +): + fig, ax = plt.subplots(1, figsize=(5, 5)) + + z_min_label = str(np.round(z_min, 2)) + z_max_label = str(np.round(z_max, 2)) + + fig.suptitle(z_min_label + " < z < " + z_max_label) + + """fit""" + lc_data, phot_kern_results, weights = multiband_lc_phot_kern( + ran_key, + param_collection, + z_min, + z_max, + num_halos, + ssp_data, + tcurves, + mag_thresh=mag_thresh, + frac_cat=frac_cat, + ) + + bins = np.linspace( + phot_kern_results.logsm_obs.min(), + phot_kern_results.logsm_obs.max(), + 50, + ) + + cen = lc_data.is_central == 1 + sat = lc_data.is_central != 1 + + ax.hist( + phot_kern_results.logsm_obs[sat], + bins=bins, + label="fit sat", + color="tab:orange", + histtype="step", + ) + ax.hist( + phot_kern_results.logsm_obs[cen], + bins=bins, + label="fit cen", + color="tab:blue", + histtype="step", + ) + + """default""" + lc_data, phot_kern_results, weights = multiband_lc_phot_kern( + ran_key, + DEFAULT_PARAM_COLLECTION, + z_min, + z_max, + num_halos, + ssp_data, + tcurves, + ) + + cen = lc_data.is_central == 1 + sat = lc_data.is_central != 1 + + ax.hist( + phot_kern_results.logsm_obs[sat], + bins=bins, + label="default sat", + color="tab:orange", + histtype="step", + ls="--", + ) + ax.hist( + phot_kern_results.logsm_obs[cen], + bins=bins, + label="default cen", + color="tab:blue", + histtype="step", + ls="--", + ) + + ax.set_xlabel("logsm_obs") + ax.set_ylabel("#") + ax.set_xlim(0, 14) + ax.set_yscale("log") + ax.legend() + + fig.savefig( + savedir + + "/sm_obs_" + model_nickname + "_z" + z_min_label diff --git a/diffhtwo/experimental/diagnostics/plot_smhm.py b/diffhtwo/experimental/diagnostics/plot_smhm.py new file mode 100644 index 000000000..2c69a1e9c --- /dev/null +++ b/diffhtwo/experimental/diagnostics/plot_smhm.py @@ -0,0 +1,182 @@ +import matplotlib.pyplot as plt +import numpy as np +from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION +from scipy.stats import binned_statistic + +from ..kernels.lc_phot_kern import multiband_lc_phot_kern + +plt.rc("font", family="serif", serif=["Times New Roman"]) + + +def get_median_logsm_obs(logmp_obs, logsm_obs): + logmp_bins = np.arange(logmp_obs.min(), logmp_obs.max() + 0.25, 0.25) + logmp_bin_centers = (logmp_bins[:-1] + logmp_bins[1:]) / 2 + + logsm_16, __, __ = binned_statistic( + logmp_obs, logsm_obs, bins=logmp_bins, statistic=lambda x: np.percentile(x, 16) + ) + logsm_50, __, __ = binned_statistic( + logmp_obs, logsm_obs, bins=logmp_bins, statistic="median" + ) + logsm_84, __, __ = binned_statistic( + logmp_obs, logsm_obs, bins=logmp_bins, statistic=lambda x: np.percentile(x, 84) + ) + return logmp_bin_centers, logsm_16, logsm_50, logsm_84 + + +def plot_smhm( + ran_key, + param_collection, + zbins, + num_halos, + ssp_data, + tcurves, + data_label, + savedir, + mag_thresh=None, + frac_cat=None, + in_situ=False, + plt_show=True, +): + n_z_bins = len(zbins) + fig_width = 1.42 * n_z_bins + fig_height = 2 + fig, ax = plt.subplots( + 1, len(zbins), figsize=(fig_width, fig_height), constrained_layout=True + ) + + labelsize = 10 + fontsize = 10 + labelsize = 10 + alpha = 0.25 + + for zbin in range(n_z_bins): + z_min = zbins[zbin][0] + z_max = zbins[zbin][1] + z_min_label = str(np.round(z_min, 2)) + z_max_label = str(np.round(z_max, 2)) + ax[zbin].set_title(z_min_label + " < z < " + z_max_label) + + """default""" + lc_data, phot_kern_results, gal_weight = multiband_lc_phot_kern( + ran_key, + DEFAULT_PARAM_COLLECTION, + z_min, + z_max, + num_halos, + ssp_data, + tcurves, + mag_thresh=mag_thresh, + frac_cat=frac_cat, + ) + if in_situ: + logsm_obs = phot_kern_results.logsm_obs_in_situ + else: + logsm_obs = phot_kern_results.logsm_obs + ( + logmp_bin_centers_default, + logsm_16_default, + logsm_50_default, + logsm_84_default, + ) = get_median_logsm_obs(lc_data.logmp_obs, logsm_obs) + + ax[zbin].plot( + logmp_bin_centers_default, + logsm_50_default, + label="default", + color="#FFB689", + ) + ax[zbin].fill_between( + logmp_bin_centers_default, + logsm_16_default, + logsm_84_default, + alpha=alpha, + color="#FFB689", + ) + + """fit""" + lc_data, phot_kern_results, gal_weight = multiband_lc_phot_kern( + ran_key, + param_collection, + z_min, + z_max, + num_halos, + ssp_data, + tcurves, + mag_thresh=mag_thresh, + frac_cat=frac_cat, + ) + if in_situ: + logsm_obs = phot_kern_results.logsm_obs_in_situ + else: + logsm_obs = phot_kern_results.logsm_obs + + ( + logmp_bin_centers_fit, + logsm_16_fit, + logsm_50_fit, + logsm_84_fit, + ) = get_median_logsm_obs(lc_data.logmp_obs, logsm_obs) + + ax[zbin].plot(logmp_bin_centers_fit, logsm_50_fit, label="fit", color="#61C0BF") + ax[zbin].fill_between( + logmp_bin_centers_fit, + logsm_16_fit, + logsm_84_fit, + alpha=alpha, + color="#61C0BF", + ) + + # if in_situ: + # ax[zbin].set_xlim(10, lc_data.logmp_obs.max()) + # ax[zbin].set_ylim(5, 13) + # ax[zbin].set_xticks([10, 11, 12, 13, 14, 15]) + # ax[zbin].set_yticks([5, 6, 7, 8, 9, 10, 11, 12]) + # else: + ax[zbin].set_xlim(11, lc_data.logmp_obs.max()) + ax[zbin].set_ylim(8, 13) + ax[zbin].set_xticks([11, 12, 13, 14, 15]) + ax[zbin].set_yticks([8, 9, 10, 11, 12]) + + ax[zbin].set_xlabel("logmp_obs", fontsize=fontsize) + + ax[zbin].minorticks_on() + ax[zbin].tick_params( + which="major", + direction="in", + top=True, + right=True, + length=6, + width=1, + labelsize=labelsize, + ) + ax[zbin].tick_params( + which="minor", + direction="in", + top=True, + right=True, + length=3, + width=0.8, + labelsize=labelsize, + ) + + if in_situ: + ax[0].set_ylabel("logsm_obs_in_situ", fontsize=fontsize) + else: + ax[0].set_ylabel("logsm_obs", fontsize=fontsize) + ax[-1].legend(fontsize=7, loc="lower right") + + if in_situ: + fig.savefig( + savedir + "/" + data_label + "_insitu_smhm.png", + dpi=300, + ) + else: + fig.savefig( + savedir + "/" + data_label + "_smhm.png", + dpi=300, + ) + + if plt_show: + plt.show() + plt.close() diff --git a/diffhtwo/experimental/emline_luminosity_pop.py b/diffhtwo/experimental/emline_luminosity_pop.py deleted file mode 100644 index 4a232a970..000000000 --- a/diffhtwo/experimental/emline_luminosity_pop.py +++ /dev/null @@ -1,246 +0,0 @@ -# flake8: noqa: E402 -""" """ -from jax import config - -config.update("jax_enable_x64", True) - -from collections import namedtuple - -import jax.numpy as jnp -from diffsky.burstpop import diffqburstpop_mono, freqburst_mono -from diffsky.dustpop import tw_dustpop_mono_noise -from diffsky.experimental.lc_phot_kern import diffstarpop_lc_cen_wrapper -from dsps.metallicity import umzr -from dsps.sed import metallicity_weights as zmetw -from dsps.sed.stellar_age_weights import calc_age_weights_from_sfh_table -from jax import jit as jjit -from jax import random as jran -from jax import vmap - -from . import emline_luminosity -from .emline_utils import get_ssp_emline_luminosity - -LGMET_SCATTER = 0.2 - -# copied from astropy.constants.L_sun.cgs.value -L_SUN_CGS = jnp.array(3.828e33, dtype="float64") - - -_M = (0, None, None) -_calc_lgmet_weights_galpop = jjit( - vmap(zmetw.calc_lgmet_weights_from_lognormal_mdf, in_axes=_M) -) - -_B = (None, 0, 0, None, 0) -_calc_bursty_age_weights_vmap = jjit( - vmap( - diffqburstpop_mono.calc_bursty_age_weights_from_diffburstpop_params, in_axes=_B - ) -) - -_AGEPOP = (None, 0, None, 0) -calc_age_weights_from_sfh_table_vmap = jjit( - vmap(calc_age_weights_from_sfh_table, in_axes=_AGEPOP) -) - - -_D = (None, None, 0, 0, 0, None, 0, 0, 0, None) -calc_dust_ftrans_vmap = jjit( - vmap( - tw_dustpop_mono_noise.calc_ftrans_singlegal_singlewave_from_dustpop_params, - in_axes=_D, - ) -) - - -_LCLINE_RET_KEYS = ( - "emline_L_cgs_q", - "emline_L_cgs_smooth_ms", - "emline_L_cgs_bursty_ms", - "weights_q", - "weights_smooth_ms", - "weights_bursty_ms", -) -LCLine = namedtuple("LCLine", _LCLINE_RET_KEYS) -LCLINE_EMPTY = LCLine._make([None] * len(LCLine._fields)) - - -@jjit -def emline_luminosity_pop( - diffstarpop_params, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, -): - ssp_emline_luminosity = get_ssp_emline_luminosity(emline_wave_aa, ssp_data) - n_met, n_age = ssp_emline_luminosity.shape - n_gals = logmp0.size - - ran_key, sfh_key = jran.split(ran_key, 2) - diffstar_galpop = diffstarpop_lc_cen_wrapper( - diffstarpop_params, - sfh_key, - mah_params, - logmp0, - t_table, - t_obs, - cosmo_params, - fb, - ) - - # get age weights - smooth_age_weights_ms = calc_age_weights_from_sfh_table_vmap( - t_table, diffstar_galpop.sfh_ms, ssp_data.ssp_lg_age_gyr, t_obs - ) - - smooth_age_weights_q = calc_age_weights_from_sfh_table_vmap( - t_table, diffstar_galpop.sfh_q, ssp_data.ssp_lg_age_gyr, t_obs - ) - - # get bursty age weights - _args = ( - spspop_params.burstpop_params, - diffstar_galpop.logsm_obs_ms, - diffstar_galpop.logssfr_obs_ms, - ssp_data.ssp_lg_age_gyr, - smooth_age_weights_ms, - ) - bursty_age_weights_ms, burst_params = _calc_bursty_age_weights_vmap(*_args) - - # get p_burst_ms - p_burst_ms = freqburst_mono.get_freqburst_from_freqburst_params( - spspop_params.burstpop_params.freqburst_params, - diffstar_galpop.logsm_obs_ms, - diffstar_galpop.logssfr_obs_ms, - ) - - # get metallicity weights - lgmet_med_ms = umzr.mzr_model(diffstar_galpop.logsm_obs_ms, t_obs, *mzr_params) - lgmet_med_q = umzr.mzr_model(diffstar_galpop.logsm_obs_q, t_obs, *mzr_params) - - lgmet_weights_ms = _calc_lgmet_weights_galpop( - lgmet_med_ms, LGMET_SCATTER, ssp_data.ssp_lgmet - ) - lgmet_weights_q = _calc_lgmet_weights_galpop( - lgmet_med_q, LGMET_SCATTER, ssp_data.ssp_lgmet - ) - - # age weights * metallicity weights - _w_age_q = smooth_age_weights_q.reshape((n_gals, 1, n_age)) - _w_lgmet_q = lgmet_weights_q.reshape((n_gals, n_met, 1)) - ssp_weights_q = _w_lgmet_q * _w_age_q - - _w_age_ms = smooth_age_weights_ms.reshape((n_gals, 1, n_age)) - _w_lgmet_ms = lgmet_weights_ms.reshape((n_gals, n_met, 1)) - ssp_weights_smooth_ms = _w_lgmet_ms * _w_age_ms - - _w_age_bursty_ms = bursty_age_weights_ms.reshape((n_gals, 1, n_age)) - ssp_weights_bursty_ms = _w_lgmet_ms * _w_age_bursty_ms - - # get ftrans due to dust - ran_key, dust_key = jran.split(ran_key, 2) - av_key, delta_key, funo_key = jran.split(dust_key, 3) - uran_av = jran.uniform(av_key, shape=(n_gals,)) - uran_delta = jran.uniform(delta_key, shape=(n_gals,)) - uran_funo = jran.uniform(funo_key, shape=(n_gals,)) - - ftrans_args_q = ( - spspop_params.dustpop_params, - emline_wave_aa, - diffstar_galpop.logsm_obs_q, - diffstar_galpop.logssfr_obs_q, - z_obs, - ssp_data.ssp_lg_age_gyr, - uran_av, - uran_delta, - uran_funo, - scatter_params, - ) - _res = calc_dust_ftrans_vmap(*ftrans_args_q) - ftrans_q = _res[1] - - ftrans_args_ms = ( - spspop_params.dustpop_params, - emline_wave_aa, - diffstar_galpop.logsm_obs_ms, - diffstar_galpop.logssfr_obs_ms, - z_obs, - ssp_data.ssp_lg_age_gyr, - uran_av, - uran_delta, - uran_funo, - scatter_params, - ) - _res = calc_dust_ftrans_vmap(*ftrans_args_ms) - ftrans_ms = _res[1] - - _ftrans_q = ftrans_q.reshape((n_gals, 1, n_age)) - _ftrans_ms = ftrans_ms.reshape((n_gals, 1, n_age)) - - _mstar_q = 10**diffstar_galpop.logsm_obs_q - _mstar_ms = 10**diffstar_galpop.logsm_obs_ms - - integrand_q = ssp_emline_luminosity * ssp_weights_q * _ftrans_q - emline_L_cgs_q = jnp.sum(integrand_q, axis=(1, 2)) * (L_SUN_CGS * _mstar_q) - - integrand_smooth_ms = ssp_emline_luminosity * ssp_weights_smooth_ms * _ftrans_ms - emline_L_cgs_smooth_ms = jnp.sum(integrand_smooth_ms, axis=(1, 2)) * ( - L_SUN_CGS * _mstar_ms - ) - - integrand_bursty_ms = ssp_emline_luminosity * ssp_weights_bursty_ms * _ftrans_ms - emline_L_cgs_bursty_ms = jnp.sum(integrand_bursty_ms, axis=(1, 2)) * ( - L_SUN_CGS * _mstar_ms - ) - - weights_q = diffstar_galpop.frac_q - weights_smooth_ms = (1 - diffstar_galpop.frac_q) * (1 - p_burst_ms) - weights_bursty_ms = (1 - diffstar_galpop.frac_q) * p_burst_ms - - emline_L = LCLINE_EMPTY._replace( - emline_L_cgs_q=emline_L_cgs_q, - emline_L_cgs_smooth_ms=emline_L_cgs_smooth_ms, - emline_L_cgs_bursty_ms=emline_L_cgs_bursty_ms, - weights_q=weights_q, - weights_smooth_ms=weights_smooth_ms, - weights_bursty_ms=weights_bursty_ms, - ) - - return emline_L - - -@jjit -def emline_luminosity_func_pop(emline_L_tuple, nhalos, sig=None, lgL_bin_edges=None): - # get q emline L_cgs histogram - emline_L_cgs_q = emline_L_tuple.emline_L_cgs_q - w_q = emline_L_tuple.weights_q * nhalos - lgL_bin_edges, tw_hist_q = emline_luminosity.get_emline_luminosity_func( - emline_L_cgs_q, w_q, sig=sig, lgL_bin_edges=lgL_bin_edges - ) - - # get smooth_ms emline L_cgs histogram - emline_L_cgs_smooth_ms = emline_L_tuple.emline_L_cgs_smooth_ms - w_smooth_ms = emline_L_tuple.weights_smooth_ms * nhalos - _, tw_hist_smooth_ms = emline_luminosity.get_emline_luminosity_func( - emline_L_cgs_smooth_ms, w_smooth_ms, sig=sig, lgL_bin_edges=lgL_bin_edges - ) - - # get bursty_ms emline L_cgs histogram - emline_L_cgs_bursty_ms = emline_L_tuple.emline_L_cgs_bursty_ms - w_bursty_ms = emline_L_tuple.weights_bursty_ms * nhalos - _, tw_hist_bursty_ms = emline_luminosity.get_emline_luminosity_func( - emline_L_cgs_bursty_ms, w_bursty_ms, sig=sig, lgL_bin_edges=lgL_bin_edges - ) - - return lgL_bin_edges, tw_hist_q, tw_hist_smooth_ms, tw_hist_bursty_ms diff --git a/diffhtwo/experimental/kernels/N_phot.py b/diffhtwo/experimental/kernels/N_phot.py index 684b7176b..b68604ada 100644 --- a/diffhtwo/experimental/kernels/N_phot.py +++ b/diffhtwo/experimental/kernels/N_phot.py @@ -1,3 +1,4 @@ +from collections import namedtuple from functools import partial import jax.numpy as jnp @@ -5,7 +6,171 @@ from dsps.cosmology import DEFAULT_COSMOLOGY from jax import jit as jjit -from .phot_kern import get_colors_mags +from .phot_kern import get_colors_mags, mag_kern + + +@jjit +def N_colors_mags( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, +): + obs_mags_weighted, gal_weight, phot_kern_results = mag_kern( + ran_key, + param_collection, + z_data.lc_data, + mag_thresh, + frac_cat, + ) + fields = z_data._fields[3:] + mag_thresh = jnp.array(mag_thresh) + for f in range(0, len(fields)): + space = getattr(z_data, fields[f]) + + if isinstance(space, list): + # Colors conditioned on mag space + new_list = [] + for s in range(0, len(space)): + space_n = space[s] + col_idx = space_n.col_idx + + # get cond weight + obs_mags_weighted_cond = obs_mags_weighted[:, space_n.cond_idx] + cond = (obs_mags_weighted_cond > space_n.cond_min) & ( + obs_mags_weighted_cond <= space_n.cond_max + ) + weight = jnp.where(cond, gal_weight, 0.0) + + obs_color = ( + obs_mags_weighted[:, col_idx[0]] - obs_mags_weighted[:, col_idx[1]] + ) + obs_color = obs_color.reshape(obs_color.size, 1) + + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_color, + space_n.sig, + weight, + space_n.bin_lo, + space_n.bin_hi, + ) + + NewTuple = namedtuple( + type(space_n).__name__, [*space_n._fields, "N_model"] + ) + new_list.append(NewTuple(*space_n, N_model)) + z_data = z_data._replace(**{fields[f]: new_list}) + + elif "mag_idx" in space._fields: + if "col_idx" in space._fields: + # Magnitude-Color space + col_idx = space.col_idx + mag_idx = space.mag_idx + + mag = obs_mags_weighted[:, mag_idx] + obs_color = ( + obs_mags_weighted[:, col_idx[0]] - obs_mags_weighted[:, col_idx[1]] + ) + obs_mag_color = jnp.vstack((mag, obs_color)).T + + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_mag_color, + space.sig, + gal_weight, + space.bin_lo, + space.bin_hi, + ) + + NewTuple = namedtuple(type(space).__name__, [*space._fields, "N_model"]) + new = NewTuple(*space, N_model) + z_data = z_data._replace(**{fields[f]: new}) + else: + # Apparent Magnitude space + mag_idx = space.mag_idx + obs_mag = obs_mags_weighted[:, mag_idx] + obs_mag = obs_mag.reshape(obs_mag.size, 1) + + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_mag, + space.sig, + gal_weight, + space.bin_lo, + space.bin_hi, + ) + + NewTuple = namedtuple(type(space).__name__, [*space._fields, "N_model"]) + new = NewTuple(*space, N_model) + z_data = z_data._replace(**{fields[f]: new}) + + else: + # Color-Color space + col_idx = space.col_idx + obs_colors = [] + for c in range(0, len(col_idx) - 1, 2): + obs_color = ( + obs_mags_weighted[:, col_idx[c]] + - obs_mags_weighted[:, col_idx[c + 1]] + ) + obs_colors.append(obs_color) + obs_colors = jnp.array(obs_colors).T + + N_model = diffndhist_lomem.tw_ndhist_weighted( + obs_colors, + space.sig, + gal_weight, + space.bin_lo, + space.bin_hi, + ) + + NewTuple = namedtuple(type(space).__name__, [*space._fields, "N_model"]) + new = NewTuple(*space, N_model) + z_data = z_data._replace(**{fields[f]: new}) + + return z_data + + +@jjit +def N_mags_1d( + ran_key, + param_collection, + magbin_bands, + lc_data, + mag_thresh, + frac_cat, + sig_scale=0.5, +): + obs_mags_weighted, gal_weight, phot_kern_results = mag_kern( + ran_key, + param_collection, + lc_data, + mag_thresh, + frac_cat, + ) + + n_gals, n_bands = obs_mags_weighted.shape + N_bands = [] + for band in range(0, n_bands): + mags = obs_mags_weighted[:, band].reshape(obs_mags_weighted[:, band].size, 1) + + magbin_edges = magbin_bands[band] + + sig = jnp.diff(magbin_edges) * sig_scale + sig = sig.reshape(sig.size, 1) + + mag_lo = magbin_edges[:-1].reshape(magbin_edges[:-1].size, 1) + mag_hi = magbin_edges[1:].reshape(magbin_edges[1:].size, 1) + + N_mags = diffndhist_lomem.tw_ndhist_weighted( + mags, + sig, + gal_weight, + mag_lo, + mag_hi, + ) + N_bands.append(N_mags) + + return N_bands @partial(jjit, static_argnames=["redshift_as_last_dimension_in_lh"]) diff --git a/diffhtwo/experimental/kernels/cat_weights.py b/diffhtwo/experimental/kernels/cat_weights.py index 9d3dec814..bff120d9c 100644 --- a/diffhtwo/experimental/kernels/cat_weights.py +++ b/diffhtwo/experimental/kernels/cat_weights.py @@ -3,16 +3,12 @@ @jjit -def compute_cat_weights(weights, phot_kern_results, mag_thresh, frac_cat): - obs_mags = phot_kern_results.obs_mags - n_gals, n_bands = obs_mags.shape - mag_thresh_mask = jnp.ones((n_gals,), dtype=bool) +def compute_cat_weights(weights, obs_mags_weighted, mag_thresh, frac_cat): + mag_thresh = jnp.array(mag_thresh) + mag_thresh_mask = obs_mags_weighted[:, 0] < mag_thresh[0] - for band in range(0, n_bands): - if mag_thresh[band] is not None: - band_mag_thresh_mask = obs_mags[:, band] < mag_thresh[band] - mag_thresh_mask *= band_mag_thresh_mask + n_gals, n_bands = obs_mags_weighted.shape + for band in range(1, n_bands): + mag_thresh_mask *= obs_mags_weighted[:, band] < mag_thresh[band] - weights = weights * jnp.where(mag_thresh_mask, frac_cat, 0.0) - - return weights + return weights * jnp.where(mag_thresh_mask, frac_cat, 0.0) diff --git a/diffhtwo/experimental/kernels/lc_phot_kern.py b/diffhtwo/experimental/kernels/lc_phot_kern.py index 603c69812..8a333026b 100644 --- a/diffhtwo/experimental/kernels/lc_phot_kern.py +++ b/diffhtwo/experimental/kernels/lc_phot_kern.py @@ -44,11 +44,12 @@ def multiband_lc_phot_kern( param_collection, lc_data, ) + obs_mags_weighted = phot_kern_results.obs_mags_weighted gal_weight = lc_data.cen_weight * lc_data.sat_weight if mag_thresh is not None: gal_weight = compute_cat_weights( - gal_weight, phot_kern_results, mag_thresh, frac_cat + gal_weight, obs_mags_weighted, mag_thresh, frac_cat ) return lc_data, phot_kern_results, gal_weight diff --git a/diffhtwo/experimental/kernels/phot_kern.py b/diffhtwo/experimental/kernels/phot_kern.py index 0ce829f77..71a89655a 100644 --- a/diffhtwo/experimental/kernels/phot_kern.py +++ b/diffhtwo/experimental/kernels/phot_kern.py @@ -55,16 +55,15 @@ def mag_kern( param_collection, lc_data, ) - obs_mags = phot_kern_results.obs_mags_weighted - + obs_mags_weighted = phot_kern_results.obs_mags_weighted gal_weight = lc_data.cen_weight * lc_data.sat_weight # update weights to incorporate mag thresh cuts and frac_cat gal_weight = compute_cat_weights( - gal_weight, phot_kern_results, mag_thresh, frac_cat + gal_weight, obs_mags_weighted, mag_thresh, frac_cat ) - return obs_mags, gal_weight, phot_kern_results + return obs_mags_weighted, gal_weight, phot_kern_results @partial(jjit, static_argnames=["redshift_as_last_dimension_in_lh"]) diff --git a/diffhtwo/experimental/kernels/tests/test_N_phot.py b/diffhtwo/experimental/kernels/tests/test_N_phot.py index efe4bfdda..f6b0ecb06 100644 --- a/diffhtwo/experimental/kernels/tests/test_N_phot.py +++ b/diffhtwo/experimental/kernels/tests/test_N_phot.py @@ -1,8 +1,10 @@ +import jax.numpy as jnp import numpy as np from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION from jax import random as jran from ..N_phot import N_colors_mags_lh +from ..phot_kern import mag_kern def test_N_colors_mags_lh(feniks_single_z_data): @@ -19,3 +21,26 @@ def test_N_colors_mags_lh(feniks_single_z_data): assert np.isfinite(N).all() assert (N >= 0.0).all() + + +def test_mag_kern(feniks): + ran_key = jran.key(0) + + obs_mags_weighted, gal_weight, phot_kern_results = mag_kern( + ran_key, + DEFAULT_PARAM_COLLECTION, + feniks.colors[0].lc_data, + feniks.filter_info.mag_thresh, + feniks.frac_cat, + ) + + assert np.isfinite(obs_mags_weighted).all() + assert np.isfinite(gal_weight).all() + + # ensure that gal_weight for gals above mag_thresh is 0.0 in each band + mag_thresh = jnp.array(feniks.filter_info.mag_thresh) + n_gals, n_bands = obs_mags_weighted.shape + for i in range(0, n_bands): + mag_sel = obs_mags_weighted[:, i] < mag_thresh[i] + gal_weight_above_magthresh = jnp.where(mag_sel, 0, gal_weight) + assert gal_weight_above_magthresh.sum() == 0.0 diff --git a/diffhtwo/experimental/kernels/tests/test_compute_cat_weights.py b/diffhtwo/experimental/kernels/tests/test_compute_cat_weights.py index c05e2700f..b58d51379 100644 --- a/diffhtwo/experimental/kernels/tests/test_compute_cat_weights.py +++ b/diffhtwo/experimental/kernels/tests/test_compute_cat_weights.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION from jax import random as jran @@ -14,15 +15,25 @@ def test_compute_cat_weights(feniks, feniks_lc_data): DEFAULT_PARAM_COLLECTION, feniks_lc_data, ) + obs_mags_weighted = phot_kern_results.obs_mags_weighted gal_weight = feniks_lc_data.cen_weight * feniks_lc_data.sat_weight assert np.isfinite(gal_weight).all() assert (gal_weight >= 0).all() + # apply mag_thresh cuts and frac_cat gal_cat_weight = compute_cat_weights( - gal_weight, phot_kern_results, feniks.filter_info.mag_thresh, feniks.frac_cat + gal_weight, obs_mags_weighted, feniks.filter_info.mag_thresh, feniks.frac_cat ) assert np.isfinite(gal_cat_weight).all() assert (gal_cat_weight >= 0).all() # ensure that gal_cat_weight does not upweight compared to gal_weight assert gal_cat_weight.sum() <= gal_weight.sum() + + # ensure that gal_cat_weight for gals above mag_thresh is 0.0 in each band + mag_thresh = jnp.array(feniks.filter_info.mag_thresh) + n_gals, n_bands = obs_mags_weighted.shape + for i in range(0, n_bands): + mag_sel = obs_mags_weighted[:, i] < mag_thresh[i] + gal_cat_weight_above_magthresh = jnp.where(mag_sel, 0, gal_cat_weight) + assert gal_cat_weight_above_magthresh.sum() == 0.0 diff --git a/diffhtwo/experimental/loss_kernels/loss_functions.py b/diffhtwo/experimental/loss_kernels/loss_functions.py index 82ae48fa4..25aaca59b 100644 --- a/diffhtwo/experimental/loss_kernels/loss_functions.py +++ b/diffhtwo/experimental/loss_kernels/loss_functions.py @@ -14,14 +14,14 @@ def mse_w(lg_n_pred, lg_n_target, lg_n_target_err, lg_n_thresh=-10): return jnp.sum(chi2) / nbins -# @jjit -# def poisson_loss(N_pred, N_target, eps=1e-12): -# N_pred = jnp.clip(N_pred, eps, None) -# return jnp.sum(N_pred - N_target * jnp.log(N_pred)) - - @jjit def poisson_loss(N_pred, N_target, eps=1e-12): N_pred = jnp.clip(N_pred, eps, None) - N_eff = jnp.maximum(jnp.sum(N_target), eps) - return jnp.sum(N_pred - N_target * jnp.log(N_pred)) / N_eff + return jnp.sum(N_pred - N_target * jnp.log(N_pred)) + + +# @jjit +# def poisson_loss(N_pred, N_target, eps=1e-12): +# N_pred = jnp.clip(N_pred, eps, None) +# N_eff = jnp.maximum(jnp.sum(N_target), eps) +# return jnp.sum(N_pred - N_target * jnp.log(N_pred)) / N_eff diff --git a/diffhtwo/experimental/loss_kernels/phot_loss.py b/diffhtwo/experimental/loss_kernels/phot_loss.py index 0086f33d9..01b46d217 100644 --- a/diffhtwo/experimental/loss_kernels/phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/phot_loss.py @@ -1,10 +1,184 @@ from jax import jit as jjit +from jax import lax -from ..kernels.N_phot import N_colors_mags_lh +from ..kernels.N_phot import N_colors_mags, N_colors_mags_lh, N_mags_1d from ..param_utils import get_param_collection_from_u_theta from .loss_functions import poisson_loss +@jjit +def get_phot_loss_2d_multiz( + ran_key, + param_collection, + data, + mag_thresh, + frac_cat, + data_sky_area_degsq, +): + phot_loss_2d = 0.0 + for z in range(0, len(data)): + z_data = data[z] + z_data_model = N_colors_mags( + ran_key, + param_collection, + z_data, + mag_thresh, + frac_cat, + ) + sky_rescale = data_sky_area_degsq / z_data_model.lc_data.sky_area_degsq + fields = z_data_model._fields[3:] + for f in range(0, len(fields)): + space = getattr(z_data_model, fields[f]) + if isinstance(space, list): + for s in range(0, len(space)): + space_n = space[s] + phot_loss_2d += lax.cond( + space_n.fit, + lambda sp=space_n: poisson_loss( + sp.N_model * sky_rescale, sp.N_data + ), + lambda: 0.0, + ) + else: + phot_loss_2d += lax.cond( + space.fit, + lambda sp=space: poisson_loss(sp.N_model * sky_rescale, sp.N_data), + lambda: 0.0, + ) + return phot_loss_2d + + +@jjit +def _loss_phot_kern_2d_multiz( + u_theta, + ran_key, + fitting_data, +): + param_collection = get_param_collection_from_u_theta(u_theta) + + phot_loss_2d = 0.0 + + # get color loss + phot_loss_2d += get_phot_loss_2d_multiz( + ran_key, + param_collection, + fitting_data.colors, + fitting_data.filter_info.mag_thresh, + fitting_data.frac_cat, + fitting_data.data_sky_area_degsq, + ) + # get app mag func loss + phot_loss_2d += get_phot_loss_2d_multiz( + ran_key, + param_collection, + fitting_data.app_mag_funcs, + fitting_data.filter_info.mag_thresh, + fitting_data.frac_cat, + fitting_data.data_sky_area_degsq, + ) + + return phot_loss_2d + + +# @jjit +# def _loss_phot_kern_2d( +# u_theta, +# ran_key, +# z_data, +# mag_thresh, +# frac_cat, +# data_sky_area_degsq, +# ): +# param_collection = get_param_collection_from_u_theta(u_theta) + +# phot_loss_2d = get_phot_loss_2d( +# ran_key, +# param_collection, +# z_data, +# mag_thresh, +# frac_cat, +# data_sky_area_degsq, +# ) + +# return phot_loss_2d + + +@jjit +def get_phot_loss_1d( + ran_key, + param_collection, + magbin_bands, + N_bands_data, + lc_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, +): + N_bands_model = N_mags_1d( + ran_key, param_collection, magbin_bands, lc_data, mag_thresh, frac_cat + ) + + n_bands = len(N_bands_data) + phot_loss_1d = 0.0 + for band in range(0, n_bands): + N_model = N_bands_model[band] * (data_sky_area_degsq / lc_data.sky_area_degsq) + phot_loss_1d += poisson_loss(N_model, N_bands_data[band]) + + return phot_loss_1d + + +@jjit +def _loss_phot_kern_1d( + u_theta, + ran_key, + magbin_bands, + N_bands_data, + lc_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, +): + param_collection = get_param_collection_from_u_theta(u_theta) + + phot_loss_1d_args = ( + ran_key, + param_collection, + magbin_bands, + N_bands_data, + lc_data, + mag_thresh, + frac_cat, + data_sky_area_degsq, + ) + phot_loss_1d = get_phot_loss_1d(*phot_loss_1d_args) + + return phot_loss_1d + + +@jjit +def _loss_phot_kern_multiband_multiz( + u_theta, + ran_key, + fitting_data, +): + zbins = fitting_data.zbins + n_z_bins = len(zbins) + phot_loss_multiband_multiz = 0.0 + for zbin in range(n_z_bins): + phot_loss_multiband_multiz += _loss_phot_kern_1d( + u_theta, + ran_key, + fitting_data.magbin_zbins_bands[zbin], + fitting_data.N_zbins_bands[zbin], + fitting_data.lc_data[zbin], + fitting_data.filter_info.mag_thresh, + fitting_data.frac_cat, + fitting_data.data_sky_area_degsq, + ) + + return phot_loss_multiband_multiz + + @jjit def get_phot_loss( ran_key, diff --git a/diffhtwo/experimental/loss_kernels/tests/test_phot_loss.py b/diffhtwo/experimental/loss_kernels/tests/test_phot_loss.py index 716cbff34..affcc0b78 100644 --- a/diffhtwo/experimental/loss_kernels/tests/test_phot_loss.py +++ b/diffhtwo/experimental/loss_kernels/tests/test_phot_loss.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION from jax import random as jran @@ -6,6 +7,9 @@ from ..phot_loss import _loss_phot_kern, get_phot_loss +@pytest.mark.skip( + reason="LH dimensions need to be fixed in load_feniks before activating this test again" +) def test_phot_loss(feniks_single_z_data): feniks_meta_data, feniks_fitting_data = feniks_single_z_data diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 3652c594e..21997606a 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -11,11 +11,14 @@ import jax.numpy as jnp from jax import jit as jjit from jax import lax, value_and_grad, vmap -from jax.debug import print from jax.example_libraries import optimizers as jax_opt from ..loss_kernels.emline_loss import _loss_emline_kern_multi_line_multi_z -from ..loss_kernels.phot_loss import _loss_phot_kern +from ..loss_kernels.phot_loss import ( + _loss_phot_kern, + _loss_phot_kern_2d_multiz, + _loss_phot_kern_multiband_multiz, +) _L_pk = ( None, @@ -34,6 +37,12 @@ value_and_grad(_loss_emline_kern_multi_line_multi_z) ) +_loss_and_grad_phot_kern_multiband_multiz = jjit( + value_and_grad(_loss_phot_kern_multiband_multiz) +) + +_loss_and_grad_phot_kern_2d_multiz = jjit(value_and_grad(_loss_phot_kern_2d_multiz)) + @partial(jjit, static_argnames=["n_steps", "step_size"]) def fit_N_multi_z( @@ -63,10 +72,90 @@ def _opt_update(opt_state, i): ) # clip gradients - global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) - tau = 1.0 - scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) - grads = tuple(g * scale for g in grads) + # global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) + + opt_state = opt_update(i, grads, opt_state) + return opt_state, loss + + opt_state, loss_hist = lax.scan(_opt_update, opt_state, jnp.arange(n_steps)) + u_theta_fit = get_params(opt_state) + + return loss_hist, u_theta_fit + + +@partial(jjit, static_argnames=["n_steps", "step_size"]) +def fit_N_phot_2d( + u_theta_init, + trainable, + ran_key, + fitting_data, + n_steps=2, + step_size=1e-2, +): + opt_init, opt_update, get_params = jax_opt.adam(step_size) + opt_state = opt_init(u_theta_init) + + other = ( + ran_key, + fitting_data, + ) + + def _opt_update(opt_state, i): + u_theta = get_params(opt_state) + loss, grads = _loss_and_grad_phot_kern_2d_multiz(u_theta, *other) + # set grads for untrainable params to 0.0 + grads = tuple( + jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) + ) + + # clip gradients + # global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) + + opt_state = opt_update(i, grads, opt_state) + return opt_state, loss + + opt_state, loss_hist = lax.scan(_opt_update, opt_state, jnp.arange(n_steps)) + u_theta_fit = get_params(opt_state) + + return loss_hist, u_theta_fit + + +@partial(jjit, static_argnames=["n_steps", "step_size"]) +def fit_N_phot_1d( + u_theta_init, + trainable, + ran_key, + fitting_data, + n_steps=2, + step_size=1e-2, +): + opt_init, opt_update, get_params = jax_opt.adam(step_size) + opt_state = opt_init(u_theta_init) + + other = ( + ran_key, + fitting_data, + ) + + def _opt_update(opt_state, i): + u_theta = get_params(opt_state) + loss, grads = _loss_and_grad_phot_kern_multiband_multiz(u_theta, *other) + # set grads for untrainable params to 0.0 + grads = tuple( + jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) + ) + + # clip gradients + # global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) opt_state = opt_update(i, grads, opt_state) return opt_state, loss @@ -83,12 +172,62 @@ def pytree_norm(grads): return jnp.sqrt(sum(jnp.sum(g**2) for g in leaves)) +@partial(jjit, static_argnames=["n_steps", "step_size"]) +def fit_sdss_feniks( + u_theta_init, + trainable, + ran_key, + sdss_fitting_data, + feniks_fitting_data, + n_steps=2, + step_size=1e-2, + w_sdss=1.0, + w_feniks=1.0, +): + opt_init, opt_update, get_params = jax_opt.adam(step_size) + opt_state = opt_init(u_theta_init) + + def _opt_update(opt_state, i): + u_theta = get_params(opt_state) + loss_sdss, grad_sdss = _loss_and_grad_phot_kern_2d_multiz( + u_theta, + ran_key, + sdss_fitting_data, + ) + + loss_feniks, grad_feniks = _loss_and_grad_phot_kern_2d_multiz( + u_theta, + ran_key, + feniks_fitting_data, + ) + + loss_sdss = w_sdss * loss_sdss + loss_feniks = w_feniks * loss_feniks + loss = loss_sdss + loss_feniks + + grads = tuple( + w_sdss * gs + w_feniks * gf for gs, gf in zip(grad_sdss, grad_feniks) + ) + # set grads for untrainable params to 0.0 + grads = tuple( + jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) + ) + + opt_state = opt_update(i, grads, opt_state) + return opt_state, (loss, loss_sdss, loss_feniks) + + opt_state, (loss_hist, loss_sdss_hist, loss_feniks_hist) = lax.scan( + _opt_update, opt_state, jnp.arange(n_steps) + ) + u_theta_fit = get_params(opt_state) + return loss_hist, loss_sdss_hist, loss_feniks_hist, u_theta_fit + + @partial(jjit, static_argnames=["n_steps", "step_size"]) def fit_feniks_hizels( u_theta_init, trainable, ran_key, - feniks_meta_data, feniks_fitting_data, hizels_fitting_data, n_steps=2, @@ -99,10 +238,9 @@ def fit_feniks_hizels( def _opt_update(opt_state, i): u_theta = get_params(opt_state) - loss_phot, grad_phot = _loss_and_grad_phot_kern_multi_z( + loss_phot, grad_phot = _loss_and_grad_phot_kern_2d_multiz( u_theta, ran_key, - feniks_meta_data, feniks_fitting_data, ) loss_emline, grad_emline = _loss_and_grad_emline_kern_multi_line_multi_z( @@ -110,9 +248,13 @@ def _opt_update(opt_state, i): ran_key, hizels_fitting_data, ) - w_phot = 10.0 + w_phot = 1.0 / 5 w_emline = 1.0 - loss = w_phot * loss_phot + w_emline * loss_emline + + loss_phot = w_phot * loss_phot + loss_emline = w_emline * loss_emline + loss = loss_phot + loss_emline + grads = tuple( w_phot * gp + w_emline * ge for gp, ge in zip(grad_phot, grad_emline) ) @@ -122,10 +264,10 @@ def _opt_update(opt_state, i): ) # clip gradients - global_norm = pytree_norm(grads) - tau = 1.0 - scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) - grads = tuple(g * scale for g in grads) + # global_norm = pytree_norm(grads) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) opt_state = opt_update(i, grads, opt_state) return opt_state, (loss, loss_phot, loss_emline) diff --git a/diffhtwo/experimental/optimizers/emline_luminosity_opt.py b/diffhtwo/experimental/optimizers/emline_luminosity_opt.py deleted file mode 100644 index 3f2a0a858..000000000 --- a/diffhtwo/experimental/optimizers/emline_luminosity_opt.py +++ /dev/null @@ -1,166 +0,0 @@ -# flake8: noqa: E402 -""" """ -import jax - -jax.config.update("jax_enable_x64", True) -jax.config.update("jax_debug_nans", True) -jax.config.update("jax_debug_infs", True) - -from functools import partial - -import jax.numpy as jnp -from diffstar.diffstarpop import get_bounded_diffstarpop_params -from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_U_PARAMS -from jax import jit as jjit -from jax import lax, value_and_grad -from jax.example_libraries import optimizers as jax_opt -from jax.flatten_util import ravel_pytree - -from ..emline_luminosity_pop import emline_luminosity_func_pop, emline_luminosity_pop - -u_theta_default, u_unravel_fn = ravel_pytree(DEFAULT_DIFFSTARPOP_U_PARAMS) -IDX = jnp.arange(16, 22, 1) - - -@jjit -def _mse(emline_lf_weighted_composite_true, emline_lf_weighted_composite_pred): - diff = emline_lf_weighted_composite_true - emline_lf_weighted_composite_pred - return jnp.mean(jnp.square(diff)) - - -def make_subspace_loss(u_unravel_fn, u_theta_default, IDX): - """ - Build a loss that optimizes ONLY the parameters at flat indices `IDX`. - - u_unravel_fn: from ravel_pytree(template) - - u_theta_default: 1D base vector (others stay fixed to these values) - - IDX: 1D array/list of flat indices to vary (static for the compiled fn) - - Notes: The only thing you should not do is use namedtuple._replace() / ._asdict() - inside @jjit. Those are Python-side and will slow/break JIT. - """ - IDX = jnp.asarray(IDX, dtype=jnp.int64) # capture in closure - - @jjit - def _loss_kern_subspace( - u_theta_sub, # only the selected subset: shape (len(IDX),) - emline_lf_weighted_composite_true, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, - ): - # scatter the subset into the full flat vector - u_theta_full = u_theta_default.at[IDX].set(u_theta_sub) - - # back to structured params and do the usual - u_diffstarpop_params = u_unravel_fn(u_theta_full) - - # convert to bounded params - diffstarpop_params = get_bounded_diffstarpop_params(u_diffstarpop_params) - - emline_lf_pred = emline_luminosity_pop( - diffstarpop_params, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, - ) - - nhalos = jnp.ones_like(emline_lf_pred.emline_L_cgs_q) - ( - lgL_bin_edges, - emline_lf_weighted_q_pred, - emline_lf_weighted_smooth_ms_pred, - emline_lf_weighted_bursty_ms_pred, - ) = emline_luminosity_func_pop(emline_lf_pred, nhalos) - - emline_lf_weighted_composite_pred = ( - emline_lf_weighted_q_pred - + emline_lf_weighted_smooth_ms_pred - + emline_lf_weighted_bursty_ms_pred - ) - - return _mse( - emline_lf_weighted_composite_true, emline_lf_weighted_composite_pred - ) - - return _loss_kern_subspace - - -loss_kern = make_subspace_loss(u_unravel_fn, u_theta_default, IDX) -loss_and_grad_fn = jjit(value_and_grad(loss_kern)) - - -@partial(jjit, static_argnames=["n_steps", "step_size"]) -def fit_emline_luminosity( - u_theta_init_sub, # only the selected subset: shape (len(IDX),) - emline_lf_weighted_composite_true, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, - n_steps=10, - step_size=1e-2, -): - opt_init, opt_update, get_params = jax_opt.adam(step_size) - opt_state = opt_init(u_theta_init_sub) - - other = ( - emline_lf_weighted_composite_true, - ran_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - cosmo_params, - fb, - ) - - def _opt_update(opt_state, i): - u_theta_sub = get_params(opt_state) - loss, grads = loss_and_grad_fn(u_theta_sub, *other) - opt_state = opt_update(i, grads, opt_state) - return opt_state, loss - - opt_state, loss_hist = lax.scan(_opt_update, opt_state, jnp.arange(n_steps)) - - u_theta_fit_sub = get_params(opt_state) - - return loss_hist, u_theta_fit_sub diff --git a/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py b/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py index 40daef729..d55c338ca 100644 --- a/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py @@ -8,16 +8,15 @@ from ... import param_utils as pu from ..Np_specphot_opt import ( - _loss_and_grad_phot_kern_multi_z, - _loss_and_grad_sdss_feniks_hizels, - fit_N_multi_z, - fit_sdss_feniks_hizels, + _loss_and_grad_emline_kern_multi_line_multi_z, + _loss_and_grad_phot_kern_2d_multiz, + fit_feniks_hizels, + fit_N_phot_2d, ) @pytest.fixture(scope="module") -def multistep_grads(ran_key, feniks_multi_z_data): - feniks_meta_data, feniks_fitting_data = feniks_multi_z_data +def multistep_grads(ran_key, feniks_fitting_data): n_steps = 10 step_size = 0.1 @@ -25,14 +24,13 @@ def multistep_grads(ran_key, feniks_multi_z_data): opt_init, opt_update, get_params = jax_opt.adam(step_size) other = ( ran_key, - feniks_meta_data, feniks_fitting_data, ) opt_state = opt_init(u_theta_init) multistep_grads = [] for i in range(n_steps): u_theta = get_params(opt_state) - loss, grads = _loss_and_grad_phot_kern_multi_z(u_theta, *other) + loss, grads = _loss_and_grad_phot_kern_2d_multiz(u_theta, *other) multistep_grads.append(grads) opt_state = opt_update(i, grads, opt_state) return multistep_grads, n_steps @@ -93,14 +91,11 @@ def test_all_diffsky_u_param_grads_are_nonzero( ), f"These {diffsky_param_names[param_idx]} have exactly zero grads: {zero_grad_params}" -def test_phot_opt(ran_key, feniks_multi_z_data): - feniks_meta_data, feniks_fitting_data = feniks_multi_z_data - +def test_phot_opt(ran_key, feniks_fitting_data): u_theta = pu.get_u_theta_from_param_collection(DEFAULT_PARAM_COLLECTION) - loss, grads = _loss_and_grad_phot_kern_multi_z( + loss, grads = _loss_and_grad_phot_kern_2d_multiz( u_theta, ran_key, - feniks_meta_data, feniks_fitting_data, ) assert np.isfinite(loss) @@ -108,11 +103,10 @@ def test_phot_opt(ran_key, feniks_multi_z_data): assert np.isfinite(grads[g]).all() trainable_params = pu.get_trainable_params(fit_type="all") - loss_hist, u_theta_fit = fit_N_multi_z( + loss_hist, u_theta_fit = fit_N_phot_2d( u_theta, trainable_params, ran_key, - feniks_meta_data, feniks_fitting_data, n_steps=2, step_size=0.1, @@ -126,38 +120,34 @@ def test_phot_opt(ran_key, feniks_multi_z_data): def test_specphot_opt( - ran_key, fake_subset_ssp_data, feniks_multi_z_data, hizels_fitting_data + ran_key, fake_subset_ssp_data, feniks_fitting_data, hizels_fitting_data ): ssp_data, emline_wave_aa = fake_subset_ssp_data - feniks_meta_data, feniks_fitting_data = feniks_multi_z_data - - # duplicate feniks data for sdss data - sdss_meta_data, sdss_fitting_data = feniks_meta_data, feniks_fitting_data - u_theta = pu.get_u_theta_from_param_collection(DEFAULT_PARAM_COLLECTION) - loss, grads = _loss_and_grad_sdss_feniks_hizels( + loss_phot, grad_phot = _loss_and_grad_phot_kern_2d_multiz( u_theta, ran_key, - sdss_meta_data, - sdss_fitting_data, - feniks_meta_data, feniks_fitting_data, - hizels_fitting_data, ) + assert np.isfinite(loss_phot) + for g in range(len(grad_phot)): + assert np.isfinite(grad_phot[g]).all() - assert np.isfinite(loss) - for g in range(len(grads)): - assert np.isfinite(grads[g]).all() + loss_emline, grad_emline = _loss_and_grad_emline_kern_multi_line_multi_z( + u_theta, + ran_key, + hizels_fitting_data, + ) + assert np.isfinite(loss_emline) + for g in range(len(grad_emline)): + assert np.isfinite(grad_emline[g]).all() trainable_params = pu.get_trainable_params(fit_type="all") - loss_hist, u_theta_fit = fit_sdss_feniks_hizels( + loss_hist, u_theta_fit = fit_feniks_hizels( u_theta, trainable_params, ran_key, - sdss_meta_data, - sdss_fitting_data, - feniks_meta_data, feniks_fitting_data, hizels_fitting_data, n_steps=2, diff --git a/diffhtwo/experimental/optimizers/tests/test_emline_luminosity_opt.py b/diffhtwo/experimental/optimizers/tests/test_emline_luminosity_opt.py deleted file mode 100644 index ca209f530..000000000 --- a/diffhtwo/experimental/optimizers/tests/test_emline_luminosity_opt.py +++ /dev/null @@ -1,120 +0,0 @@ -import jax.numpy as jnp -import numpy as np -from diffsky.experimental import mc_lightcone_halos as mclh -from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS -from diffsky.param_utils import spspop_param_utils as spspu -from diffstar.defaults import FB, T_TABLE_MIN -from diffstar.diffstarpop.defaults import ( - DEFAULT_DIFFSTARPOP_PARAMS, - DEFAULT_DIFFSTARPOP_U_PARAMS, -) -from dsps.cosmology import DEFAULT_COSMOLOGY, flat_wcdm -from dsps.data_loaders import retrieve_fake_fsps_data -from dsps.metallicity import umzr -from jax import random as jran -from jax.flatten_util import ravel_pytree - -from ...emline_luminosity_pop import emline_luminosity_func_pop, emline_luminosity_pop -from ..emline_luminosity_opt import IDX, fit_emline_luminosity - -u_theta_default, u_unravel_fn = ravel_pytree(DEFAULT_DIFFSTARPOP_U_PARAMS) -theta_default, unravel_fn = ravel_pytree(DEFAULT_DIFFSTARPOP_PARAMS) - -ssp_data = retrieve_fake_fsps_data.load_fake_ssp_data() -emline_wave_aa = 6000 - - -def test_emline_luminosity_opt(): - ran_key = jran.key(0) - - # generate lightcone - ran_key, lc_key = jran.split(ran_key, 2) - lgmp_min = 12.0 - z_min, z_max = 0.1, 0.5 - sky_area_degsq = 0.1 - - args = (lc_key, lgmp_min, z_min, z_max, sky_area_degsq) - - lc_halopop = mclh.mc_lightcone_host_halo_diffmah(*args) - - n_z_phot_table = 15 - z_phot_table = np.linspace(z_min, z_max, n_z_phot_table) - - z_obs = lc_halopop["z_obs"] - t_obs = lc_halopop["t_obs"] - mah_params = lc_halopop["mah_params"] - logmp0 = lc_halopop["logmp0"] - t_0 = flat_wcdm.age_at_z0(*DEFAULT_COSMOLOGY) - lgt0 = np.log10(t_0) - - t_table = np.linspace(T_TABLE_MIN, 10**lgt0, 100) - - mzr_params = umzr.DEFAULT_MZR_PARAMS - - spspop_params = spspu.DEFAULT_SPSPOP_PARAMS - scatter_params = DEFAULT_SCATTER_PARAMS - - ran_key, dpop_halpha_true_key = jran.split(ran_key, 2) - args = ( - DEFAULT_DIFFSTARPOP_PARAMS, - dpop_halpha_true_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - DEFAULT_COSMOLOGY, - FB, - ) - halpha_L_true = emline_luminosity_pop(*args) - nhalos = jnp.ones_like(halpha_L_true.emline_L_cgs_q) - ( - lgL_bin_edges, - halpha_lf_weighted_q_true, - halpha_lf_weighted_smooth_ms_true, - halpha_lf_weighted_bursty_ms_true, - ) = emline_luminosity_func_pop(halpha_L_true, nhalos) - - halpha_lf_weighted_composite_true = ( - halpha_lf_weighted_q_true - + halpha_lf_weighted_smooth_ms_true - + halpha_lf_weighted_bursty_ms_true - ) - - noise_scale = 0.1 - ran_key, perturb_key = jran.split(ran_key, 2) - u_theta_perturbed = u_theta_default + noise_scale * jran.normal( - perturb_key, shape=u_theta_default.shape - ) - - ran_key, dpop_halpha_perturbed_key = jran.split(ran_key, 2) - fit_args = ( - u_theta_perturbed[IDX], - halpha_lf_weighted_composite_true, - dpop_halpha_perturbed_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - DEFAULT_COSMOLOGY, - FB, - ) - - loss_hist, u_theta_fit_sub = fit_emline_luminosity( - *fit_args, n_steps=2, step_size=0.02 - ) - - assert np.isfinite(loss_hist).all() diff --git a/diffhtwo/experimental/tests/test_emline_luminosity_pop.py b/diffhtwo/experimental/tests/test_emline_luminosity_pop.py deleted file mode 100644 index bebd003af..000000000 --- a/diffhtwo/experimental/tests/test_emline_luminosity_pop.py +++ /dev/null @@ -1,69 +0,0 @@ -import numpy as np -from diffsky.experimental import mc_lightcone_halos as mclh -from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS -from diffsky.param_utils import spspop_param_utils as spspu -from diffstar.defaults import FB, T_TABLE_MIN -from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_PARAMS -from dsps.cosmology import flat_wcdm -from dsps.cosmology.defaults import DEFAULT_COSMOLOGY -from dsps.data_loaders import retrieve_fake_fsps_data -from dsps.metallicity import umzr -from jax import random as jran - -from ..emline_luminosity_pop import emline_luminosity_pop - -ssp_data = retrieve_fake_fsps_data.load_fake_ssp_data() - - -def test_emline_luminosity_pop(): - ran_key = jran.key(0) - lgmp_min = 12.0 - z_min, z_max = 0.1, 0.5 - sky_area_degsq = 0.1 - - ran_key, lc_key = jran.split(ran_key, 2) - args = (lc_key, lgmp_min, z_min, z_max, sky_area_degsq) - - lc_halopop = mclh.mc_lightcone_host_halo_diffmah(*args) - - n_z_phot_table = 15 - z_phot_table = np.linspace(z_min, z_max, n_z_phot_table) - - z_obs = lc_halopop["z_obs"] - t_obs = lc_halopop["t_obs"] - mah_params = lc_halopop["mah_params"] - logmp0 = lc_halopop["logmp0"] - t_0 = flat_wcdm.age_at_z0(*DEFAULT_COSMOLOGY) - lgt0 = np.log10(t_0) - - t_table = np.linspace(T_TABLE_MIN, 10**lgt0, 100) - - mzr_params = umzr.DEFAULT_MZR_PARAMS - - spspop_params = spspu.DEFAULT_SPSPOP_PARAMS - scatter_params = DEFAULT_SCATTER_PARAMS - - emline_wave_aa = 6000 - - ran_key, diffstarpop_key = jran.split(ran_key, 2) - args = ( - DEFAULT_DIFFSTARPOP_PARAMS, - diffstarpop_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - emline_wave_aa, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - DEFAULT_COSMOLOGY, - FB, - ) - halpha_L = emline_luminosity_pop(*args) - - for arr in halpha_L: - assert np.all(np.isfinite(arr)) diff --git a/diffhtwo/experimental/tests/test_line_phot_kern.py b/diffhtwo/experimental/tests/test_line_phot_kern.py index 835e68d1a..afcfecb56 100644 --- a/diffhtwo/experimental/tests/test_line_phot_kern.py +++ b/diffhtwo/experimental/tests/test_line_phot_kern.py @@ -1,22 +1,11 @@ -import jax.numpy as jnp import numpy as np from astropy import units as u from astropy.cosmology import FlatLambdaCDM -from diffsky.experimental import mc_lightcone_halos as mclh -from diffsky.experimental.scatter import DEFAULT_SCATTER_PARAMS -from diffsky.param_utils import spspop_param_utils as spspu -from diffstar.defaults import FB, T_TABLE_MIN -from diffstar.diffstarpop.defaults import DEFAULT_DIFFSTARPOP_PARAMS -from dsps.cosmology import flat_wcdm -from dsps.cosmology.defaults import DEFAULT_COSMOLOGY from dsps.data_loaders import retrieve_fake_fsps_data -from dsps.metallicity import umzr -from jax import random as jran from .. import line_phot_kern from ..data_loaders import retrieve_tcurves from ..defaults import C_ANGSTROMS, HALPHA_CENTER_AA -from ..emline_luminosity_pop import emline_luminosity_func_pop, emline_luminosity_pop ssp_data = retrieve_fake_fsps_data.load_fake_ssp_data() @@ -118,81 +107,3 @@ def test_line_phot_kern(BB_tcurve=HSC_Z_tcurve, NB_tcurve=HSC_NB921_tcurve): BB_mag_ab_check = -2.5 * np.log10(numerator / denominator) - 48.6 assert np.isclose(BB_mag_ab, BB_mag_ab_check) - - -def test_line_mag_vmap(): - ran_key = jran.key(0) - ran_key, lc_key = jran.split(ran_key, 2) - - lgmp_min = 12.0 - z_min, z_max = 0.2, 0.5 - sky_area_degsq = 1 - - """weighted mc lightcone""" - num_halos = 500 - lgmp_max = 15.0 - args = (lc_key, num_halos, z_min, z_max, lgmp_min, lgmp_max, sky_area_degsq) - lc_halopop = mclh.mc_weighted_halo_lightcone(*args) - - n_z_phot_table = 15 - z_phot_table = np.linspace(z_min, z_max, n_z_phot_table) - - z_obs = jnp.array(lc_halopop["z_obs"]) - t_obs = lc_halopop["t_obs"] - mah_params = lc_halopop["mah_params"] - # logmp0 = lc_halopop["logmp0"] - logmp0 = lc_halopop["logmp0"] - nhalos = lc_halopop["nhalos"] - t_0 = flat_wcdm.age_at_z0(*DEFAULT_COSMOLOGY) - lgt0 = np.log10(t_0) - - t_table = np.linspace(T_TABLE_MIN, 10**lgt0, 100) - - mzr_params = umzr.DEFAULT_MZR_PARAMS - - spspop_params = spspu.DEFAULT_SPSPOP_PARAMS - scatter_params = DEFAULT_SCATTER_PARAMS - - ran_key, dpop_halpha_true_key = jran.split(ran_key, 2) - args = ( - DEFAULT_DIFFSTARPOP_PARAMS, - dpop_halpha_true_key, - z_obs, - t_obs, - mah_params, - logmp0, - t_table, - ssp_data, - HALPHA_CENTER_AA, - z_phot_table, - mzr_params, - spspop_params, - scatter_params, - DEFAULT_COSMOLOGY, - FB, - ) - halpha_L_true = emline_luminosity_pop(*args) - - ( - lgL_bin_edges, - halpha_lf_weighted_q_true, - halpha_lf_weighted_smooth_ms_true, - halpha_lf_weighted_bursty_ms_true, - ) = emline_luminosity_func_pop(halpha_L_true, nhalos) - - halpha_obs_aa = HALPHA_CENTER_AA * (1 + z_obs) - d_L_Mpc = COSMO.luminosity_distance(z_obs) - d_L_cm = d_L_Mpc * MPC_TO_CM - - SXDS_z_mag_ab = line_phot_kern.get_band_mag_ab_from_luminosity( - halpha_obs_aa, - halpha_L_true, - z_obs, - d_L_cm, - HSC_Z_tcurve.wave, - HSC_Z_tcurve.transmission, - ) - - assert np.isfinite(SXDS_z_mag_ab.band_mag_ab_q).all() - assert np.isfinite(SXDS_z_mag_ab.band_mag_ab_smooth_ms).all() - assert np.isfinite(SXDS_z_mag_ab.band_mag_ab_bursty_ms).all() diff --git a/diffhtwo/experimental/utils.py b/diffhtwo/experimental/utils.py index 890f67faa..11ceb0339 100644 --- a/diffhtwo/experimental/utils.py +++ b/diffhtwo/experimental/utils.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import numpy as np +from astropy.table import Table, vstack from jax import jit as jjit from jax.tree_util import tree_flatten_with_path @@ -73,3 +74,37 @@ def get_param_names(params): paths, leaves = tree_flatten_with_path(params) names = [p[0][0].name for p in paths] # each p is (GetAttrKey(...),) return names + + +def add_random_rows(tab, N): + new_rows = {} + for col in tab.colnames: + data = tab[col] + dtype = data.dtype + + if dtype.kind in ("f", "i") and data.ndim == 1: + valid = data[(data != -99) & np.isfinite(data)] + if len(valid) > 0: + lo, hi = valid.min(), valid.max() + new_rows[col] = np.random.uniform(lo, hi, N) + if dtype.kind == "i": + new_rows[col] = new_rows[col].astype(dtype) + else: + new_rows[col] = np.full(N, -99, dtype=dtype) + elif dtype.kind == "f" and data.ndim == 2: + ncols = data.shape[1] + new_col = np.empty((N, ncols), dtype=dtype) + for j in range(ncols): + sub = data[:, j] + valid = sub[(sub != -99) & np.isfinite(sub)] + if len(valid) > 0: + lo, hi = valid.min(), valid.max() + new_col[:, j] = np.random.uniform(lo, hi, N) + else: + new_col[:, j] = -99 + new_rows[col] = new_col + else: + new_rows[col] = data[:N] # fallback for non-numeric columns + + new_tab = Table(new_rows) + return vstack([tab, new_tab]) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index e74b30811..c9e33a641 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run133 -model_nickname: run133_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run133/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run187 +model_nickname: run187_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run187/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks @@ -13,11 +13,11 @@ ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kr plot_sdss: False plot_feniks: True -plot_hizels: True +plot_hizels: False plots: - num_halos : 3000 - plot_insitu_sm: True + num_halos : 1000 + plot_color_contours: True plot_app_mag_funcs: True plot_color_pdfs: True plot_colors_mags: True @@ -25,11 +25,14 @@ plots: plot_ssperr: True plot_massive_cen_colors: True plot_merging_sat_colors: True - plot_satquench: False - plot_satquench_model: True + plot_smhm: True plot_insitu_smhm: True + plot_insitu_sm: True + plot_sm: True plot_uvj: True plot_exsitu_frac: True plot_avpop: True plot_burstpop: True - plot_fburst_mh_z: True \ No newline at end of file + plot_fburst_mh_z: True + plot_satquench_model: True + plot_satquench: False \ No newline at end of file diff --git a/scripts/config_diffsky.yaml b/scripts/config_diffsky.yaml index d367eca25..9b90f0dff 100644 --- a/scripts/config_diffsky.yaml +++ b/scripts/config_diffsky.yaml @@ -13,7 +13,8 @@ sdss: feniks: lh_d_mag: 0.4 N_centroids: 100 - num_halos: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 100 hizels: num_halos: 100 diff --git a/scripts/config_feniks.yaml b/scripts/config_feniks.yaml index 51c5966c3..68151f042 100644 --- a/scripts/config_feniks.yaml +++ b/scripts/config_feniks.yaml @@ -4,17 +4,16 @@ start_runid: "run90" start_fit_type: "all" fit_runid: "runtest" -fit_type: "all" +fit_type: "diffstarpop+spspop+merging" feniks: - lh_d_mag: 0.4 - N_centroids: 2000 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 200 epoch: n_it: 1 n_steps: 2 step_size: 0.1 - num_halos: 200 defaults: diffstarpop: True diff --git a/scripts/config_feniks_lh.yaml b/scripts/config_feniks_lh.yaml new file mode 100644 index 000000000..bcc1cfeff --- /dev/null +++ b/scripts/config_feniks_lh.yaml @@ -0,0 +1,23 @@ +base_path: "/Users/kumail/diffdir" + +start_runid: "run90" +start_fit_type: "all" + +fit_runid: "runtest" +fit_type: "all" + +feniks: + lh_d_mag: 0.4 + N_centroids: 2000 + +epoch: + n_it: 1 + n_steps: 2 + step_size: 0.1 + num_halos: 100 + +defaults: + diffstarpop: True + spspop: True + ssperr: True + merging: True diff --git a/scripts/config_sdss.yaml b/scripts/config_sdss.yaml index 645461fce..a44cb376e 100644 --- a/scripts/config_sdss.yaml +++ b/scripts/config_sdss.yaml @@ -6,12 +6,15 @@ start_fit_type: "all" fit_runid: "runtest" fit_type: "all" +sdss: + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 200 + epoch: n_it: 1 - n_steps: 25 + n_steps: 2 step_size: 0.1 N_centroids: 2000 - num_halos: 300 defaults: diffstarpop: True diff --git a/scripts/config_sdss_feniks.py b/scripts/config_sdss_feniks.py new file mode 100644 index 000000000..9b90f0dff --- /dev/null +++ b/scripts/config_sdss_feniks.py @@ -0,0 +1,31 @@ +base_path: "/Users/kumail/diffdir" + +start_runid: "run90" +start_fit_type: "all" + +fit_runid: "runtest" +fit_type: "all" + +sdss: + N_centroids: 100 + num_halos: 100 + +feniks: + lh_d_mag: 0.4 + N_centroids: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 100 + +hizels: + num_halos: 100 + +epoch: + n_it: 1 + n_steps: 2 + step_size: 0.1 + +defaults: + diffstarpop: True + spspop: True + ssperr: True + merging: True diff --git a/scripts/config_sdss_feniks.yaml b/scripts/config_sdss_feniks.yaml index ab7f2b1db..64f3a54da 100644 --- a/scripts/config_sdss_feniks.yaml +++ b/scripts/config_sdss_feniks.yaml @@ -6,18 +6,21 @@ start_fit_type: "all" fit_runid: "runtest" fit_type: "all" -feniks: - lh_d_mag: 0.6 - N_centroids: 2000 - sdss: - N_centroids: 2000 + N_centroids: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 100 + +feniks: + lh_d_mag: 0.4 + N_centroids: 100 + num_halos_coarse_zbins: 100 + num_halos_fine_zbins: 100 epoch: n_it: 1 n_steps: 2 step_size: 0.1 - num_halos: 200 defaults: diffstarpop: True diff --git a/scripts/fit_diffsky.py b/scripts/fit_diffsky.py index 2e7563fd7..779e2dbde 100644 --- a/scripts/fit_diffsky.py +++ b/scripts/fit_diffsky.py @@ -1,6 +1,7 @@ import argparse import os import time +from collections import namedtuple from datetime import datetime from pathlib import Path @@ -22,8 +23,6 @@ from diffhtwo.experimental import param_utils as pu from diffhtwo.experimental.data_loaders import load_feniks, load_hizels -from diffhtwo.experimental.defaults import FENIKS_Z_MIN -from diffhtwo.experimental.latin_hypercube import lh_utils as lhu from diffhtwo.experimental.optimizers import Np_specphot_opt DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ @@ -51,14 +50,19 @@ ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) halpha_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) - # load sdss data - # ran_key = jran.key(0) - # SDSS = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) - # load feniks data ran_key = jran.key(0) - FENIKS = load_feniks.get_feniks_data( - feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] + feniks = load_feniks.get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} ) # load hizels data @@ -66,7 +70,7 @@ hizels_drn, ran_key, ssp_data, - FENIKS.filter_info.tcurves, + feniks.filter_info.tcurves, halpha_wave_aa, num_halos=cfg["hizels"]["num_halos"], ) @@ -111,45 +115,11 @@ os.system(f"cp {args.config} {fit_diagnostics_save_drn}") - # SDSS - # sdss_z_min = [SDSS_Z_MIN, 0.08, 0.14] - # sdss_z_max = [0.08, 0.14, SDSS_Z_MAX] - - # FENIKS - feniks_z_min = [FENIKS_Z_MIN, 1] - feniks_z_max = [1, 2] - initial_pts = [] start = time.time() for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - # SDSS - # sdss = load_sdss.refresh_lh_centroids(SDSS) - # sdss_meta_data, sdss_fitting_data = lhu.get_zbins_lh_lc( - # ran_key, - # SDSS, - # sdss_z_min, - # sdss_z_max, - # ssp_data, - # cfg["sdss"]["N_centroids"], - # lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", - # num_halos=cfg["sdss"]["num_halos"], - # ) - - # FENIKS - FENIKS = load_feniks.refresh_lh_centroids(FENIKS, cfg["feniks"]["lh_d_mag"]) - feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( - ran_key, - FENIKS, - feniks_z_min, - feniks_z_max, - ssp_data, - cfg["feniks"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", - num_halos=cfg["feniks"]["num_halos"], - ) - ( loss_hist, loss_phot_hist, @@ -159,7 +129,6 @@ u_theta_fit, trainable_params, ran_key, - feniks_meta_data, feniks_fitting_data, hizels_fitting_data, n_steps=cfg["epoch"]["n_steps"], diff --git a/scripts/fit_feniks.py b/scripts/fit_feniks.py index 0e3e7b645..e05914a21 100644 --- a/scripts/fit_feniks.py +++ b/scripts/fit_feniks.py @@ -21,8 +21,6 @@ from diffhtwo.experimental import param_utils as pu from diffhtwo.experimental.data_loaders import load_feniks -from diffhtwo.experimental.defaults import FENIKS_Z_MAX, FENIKS_Z_MIN -from diffhtwo.experimental.latin_hypercube import lh_utils as lhu from diffhtwo.experimental.optimizers import Np_specphot_opt DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ @@ -49,12 +47,6 @@ emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) emline_wave_table = jnp.array([emline_wave_aa]) - # load feniks data - ran_key = jran.key(0) - FENIKS = load_feniks.get_feniks_data( - feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] - ) - # start fit dirs fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" param_collection_fit = lc_mock.load_diffsky_param_collection_merging( @@ -95,36 +87,29 @@ os.system(f"cp {args.config} {fit_diagnostics_save_drn}") - feniks_z_min = [FENIKS_Z_MIN, 1] - feniks_z_max = [1, 2] - initial_pts = [] start = time.time() + ran_key = jran.key(0) for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - FENIKS = load_feniks.refresh_lh_centroids(FENIKS, cfg["feniks"]["lh_d_mag"]) - # FENIKS - feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( + feniks_fitting_data = load_feniks.get_feniks_fitting_data( + feniks_drn, ran_key, - FENIKS, - feniks_z_min, - feniks_z_max, ssp_data, - cfg["feniks"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", - num_halos=cfg["epoch"]["num_halos"], + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], ) - loss_hist, u_theta_fit = Np_specphot_opt.fit_N_multi_z( + loss_hist, u_theta_fit = Np_specphot_opt.fit_N_phot_2d( u_theta_fit, trainable_params, ran_key, - feniks_meta_data, feniks_fitting_data, n_steps=cfg["epoch"]["n_steps"], step_size=cfg["epoch"]["step_size"], ) + jax.clear_caches() param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) diff --git a/scripts/fit_feniks_lh.py b/scripts/fit_feniks_lh.py new file mode 100644 index 000000000..0e3e7b645 --- /dev/null +++ b/scripts/fit_feniks_lh.py @@ -0,0 +1,169 @@ +import argparse +import os +import time +from datetime import datetime + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import yaml +from diffsky.data_loaders.hacc_utils import lc_mock +from diffsky.merging.merging_model import DEFAULT_MERGE_PARAMS +from diffsky.param_utils.spspop_param_utils import DEFAULT_SPSPOP_PARAMS +from diffsky.ssp_err_model.defaults import ZERO_SSPERR_PARAMS +from diffstar.diffstarpop.kernels.params.params_diffstarpopfits_mgash import ( + DiffstarPop_Params_Diffstarpopfits_mgash, +) +from dsps import load_ssp_templates +from dsps.data_loaders import load_emline_info as lemi +from jax import random as jran + +from diffhtwo.experimental import param_utils as pu +from diffhtwo.experimental.data_loaders import load_feniks +from diffhtwo.experimental.defaults import FENIKS_Z_MAX, FENIKS_Z_MIN +from diffhtwo.experimental.latin_hypercube import lh_utils as lhu +from diffhtwo.experimental.optimizers import Np_specphot_opt + +DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ + "galacticus_in_plus_ex_situ" +] + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("--config", default="config_feniks.yaml") + args = p.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + feniks_drn = cfg["base_path"] + "/feniks" + ssp_filename = ( + cfg["base_path"] + + "/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5" + ) + + # get ssp data + ssp_data = load_ssp_templates(fn=ssp_filename) + ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) + emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) + emline_wave_table = jnp.array([emline_wave_aa]) + + # load feniks data + ran_key = jran.key(0) + FENIKS = load_feniks.get_feniks_data( + feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] + ) + + # start fit dirs + fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" + param_collection_fit = lc_mock.load_diffsky_param_collection_merging( + fit_start_drn, + cfg["start_runid"] + "_" + cfg["start_fit_type"], + ) + if cfg["defaults"]["diffstarpop"]: + param_collection_fit = param_collection_fit._replace( + diffstarpop_params=DIFFSTARPOP_GALACTICUS_exsitu + ) + if cfg["defaults"]["spspop"]: + param_collection_fit = param_collection_fit._replace( + spspop_params=DEFAULT_SPSPOP_PARAMS + ) + if cfg["defaults"]["ssperr"]: + param_collection_fit = param_collection_fit._replace( + ssperr_params=ZERO_SSPERR_PARAMS + ) + if cfg["defaults"]["merging"]: + param_collection_fit = param_collection_fit._replace( + merging_params=DEFAULT_MERGE_PARAMS + ) + + u_theta_fit = pu.get_u_theta_from_param_collection(param_collection_fit) + + # fit dirs + trainable_params = pu.get_trainable_params(fit_type=cfg["fit_type"]) + fit_save_drn = cfg["base_path"] + "/fits/" + cfg["fit_runid"] + "/" + fit_diagnostics_save_drn = ( + cfg["base_path"] + + "/fits/" + + cfg["fit_runid"] + + "/diagnostic_plots/" + + cfg["fit_type"] + ) + os.makedirs(fit_diagnostics_save_drn + "/loss", exist_ok=True) + os.makedirs(fit_diagnostics_save_drn + "/lh_N_z", exist_ok=True) + + os.system(f"cp {args.config} {fit_diagnostics_save_drn}") + + feniks_z_min = [FENIKS_Z_MIN, 1] + feniks_z_max = [1, 2] + + initial_pts = [] + start = time.time() + for epoch in range(0, cfg["epoch"]["n_it"]): + print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') + FENIKS = load_feniks.refresh_lh_centroids(FENIKS, cfg["feniks"]["lh_d_mag"]) + + # FENIKS + feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( + ran_key, + FENIKS, + feniks_z_min, + feniks_z_max, + ssp_data, + cfg["feniks"]["N_centroids"], + lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", + num_halos=cfg["epoch"]["num_halos"], + ) + + loss_hist, u_theta_fit = Np_specphot_opt.fit_N_multi_z( + u_theta_fit, + trainable_params, + ran_key, + feniks_meta_data, + feniks_fitting_data, + n_steps=cfg["epoch"]["n_steps"], + step_size=cfg["epoch"]["step_size"], + ) + jax.clear_caches() + + param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) + lc_mock.write_diffsky_param_collection_merging( + fit_save_drn, + cfg["fit_runid"] + "_" + cfg["fit_type"], + param_collection_fit, + ) + + if epoch == 0: + STEPS = np.arange(1, cfg["epoch"]["n_steps"] + 1, 1) + + LOSS_HIST = loss_hist + + initial_pts.append((STEPS[0], LOSS_HIST[0])) + else: + steps = np.arange(STEPS[-1] + 1, STEPS[-1] + cfg["epoch"]["n_steps"] + 1, 1) + initial_pts.append((steps[0], loss_hist[0])) + + STEPS = np.concatenate((STEPS, steps)) + LOSS_HIST = np.concatenate((LOSS_HIST, loss_hist)) + + end = time.time() + elapsed = end - start + print( + f'Gradient descent took {elapsed/60:.3f} minutes for {cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]} steps.' + ) + print(f'speed: {elapsed/(cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]):.3f} s/it') + + # gradient descent figure + fig_loss, ax_loss = plt.subplots(1) + + start_step = [s[0] for s in initial_pts] + start_loss = [s[1] for s in initial_pts] + ax_loss.scatter(start_step, start_loss, s=50, c="deepskyblue") + + ax_loss.plot(STEPS, LOSS_HIST, c="deepskyblue") + ax_loss.set_ylabel("Poisson Loss") + ax_loss.set_xlabel("steps") + ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + plt.savefig(fit_diagnostics_save_drn + "/loss/feniks_loss_" + ts + ".png") + plt.close() diff --git a/scripts/fit_sdss.py b/scripts/fit_sdss.py index f87db9ea6..7dd28dc2a 100644 --- a/scripts/fit_sdss.py +++ b/scripts/fit_sdss.py @@ -1,6 +1,7 @@ import argparse import os import time +from collections import namedtuple from datetime import datetime import jax @@ -21,8 +22,6 @@ from diffhtwo.experimental import param_utils as pu from diffhtwo.experimental.data_loaders import load_sdss -from diffhtwo.experimental.defaults import SDSS_Z_MAX, SDSS_Z_MIN -from diffhtwo.experimental.latin_hypercube import lh_utils as lhu from diffhtwo.experimental.optimizers import Np_specphot_opt DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ @@ -49,10 +48,6 @@ emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) emline_wave_table = jnp.array([emline_wave_aa]) - # load sdss data - ran_key = jran.key(0) - SDSS = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) - # start fit dirs fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" param_collection_fit = lc_mock.load_diffsky_param_collection_merging( @@ -93,32 +88,29 @@ os.system(f"cp {args.config} {fit_diagnostics_save_drn}") - sdss_z_min = [SDSS_Z_MIN, 0.08, 0.14] - sdss_z_max = [0.08, 0.14, SDSS_Z_MAX] - + ran_key = jran.key(0) initial_pts = [] start = time.time() for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - sdss = load_sdss.refresh_lh_centroids(SDSS) - # SDSS - sdss_meta_data, sdss_fitting_data = lhu.get_zbins_lh_lc( + sdss = load_sdss.get_sdss_data( + sdss_drn, ran_key, - SDSS, - sdss_z_min, - sdss_z_max, ssp_data, - cfg["epoch"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", - num_halos=cfg["epoch"]["num_halos"], + num_halos_coarse_zbins=cfg["sdss"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["sdss"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + SdssFitting = namedtuple("Sdss", [s for s in sdss._fields if s not in remove]) + sdss_fitting_data = SdssFitting( + **{s: getattr(sdss, s) for s in SdssFitting._fields} ) - loss_hist, u_theta_fit = Np_specphot_opt.fit_N_multi_z( + loss_hist, u_theta_fit = Np_specphot_opt.fit_N_phot_2d( u_theta_fit, trainable_params, ran_key, - sdss_meta_data, sdss_fitting_data, n_steps=cfg["epoch"]["n_steps"], step_size=cfg["epoch"]["step_size"], diff --git a/scripts/fit_sdss_feniks.py b/scripts/fit_sdss_feniks.py index a13537692..6693802f0 100644 --- a/scripts/fit_sdss_feniks.py +++ b/scripts/fit_sdss_feniks.py @@ -1,8 +1,11 @@ import argparse import os import time +from collections import namedtuple from datetime import datetime +from pathlib import Path +import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np @@ -20,12 +23,6 @@ from diffhtwo.experimental import param_utils as pu from diffhtwo.experimental.data_loaders import load_feniks, load_sdss -from diffhtwo.experimental.defaults import ( - FENIKS_Z_MIN, - SDSS_Z_MAX, - SDSS_Z_MIN, -) -from diffhtwo.experimental.latin_hypercube import lh_utils as lhu from diffhtwo.experimental.optimizers import Np_specphot_opt DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ @@ -42,7 +39,6 @@ sdss_drn = cfg["base_path"] + "/sdss" feniks_drn = cfg["base_path"] + "/feniks" - ssp_filename = ( cfg["base_path"] + "/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5" @@ -51,15 +47,8 @@ # get ssp data ssp_data = load_ssp_templates(fn=ssp_filename) ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) - emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) - emline_wave_table = jnp.array([emline_wave_aa]) + halpha_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) - # load data - ran_key = jran.key(0) - SDSS = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) - FENIKS = load_feniks.get_feniks_data( - feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] - ) # start fit dirs fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" param_collection_fit = lc_mock.load_diffsky_param_collection_merging( @@ -96,98 +85,113 @@ + cfg["fit_type"] ) os.makedirs(fit_diagnostics_save_drn + "/loss", exist_ok=True) - os.makedirs(fit_diagnostics_save_drn + "/sdss_lh_N_z", exist_ok=True) - os.makedirs(fit_diagnostics_save_drn + "/feniks_lh_N_z", exist_ok=True) + os.makedirs(fit_diagnostics_save_drn + "/lh_N_z", exist_ok=True) os.system(f"cp {args.config} {fit_diagnostics_save_drn}") - sdss_z_min = [SDSS_Z_MIN, 0.1] - sdss_z_max = [0.1, SDSS_Z_MAX] - - feniks_z_min = [FENIKS_Z_MIN, 1] - feniks_z_max = [1, 2] - initial_pts = [] start = time.time() + ran_key = jran.key(0) for epoch in range(0, cfg["epoch"]["n_it"]): print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') - sdss = load_sdss.refresh_lh_centroids(SDSS) - # SDSS - sdss_meta_data, sdss_fitting_data = lhu.get_zbins_lh_lc( + # load sdss data + sdss = load_sdss.get_sdss_data( + sdss_drn, ran_key, - SDSS, - sdss_z_min, - sdss_z_max, ssp_data, - cfg["sdss"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/sdss_lh_N_z", - num_halos=cfg["epoch"]["num_halos"], + num_halos_coarse_zbins=cfg["sdss"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["sdss"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + SdssFitting = namedtuple("Sdss", [s for s in sdss._fields if s not in remove]) + sdss_fitting_data = SdssFitting( + **{s: getattr(sdss, s) for s in SdssFitting._fields} ) - # FENIKS - feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( + # load feniks data + feniks = load_feniks.get_feniks_data( + feniks_drn, ran_key, - FENIKS, - feniks_z_min, - feniks_z_max, ssp_data, - cfg["feniks"]["N_centroids"], - lh_N_z_savedir=fit_diagnostics_save_drn + "/feniks_lh_N_z", - num_halos=cfg["epoch"]["num_halos"], + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple( + "Feniks", [f for f in feniks._fields if f not in remove] + ) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} ) - loss_hist, u_theta_fit = Np_specphot_opt.fit_sdss_feniks_hizels( + ( + loss_hist, + loss_sdss_hist, + loss_feniks_hist, + u_theta_fit, + ) = Np_specphot_opt.fit_sdss_feniks( u_theta_fit, trainable_params, ran_key, - sdss_meta_data, sdss_fitting_data, - feniks_meta_data, feniks_fitting_data, - # hizels, - # line_wave_table, n_steps=cfg["epoch"]["n_steps"], step_size=cfg["epoch"]["step_size"], ) + jax.clear_caches() + param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) lc_mock.write_diffsky_param_collection_merging( fit_save_drn, cfg["fit_runid"] + "_" + cfg["fit_type"], param_collection_fit, ) - if epoch == 0: STEPS = np.arange(1, cfg["epoch"]["n_steps"] + 1, 1) - LOSS_HIST = loss_hist - + LOSS_SDSS_HIST = loss_sdss_hist + LOSS_FENIKS_HIST = loss_feniks_hist initial_pts.append((STEPS[0], LOSS_HIST[0])) else: steps = np.arange(STEPS[-1] + 1, STEPS[-1] + cfg["epoch"]["n_steps"] + 1, 1) initial_pts.append((steps[0], loss_hist[0])) - STEPS = np.concatenate((STEPS, steps)) LOSS_HIST = np.concatenate((LOSS_HIST, loss_hist)) - + LOSS_SDSS_HIST = np.concatenate((LOSS_SDSS_HIST, loss_sdss_hist)) + LOSS_FENIKS_HIST = np.concatenate((LOSS_FENIKS_HIST, loss_feniks_hist)) end = time.time() elapsed = end - start print( f'Gradient descent took {elapsed/60:.3f} minutes for {cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]} steps.' ) print(f'speed: {elapsed/(cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]):.3f} s/it') - # gradient descent figure fig_loss, ax_loss = plt.subplots(1) - start_step = [s[0] for s in initial_pts] start_loss = [s[1] for s in initial_pts] - ax_loss.scatter(start_step, start_loss, s=50, c="deepskyblue") - - ax_loss.plot(STEPS, LOSS_HIST, c="deepskyblue") - ax_loss.set_ylabel("Poisson Loss") + ax_loss.scatter(start_step, start_loss, s=50, c="k") + ax_loss.plot(STEPS, LOSS_HIST, c="k", label="total") + ax_loss.plot( + STEPS, + LOSS_SDSS_HIST, + c="#0a7a80", + linestyle="--", + alpha=0.7, + label="sdss", + ) + ax_loss.plot( + STEPS, + LOSS_FENIKS_HIST, + c="#c87820", + linestyle="--", + alpha=0.7, + label="feniks", + ) + ax_loss.legend() + ax_loss.set_ylabel("Poisson Negative Log-Likelihood") ax_loss.set_xlabel("steps") ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") - plt.savefig(fit_diagnostics_save_drn + "/loss/sdss_feniks_loss_" + ts + ".png") + plt.savefig(fit_diagnostics_save_drn + "/loss/loss_" + ts + ".png") plt.close() diff --git a/scripts/fits_sdss_feniks.py b/scripts/fits_sdss_feniks.py new file mode 100644 index 000000000..779e2dbde --- /dev/null +++ b/scripts/fits_sdss_feniks.py @@ -0,0 +1,192 @@ +import argparse +import os +import time +from collections import namedtuple +from datetime import datetime +from pathlib import Path + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import yaml +from diffsky.data_loaders.hacc_utils import lc_mock +from diffsky.merging.merging_model import DEFAULT_MERGE_PARAMS +from diffsky.param_utils.spspop_param_utils import DEFAULT_SPSPOP_PARAMS +from diffsky.ssp_err_model.defaults import ZERO_SSPERR_PARAMS +from diffstar.diffstarpop.kernels.params.params_diffstarpopfits_mgash import ( + DiffstarPop_Params_Diffstarpopfits_mgash, +) +from dsps import load_ssp_templates +from dsps.data_loaders import load_emline_info as lemi +from jax import random as jran + +from diffhtwo.experimental import param_utils as pu +from diffhtwo.experimental.data_loaders import load_feniks, load_hizels +from diffhtwo.experimental.optimizers import Np_specphot_opt + +DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ + "galacticus_in_plus_ex_situ" +] + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("--config", default="config_diffsky.yaml") + args = p.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # sdss_drn = cfg["base_path"] + "/sdss" + feniks_drn = cfg["base_path"] + "/feniks" + hizels_drn = Path(cfg["base_path"] + "/hizels") + ssp_filename = ( + cfg["base_path"] + + "/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5" + ) + + # get ssp data + ssp_data = load_ssp_templates(fn=ssp_filename) + ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) + halpha_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) + + # load feniks data + ran_key = jran.key(0) + feniks = load_feniks.get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=cfg["feniks"]["num_halos_coarse_zbins"], + num_halos_fine_zbins=cfg["feniks"]["num_halos_fine_zbins"], + ) + remove = {"dataset_dim_labels", "mags_labels"} + FeniksFitting = namedtuple("Feniks", [f for f in feniks._fields if f not in remove]) + feniks_fitting_data = FeniksFitting( + **{f: getattr(feniks, f) for f in FeniksFitting._fields} + ) + + # load hizels data + hizels_fitting_data = load_hizels.get_hizels_data( + hizels_drn, + ran_key, + ssp_data, + feniks.filter_info.tcurves, + halpha_wave_aa, + num_halos=cfg["hizels"]["num_halos"], + ) + + # start fit dirs + fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" + param_collection_fit = lc_mock.load_diffsky_param_collection_merging( + fit_start_drn, + cfg["start_runid"] + "_" + cfg["start_fit_type"], + ) + if cfg["defaults"]["diffstarpop"]: + param_collection_fit = param_collection_fit._replace( + diffstarpop_params=DIFFSTARPOP_GALACTICUS_exsitu + ) + if cfg["defaults"]["spspop"]: + param_collection_fit = param_collection_fit._replace( + spspop_params=DEFAULT_SPSPOP_PARAMS + ) + if cfg["defaults"]["ssperr"]: + param_collection_fit = param_collection_fit._replace( + ssperr_params=ZERO_SSPERR_PARAMS + ) + if cfg["defaults"]["merging"]: + param_collection_fit = param_collection_fit._replace( + merging_params=DEFAULT_MERGE_PARAMS + ) + + u_theta_fit = pu.get_u_theta_from_param_collection(param_collection_fit) + + # fit dirs + trainable_params = pu.get_trainable_params(fit_type=cfg["fit_type"]) + fit_save_drn = cfg["base_path"] + "/fits/" + cfg["fit_runid"] + "/" + fit_diagnostics_save_drn = ( + cfg["base_path"] + + "/fits/" + + cfg["fit_runid"] + + "/diagnostic_plots/" + + cfg["fit_type"] + ) + os.makedirs(fit_diagnostics_save_drn + "/loss", exist_ok=True) + os.makedirs(fit_diagnostics_save_drn + "/lh_N_z", exist_ok=True) + + os.system(f"cp {args.config} {fit_diagnostics_save_drn}") + + initial_pts = [] + start = time.time() + for epoch in range(0, cfg["epoch"]["n_it"]): + print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') + + ( + loss_hist, + loss_phot_hist, + loss_emline_hist, + u_theta_fit, + ) = Np_specphot_opt.fit_feniks_hizels( + u_theta_fit, + trainable_params, + ran_key, + feniks_fitting_data, + hizels_fitting_data, + n_steps=cfg["epoch"]["n_steps"], + step_size=cfg["epoch"]["step_size"], + ) + + jax.clear_caches() + + param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) + lc_mock.write_diffsky_param_collection_merging( + fit_save_drn, + cfg["fit_runid"] + "_" + cfg["fit_type"], + param_collection_fit, + ) + if epoch == 0: + STEPS = np.arange(1, cfg["epoch"]["n_steps"] + 1, 1) + LOSS_HIST = loss_hist + LOSS_PHOT_HIST = loss_phot_hist + LOSS_EMLINE_HIST = loss_emline_hist + initial_pts.append((STEPS[0], LOSS_HIST[0])) + else: + steps = np.arange(STEPS[-1] + 1, STEPS[-1] + cfg["epoch"]["n_steps"] + 1, 1) + initial_pts.append((steps[0], loss_hist[0])) + STEPS = np.concatenate((STEPS, steps)) + LOSS_HIST = np.concatenate((LOSS_HIST, loss_hist)) + LOSS_PHOT_HIST = np.concatenate((LOSS_PHOT_HIST, loss_phot_hist)) + LOSS_EMLINE_HIST = np.concatenate((LOSS_EMLINE_HIST, loss_emline_hist)) + end = time.time() + elapsed = end - start + print( + f'Gradient descent took {elapsed/60:.3f} minutes for {cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]} steps.' + ) + print(f'speed: {elapsed/(cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]):.3f} s/it') + # gradient descent figure + fig_loss, ax_loss = plt.subplots(1) + start_step = [s[0] for s in initial_pts] + start_loss = [s[1] for s in initial_pts] + ax_loss.scatter(start_step, start_loss, s=50, c="k") + ax_loss.plot(STEPS, LOSS_HIST, c="k", label="total loss") + ax_loss.plot( + STEPS, + LOSS_PHOT_HIST, + c="#0a7a80", + linestyle="--", + alpha=0.7, + label="phot loss", + ) + ax_loss.plot( + STEPS, + LOSS_EMLINE_HIST, + c="#c87820", + linestyle="--", + alpha=0.7, + label="emline loss", + ) + ax_loss.legend() + ax_loss.set_ylabel("Poisson Negative Log-Likelihood") + ax_loss.set_xlabel("steps") + ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + plt.savefig(fit_diagnostics_save_drn + "/loss/loss_" + ts + ".png") + plt.close() diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 97122cda0..41b3167a1 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -3,6 +3,7 @@ from pathlib import Path import jax.numpy as jnp +import matplotlib.pyplot as plt import numpy as np import yaml from diffsky.data_loaders.hacc_utils import lc_mock @@ -21,6 +22,7 @@ SDSS_Z_MAX, SDSS_Z_MIN, ) +from diffhtwo.experimental.diagnostics import plot_sm from diffhtwo.experimental.diagnostics.plot_avpop_mono import ( make_avpop_mono_comparison_plots, ) @@ -29,6 +31,7 @@ plot_lgfburst_mh_z, ) from diffhtwo.experimental.diagnostics.plot_cen import plot_massive_cen_colors +from diffhtwo.experimental.diagnostics.plot_contour import plot_color_contours from diffhtwo.experimental.diagnostics.plot_halpha import ( plot_halpha, plot_halpha_insitu_exsitu, @@ -36,7 +39,6 @@ plot_halpha_sfr, plot_halpha_ssfr, ) -from diffhtwo.experimental.diagnostics.plot_insitu_sm import plot_insitu_sm from diffhtwo.experimental.diagnostics.plot_phot import ( plot_app_mag_funcs, plot_color_pdfs, @@ -49,6 +51,7 @@ generate_sat_plots, plot_satquench_model, ) +from diffhtwo.experimental.diagnostics.plot_smhm import plot_smhm if __name__ == "__main__": p = argparse.ArgumentParser() @@ -203,16 +206,54 @@ """ if cfg["plot_feniks"]: feniks_label = "feniks" # + cfg["model_nickname"].split("_")[0] - feniks = load_feniks.get_feniks_data(feniks_drn, ran_key, ssp_data) - feniks_zbins = np.array( - [ - [0.2, 0.5], - [0.5, 0.8], - [0.8, 1.2], - [1.2, 1.6], - [1.6, 2.0], - ] + feniks = load_feniks.get_feniks_data( + feniks_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=num_halos, + num_halos_fine_zbins=int(num_halos / 2), ) + + feniks_zbins = feniks.fine_zbins + + if cfg["plots"]["plot_smhm"]: + print("Generating FENIKS SMHM plots...") + plot_smhm( + ran_key, + param_collection_fit, + feniks_zbins, + num_halos, + ssp_data, + feniks.filter_info.tcurves, + feniks_label, + fit_diagnostics_save_drn, + plt_show=False, + ) + plot_smhm( + ran_key, + param_collection_fit, + feniks_zbins, + num_halos, + ssp_data, + feniks.filter_info.tcurves, + feniks_label, + fit_diagnostics_save_drn, + in_situ=True, + plt_show=False, + ) + + if cfg["plots"]["plot_color_contours"]: + print("Generating FENIKS color contour plots...") + plot_color_contours( + ran_key, + param_collection_fit, + feniks.colors, + feniks.filter_info.mag_thresh, + feniks.frac_cat, + feniks_label, + fit_diagnostics_save_drn, + ) + if cfg["plots"]["plot_app_mag_funcs"]: print("Generating FENIKS app mag funcs plot...") plot_app_mag_funcs( @@ -220,6 +261,7 @@ feniks_label, param_collection_fit, ran_key, + feniks_zbins, ssp_data, fit_diagnostics_save_drn, num_halos=num_halos, @@ -255,6 +297,7 @@ + str(FENIKS_Z_MAX), drn_out=fit_diagnostics_save_drn, ) + plt.close() if cfg["plots"]["plot_fburst_mh_z"]: print("Generating FENIKS lgfburst plot...") @@ -281,7 +324,25 @@ print( f"Generating FENIKS in-situ sm plot for {zbin+1}/{len(feniks_zbins)} z-bin..." ) - plot_insitu_sm( + plot_sm.plot_insitu_sm_obs( + ran_key, + param_collection_fit, + z_min, + z_max, + feniks.dataset_dim_labels, + ssp_data, + feniks.filter_info.tcurves, + feniks_label, + fit_diagnostics_save_drn, + num_halos=num_halos, + plt_show=False, + ) + + if cfg["plots"]["plot_sm"]: + print( + f"Generating FENIKS in+ex-situ sm plot for {zbin+1}/{len(feniks_zbins)} z-bin..." + ) + plot_sm.plot_sm_obs( ran_key, param_collection_fit, z_min, @@ -308,6 +369,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_colors_mags"]: @@ -323,6 +385,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_mags"]: @@ -335,6 +398,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_ssperr"]: @@ -428,16 +492,66 @@ """ if cfg["plot_sdss"]: sdss_label = "sdss" # + cfg["model_nickname"].split("_")[0] - sdss = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) - sdss_zbins = np.array( - [ - [0.02, 0.06], - [0.06, 0.1], - [0.1, 0.14], - [0.14, 0.18], - [0.18, 0.2], - ] + sdss = load_sdss.get_sdss_data( + sdss_drn, + ran_key, + ssp_data, + num_halos_coarse_zbins=num_halos, + num_halos_fine_zbins=int(num_halos / 2), ) + sdss_zbins = sdss.fine_zbins + + if cfg["plots"]["plot_smhm"]: + print("Generating SDSS SMHM plots...") + plot_smhm( + ran_key, + param_collection_fit, + sdss_zbins, + num_halos, + ssp_data, + sdss.filter_info.tcurves, + sdss_label, + fit_diagnostics_save_drn, + plt_show=False, + ) + plot_smhm( + ran_key, + param_collection_fit, + sdss_zbins, + num_halos, + ssp_data, + sdss.filter_info.tcurves, + sdss_label, + fit_diagnostics_save_drn, + in_situ=True, + plt_show=False, + ) + + if cfg["plots"]["plot_color_contours"]: + print("Generating SDSS color contour plots...") + plot_color_contours( + ran_key, + param_collection_fit, + sdss.colors, + sdss.filter_info.mag_thresh, + sdss.frac_cat, + sdss_label, + fit_diagnostics_save_drn, + ) + + if cfg["plots"]["plot_app_mag_funcs"]: + print("Generating SDSS app mag funcs plot...") + plot_app_mag_funcs( + sdss, + sdss_label, + param_collection_fit, + ran_key, + sdss_zbins, + ssp_data, + fit_diagnostics_save_drn, + num_halos=num_halos, + plt_show=False, + ) if cfg["plots"]["plot_exsitu_frac"]: print("Generating SDSS ex-situ frac plot...") @@ -456,6 +570,7 @@ + str(SDSS_Z_MAX), drn_out=fit_diagnostics_save_drn, ) + plt.close() if cfg["plots"]["plot_fburst_mh_z"]: print("Generating SDSS lgfburst plot...") @@ -482,7 +597,25 @@ print( f"Generating SDSS in-situ sm plot for {zbin+1}/{len(sdss_zbins)} z-bin..." ) - plot_insitu_sm( + plot_sm.plot_insitu_sm_obs( + ran_key, + param_collection_fit, + z_min, + z_max, + sdss.dataset_dim_labels, + ssp_data, + sdss.filter_info.tcurves, + sdss_label, + fit_diagnostics_save_drn, + num_halos=num_halos, + plt_show=False, + ) + + if cfg["plots"]["plot_sm"]: + print( + f"Generating SDSS in+ex-situ sm plot for {zbin+1}/{len(sdss_zbins)} z-bin..." + ) + plot_sm.plot_sm_obs( ran_key, param_collection_fit, z_min, @@ -509,6 +642,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_colors_mags"]: @@ -524,6 +658,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_mags"]: @@ -536,6 +671,7 @@ z_max, ssp_data, fit_diagnostics_save_drn, + num_halos=num_halos, ) if cfg["plots"]["plot_ssperr"]: