Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 72 additions & 72 deletions swvo/io/RBMDataSet/interp_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from enum import Enum
from functools import partial
from multiprocessing import Pool
from typing import Literal
from typing import TYPE_CHECKING, Literal, TypeAlias, cast

import numpy as np
from numpy.typing import NDArray
Expand All @@ -19,11 +19,14 @@
from swvo.io.RBMDataSet import RBMDataSet


class TargetType(Enum):
class TargetType(Enum): # noqa: D101
TargetPairs = 0
TargetMeshGrid = 1


TARGETS: TypeAlias = list[tuple[float | int, float | int]]


def _linear_interp(
flux_left: float,
flux_right: float,
Expand All @@ -48,18 +51,14 @@ def _interp_flux_parallel(
# find left and right alpha indices
# first find the two al levels, where en points must exist

al_right_idx = np.searchsorted(
alpha_eq_model[it, :], target_al_single, side="right"
)
al_right_idx = np.searchsorted(alpha_eq_model[it, :], target_al_single, side="right")
al_left_idx = al_right_idx - 1

if al_right_idx == 0 or al_right_idx >= len(alpha_eq_model[it, :]):
result.append(np.nan)
continue

finite_idx = np.argwhere(
np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_left_idx])
)
finite_idx = np.argwhere(np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_left_idx]))
if finite_idx.size == 0:
result.append(np.nan)
continue
Expand All @@ -68,15 +67,9 @@ def _interp_flux_parallel(
flux_interp = np.squeeze(flux[it, finite_idx, al_left_idx])
assert np.all(np.diff(energy_interp) > 0)

flux_left = float(
np.interp(
target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan
)
)
flux_left = float(np.interp(target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan))

finite_idx = np.argwhere(
np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_right_idx])
)
finite_idx = np.argwhere(np.isfinite(energy[it, :]) & np.isfinite(flux[it, :, al_right_idx]))
if finite_idx.size == 0:
result.append(np.nan)
continue
Expand All @@ -85,11 +78,7 @@ def _interp_flux_parallel(
flux_interp = np.squeeze(flux[it, finite_idx, al_right_idx])
assert np.all(np.diff(energy_interp) > 0)

flux_right = float(
np.interp(
target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan
)
)
flux_right = float(np.interp(target_en_single, energy_interp, flux_interp, left=np.nan, right=np.nan))

result.append(
_linear_interp(
Expand All @@ -104,13 +93,14 @@ def _interp_flux_parallel(
return result


def interp_flux(
def interp_flux( # noqa: D103
self: RBMDataSet,
target_en: float | list[float] | NDArray[np.float64],
target_al: float | list[float],
target_type: TargetType|Literal["TargetPairs", "TargetMesh"],
target_type: TargetType | Literal["TargetPairs", "TargetMesh"],
n_threads: int = 10,
) -> NDArray[np.float64]:

if not isinstance(target_en, Iterable):
target_en = [target_en]
if not isinstance(target_al, Iterable):
Expand All @@ -125,10 +115,10 @@ def interp_flux(
), "For TargetType.Pairs, the target vectors must have the same size!"

result_arr = np.empty((len(self.time), len(target_en))) # ty:ignore[invalid-argument-type]
targets = list(zip(target_en, target_al))
targets = cast("TARGETS", list(zip(target_en, target_al, strict=False)))
else:
result_arr = np.empty((len(self.time), len(target_en), len(target_al))) # ty:ignore[invalid-argument-type]
targets = list(itertools.product(target_en, target_al))
targets = cast("TARGETS", list(itertools.product(target_en, target_al)))

func = partial(
_interp_flux_parallel,
Expand Down Expand Up @@ -165,20 +155,22 @@ def interp_flux(
result_arr[i, t] = parallel_results[i][t]
else:
for ie, ia in itertools.product(
range(len(target_en)), range(len(target_al)) # ty:ignore[invalid-argument-type]
range(len(target_en)), # ty:ignore[invalid-argument-type]
range(len(target_al)), # ty:ignore[invalid-argument-type]
):
result_arr[i, ie, ia] = parallel_results[i][ie * len(target_al) + ia] # ty:ignore[invalid-argument-type]

return result_arr


def _interp_psd_parallel(psd: NDArray[np.float64],
invmu: NDArray[np.float64],
invk: NDArray[np.float64],
targets: list[tuple[float, float]],
it: int) -> list[float]:
"""
Interpolate PSD at time index `it` to (mu_target, K_target) pairs in `targets`.
def _interp_psd_parallel(
psd: NDArray[np.float64],
invmu: NDArray[np.float64],
invk: NDArray[np.float64],
targets: list[tuple[float, float]],
it: int,
) -> list[float]:
"""Interpolate PSD at time index `it` to (mu_target, K_target) pairs in `targets`.

Shapes per time slice:
psd[it] -> (nE, nA)
Expand All @@ -188,84 +180,90 @@ def _interp_psd_parallel(psd: NDArray[np.float64],
out: list[float] = []

# ---- 0) Extract this time slice
psd_i = psd[it, :, :] # (nE, nA)
mu_i = invmu[it, :, :] # (nE, nA)
K_row = invk[it, :] # (nA,)
psd_i = psd[it, :, :] # (nE, nA)
mu_i = invmu[it, :, :] # (nE, nA)
K_row = invk[it, :] # (nA,)

# ---- 1) Drop NaN K bins and the corresponding columns in PSD/mu
finite_k = np.isfinite(K_row)
if not np.any(finite_k):
# No valid K at this time -> all NaN
return [np.nan] * len(targets)

K_use = K_row[finite_k] # (nA_valid,)
psd_use = psd_i[:, finite_k] # (nE, nA_valid)
mu_use = mu_i[:, finite_k] # (nE, nA_valid)
K_use = K_row[finite_k] # (nA_valid,)
psd_use = psd_i[:, finite_k] # (nE, nA_valid)
mu_use = mu_i[:, finite_k] # (nE, nA_valid)

# If after masking we have fewer than 2 K points, we cannot bracket
if K_use.size < 2:
return [np.nan] * len(targets)

# ---- 2) Ensure K ascending for searchsorted; if descending, flip columns
if K_use[1] < K_use[0]:
K_use = K_use[::-1]
K_use = K_use[::-1]
psd_use = psd_use[:, ::-1]
mu_use = mu_use[:, ::-1]
mu_use = mu_use[:, ::-1]

# ---- 3) For each (mu*, K*) target: 1D along mu, then linear across K
for _, (mu_t, K_t) in enumerate(targets):

# 3a) Bracket in K
k_right = np.searchsorted(K_use, K_t, side='right')
k_left = k_right - 1
k_right = np.searchsorted(K_use, K_t, side="right")
k_left = k_right - 1
if k_right == 0 or k_right >= K_use.size:
out.append(np.nan)
continue

# 3b) Interp along mu at LEFT K
mu_L = mu_use[:, k_left]
mu_L = mu_use[:, k_left]
psd_L = psd_use[:, k_left]
okL = np.isfinite(mu_L) & np.isfinite(psd_L)
okL = np.isfinite(mu_L) & np.isfinite(psd_L)
if not np.any(okL):
out.append(np.nan); continue
out.append(np.nan)
continue

xL = np.asarray(mu_L[okL], dtype=float)
xL = np.asarray(mu_L[okL], dtype=float)
yL = np.asarray(psd_L[okL], dtype=float)
if xL.size < 2:
out.append(np.nan); continue
out.append(np.nan)
continue
if not np.all(np.diff(xL) > 0):
order = np.argsort(xL)
xL, yL = xL[order], yL[order]
xL, idx = np.unique(xL, return_index=True)
yL = yL[idx]
if xL.size < 2:
out.append(np.nan); continue
out.append(np.nan)
continue

psd_left = float(np.interp(mu_t, xL, yL, left=np.nan, right=np.nan))

# 3c) Interp along mu at RIGHT K
mu_R = mu_use[:, k_right]
mu_R = mu_use[:, k_right]
psd_R = psd_use[:, k_right]
okR = np.isfinite(mu_R) & np.isfinite(psd_R)
okR = np.isfinite(mu_R) & np.isfinite(psd_R)
if not np.any(okR):
out.append(np.nan); continue
out.append(np.nan)
continue

xR = np.asarray(mu_R[okR], dtype=float)
xR = np.asarray(mu_R[okR], dtype=float)
yR = np.asarray(psd_R[okR], dtype=float)
if xR.size < 2:
out.append(np.nan); continue
out.append(np.nan)
continue
if not np.all(np.diff(xR) > 0):
order = np.argsort(xR)
xR, yR = xR[order], yR[order]
xR, idx = np.unique(xR, return_index=True)
yR = yR[idx]
if xR.size < 2:
out.append(np.nan); continue
out.append(np.nan)
continue

psd_right = float(np.interp(mu_t, xR, yR, left=np.nan, right=np.nan))

if not np.isfinite(psd_left) or not np.isfinite(psd_right):
out.append(np.nan); continue
out.append(np.nan)
continue

# 3d) Linear across K to K_t
val = _linear_interp(psd_left, psd_right, K_t, K_use[k_left], K_use[k_right])
Expand All @@ -274,19 +272,19 @@ def _interp_psd_parallel(psd: NDArray[np.float64],
return out


def interp_psd(self: RBMDataSet,
target_mu: float | list[float] | NDArray[np.float64],
target_K: float | list[float] | NDArray[np.float64],
target_type: TargetType|Literal["TargetPairs", "TargetMesh"],
n_threads: int = 10) -> NDArray[np.float64]:
"""
Interpolate PSD to requested (mu, K) targets for every time.
def interp_psd(
self: RBMDataSet,
target_mu: float | list[float] | NDArray[np.float64],
target_K: float | list[float] | NDArray[np.float64],
target_type: TargetType | Literal["TargetPairs", "TargetMesh"],
n_threads: int = 10,
) -> NDArray[np.float64]:
"""Interpolate PSD to requested (mu, K) targets for every time.

Output shapes (matching interp_flux semantics):
- TargetPairs -> (time, N)
- TargetMeshGrid -> (time, n_mu, n_K)
"""

if not isinstance(target_mu, Iterable):
target_mu = [target_mu]
if not isinstance(target_K, Iterable):
Expand All @@ -296,16 +294,17 @@ def interp_psd(self: RBMDataSet,
target_type = TargetType[target_type]

if target_type == TargetType.TargetPairs:
assert len(target_mu) == len(target_K), \
"For TargetType.Pairs, mu and K vectors must have the same size!" # ty:ignore[invalid-argument-type]
assert len(target_mu) == len(target_K), "For TargetType.Pairs, mu and K vectors must have the same size!" # ty:ignore[invalid-argument-type]
result_arr = np.empty((len(self.time), len(target_mu))) # ty:ignore[invalid-argument-type]
targets = list(zip(target_mu, target_K))
targets = cast("TARGETS", list(zip(target_mu, target_K, strict=False)))
else:
result_arr = np.empty((len(self.time), len(target_mu), len(target_K))) # ty:ignore[invalid-argument-type]
targets = list(itertools.product(target_mu, target_K))
targets = cast("TARGETS", list(itertools.product(target_mu, target_K)))

# ensure needed fields are loaded (triggers lazy loader if any)
_ = self.PSD; _ = self.InvMu; _ = self.InvK
_ = self.PSD
_ = self.InvMu
_ = self.InvK

# parallel over time (same pattern as interp_flux)
func = partial(_interp_psd_parallel, self.PSD, self.InvMu, self.InvK, targets)
Expand All @@ -317,8 +316,9 @@ def interp_psd(self: RBMDataSet,
total_elements = rs._number_left # ty:ignore[unresolved-attribute]
with tqdm(total=total_elements) as t:
while True:
if rs.ready(): break
t.n = (total_elements - rs._number_left) # ty:ignore[unresolved-attribute]
if rs.ready():
break
t.n = total_elements - rs._number_left # ty:ignore[unresolved-attribute]
t.refresh()
time.sleep(1)
else:
Expand Down
Loading
Loading