diff --git a/.github/workflows/continuous_integration.yml b/.github/workflows/continuous_integration.yml index ed0ab75f34..a688d5b959 100644 --- a/.github/workflows/continuous_integration.yml +++ b/.github/workflows/continuous_integration.yml @@ -25,29 +25,14 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install POSYDON without extras + - name: Install POSYDON with test dependencies run: | python -m pip install --upgrade pip - pip install . + pip install .[test] - name: Run all tests in posydon/unit_tests run: | - # python -m pip install --upgrade pip - # pip install . - pip install pytest - pip install pytest-cov export PATH_TO_POSYDON=./ export PATH_TO_POSYDON_DATA=./posydon/unit_tests/_data/ export MESA_DIR=./ - python -m pytest posydon/unit_tests/ \ - --cov=posydon.config \ - --cov=posydon.utils \ - --cov=posydon.grids \ - --cov=posydon.popsyn.IMFs \ - --cov=posydon.popsyn.norm_pop \ - --cov=posydon.popsyn.distributions \ - --cov=posydon.popsyn.star_formation_history \ - --cov=posydon.CLI \ - --cov-branch \ - --cov-report term-missing \ - --cov-fail-under=100 + pytest # run and coverage parameters are defined in pyproject.toml diff --git a/.gitignore b/.gitignore index 45e334668d..63e5cf387d 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,7 @@ instance/ # Sphinx documentation docs/_build/ docs/_tmp/ +docs/_source/_build/ docs/_source/api_reference docs/checkautodoc *.h5 diff --git a/.gitmodules b/.gitmodules index 6cda6aff11..1cfccfceda 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,3 @@ [submodule "grid_params/POSYDON-MESA-INLISTS"] path = grid_params/POSYDON-MESA-INLISTS url = https://github.com/POSYDON-code/POSYDON-MESA-INLISTS.git -[submodule "data/POSYDON_data"] - path = data/POSYDON_data - url = https://github.com/POSYDON-code/POSYDON_data.git -[submodule "posydon/tests/data/POSYDON-UNIT-TESTS"] - path = posydon/tests/data/POSYDON-UNIT-TESTS - url = https://github.com/POSYDON-code/POSYDON-UNIT-TESTS.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c8a697f65a..5e48ccfb0f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,7 @@ repos: - id: requirements-txt-fixer - id: trailing-whitespace - repo: https://github.com/pycqa/isort - rev: 8.0.1 + rev: 9.0.0a3 hooks: - id: isort args: diff --git a/bin/posydon-popsyn b/bin/posydon-popsyn index a0ca33d36a..7e2a0d9c25 100644 --- a/bin/posydon-popsyn +++ b/bin/posydon-popsyn @@ -71,6 +71,16 @@ if __name__ == '__main__': '--account', help='the account you would like to use', default=None) + setup_parser.add_argument( + '--max_concurrent_jobs', + help='the maximum number of concurrent jobs to run in the job array', + type=int, + default=None) + setup_parser.add_argument( + '--exclude', + help='a comma-separated list of nodes to exclude when submitting the job', + type=str, + default=None) setup_parser.set_defaults(func=setup_popsyn_function) # Check the run subcommand @@ -110,6 +120,16 @@ if __name__ == '__main__': '--account', help='the account you would like to use', default=None) + check_parser.add_argument( + '--max_concurrent_jobs', + help='the maximum number of concurrent jobs to run in the job array', + type=int, + default=None) + check_parser.add_argument( + '--exclude', + help='a comma-separated list of nodes to exclude when submitting the job', + type=str, + default=None) check_parser.set_defaults(func=check_popsyn_function) # Rescue the run subcommand (DEPRECATED) @@ -149,6 +169,16 @@ if __name__ == '__main__': '--account', help='the account you would like to use', default=None) + rescue_parser.add_argument( + '--max_concurrent_jobs', + help='the maximum number of concurrent jobs to run in the job array', + type=int, + default=None) + rescue_parser.add_argument( + '--exclude', + help='a comma-separated list of nodes to exclude when submitting the job', + type=str, + default=None) rescue_parser.set_defaults(func=rescue_popsyn_function) args = parser.parse_args() diff --git a/conda/meta.yaml b/conda/meta.yaml index 7586866e6d..751d76aec0 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -1,9 +1,8 @@ {% set name = "posydon" %} -{% set version = "2.2.8" %} package: name: "{{ name|lower }}" - version: "{{ version }}" + version: {{ GIT_DESCRIBE_TAG }} source: path: .. @@ -17,7 +16,8 @@ requirements: host: - pip - python==3.11 - - setuptools>=38.2.5 + - setuptools>=76.0.0 + - setuptools-scm>=8.0 run: - python==3.11 diff --git a/data/POSYDON_data b/data/POSYDON_data deleted file mode 160000 index e5d8d77985..0000000000 --- a/data/POSYDON_data +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e5d8d77985fc1502b6b6cc0400577623d50743ab diff --git a/dev-tools/.gitignore b/dev-tools/.gitignore new file mode 100644 index 0000000000..c5bb40a839 --- /dev/null +++ b/dev-tools/.gitignore @@ -0,0 +1,7 @@ +workdirs/ +output/ +logs/ +baselines/ +test_*.h5 +*.ini +*.txt diff --git a/dev-tools/README.md b/dev-tools/README.md new file mode 100644 index 0000000000..a5984cb164 --- /dev/null +++ b/dev-tools/README.md @@ -0,0 +1,210 @@ +Validation suite for POSYDON binary evolution. Evolves a fixed set of test binaries on a candidate branch and compares results against a stored baseline to catch regressions. A baseline can be formed from any branch (`main` by default) and is represented by a set of results from `binary_suite.py`, saved HDF5 files, all stored in `dev-tools/baselines/`. + +## Quick Start + +```bash +# 1. Generate a baseline from the main branch (once) +./generate_baseline.sh main + +# 2. Validate a candidate branch against that baseline +./validate_binaries.sh feature/my-branch +``` + +Results are written to `outputs//`. After validation, check: + +- `outputs//comparison_summary.txt` for a pass/fail overview across all metallicities +- `outputs//comparison_Zsun.txt` for detailed per-metallicity diff reports +- `logs//evolve_Zsun.log` for the full evolution output of each metallicity + +By default, all eight POSYDON metallicities are run. To validate only a subset, pass a quoted space-separated list as the third argument: + +```bash +./generate_baseline.sh main "" "1 0.45" +./validate_binaries.sh feature/my-branch main "1 0.45" +``` + +To re-run comparison with different tolerances without re-evolving: + +```bash +./validate_binaries.sh feature/my-branch main "1 0.45" --skip-evolve --loose +``` + +## Scripts + +### `validate_binaries.sh` + +Top-level entry point. Evolves test binaries on a candidate branch, then compares results against an existing baseline. This script will look for baseline HDF5 files stored in `dev-tools/baseline/`, where `main` is the default ``. + +```bash +./validate_binaries.sh [baseline_branch] [metallicities] [--loose] [--rtol VALUE] [--atol VALUE] [--skip-evolve] +``` + +By default, comparison is exact (rtol=0, atol=0). Use `--loose` for relaxed floating-point tolerances (rtol=1e-12, atol=1e-15), or set `--rtol`/`--atol` explicitly as per [np.allclose](https://numpy.org/devdocs/reference/generated/numpy.allclose.html). Use `--skip-evolve` to skip the evolution step and compare existing candidate outputs against the baseline. + +### `generate_baseline.sh` + +Generates baseline HDF5 files from a designated branch name and optionally a SHA to specify a commit. + +```bash +./generate_baseline.sh [sha] [metallicities] +``` + +If you already have results from prior runs of `evolve_binaries.sh` saved as HDF5 files in `outputs//`, you can copy these directly into the baselines directory with the `--promote` option, skipping re-evolution: + +```bash +./generate_baseline.sh --promote [metallicities] +``` + +### `evolve_binaries.sh` + +Clones a POSYDON branch, creates a conda environment, installs POSYDON, and runs the binary suite at all requested metallicities. Called by `validate_binaries.sh` and `generate_baseline.sh`; can also be run standalone. Records the resolved commit SHA and branch name in each HDF5 file's metadata for provenance tracking. + +```bash +./evolve_binaries.sh [sha] [metallicities] +``` + +### `binaries_suite.py` + +Defines and evolves the set of 44 test binaries at a given metallicity. Each binary targets a specific edge case or past bug fix (e.g., matching failures, oRLO2 looping, SN type errors, NaN spins). Results are saved to an HDF5 file with a `/metadata` table that records metallicity, binary counts, `PATH_TO_POSYDON_DATA`, and optionally branch name, commit SHA, and generation timestamp (via `--branch`/`--sha`). + +```bash +python binaries_suite.py --output results.h5 --metallicity 1 +python binaries_suite.py --output results.h5 --metallicity 1 --branch main --sha abc123f +``` + +### `compare_runs.py` + +Compares two HDF5 files produced by `binaries_suite.py` and reports differences in three categories: + +- **Structural**: missing or extra binaries, evolution step count changes, binaries that newly fail or newly pass, missing HDF5 tables. +- **Qualitative**: changes to categorical columns such as state, event, step name, SN type, interpolation class, and mass transfer history. +- **Quantitative**: changes to any numeric column. By default, comparison is exact (bitwise identical floats). Use `--loose` for slightly relaxed tolerances (rtol=1e-12, atol=1e-15), or set `--rtol`/`--atol` explicitly as per [np.allclose](https://numpy.org/devdocs/reference/generated/numpy.allclose.html). + +The script also compares warning and error tables, reporting new, removed, or changed warnings per binary. The report header includes provenance metadata (branch, commit SHA, generation time, POSYDON data path) read from each file when available. + +```bash +python compare_runs.py baseline.h5 candidate.h5 [--loose] [--rtol VALUE] [--atol VALUE] [--verbose] +``` + +### `binaries_params.ini` + +Configuration file for `SimulationProperties`. Defines the POSYDON evolution steps, supernova prescriptions, common envelope parameters, and output column selections. Metallicity is overridden at runtime by `binaries_suite.py`. + +## Running Scripts Manually + +The shell scripts handle cloning, environment setup, orchestration, and execution. If you already have POSYDON installed in your current environment, you can execute the Python scripts directly. + +### Evolving binaries + +```bash +# Evolve all 44 test binaries at solar metallicity +python binaries_suite.py --output my_results.h5 --metallicity 1 + +# Evolve at a specific metallicity with verbose output +python binaries_suite.py --output my_results.h5 --metallicity 0.01 --verbose + +# Use a custom ini file +python binaries_suite.py --output my_results.h5 --metallicity 1 --ini /path/to/custom.ini + +# Record branch/SHA provenance in HDF5 metadata (done automatically by evolve_binaries.sh) +python binaries_suite.py --output my_results.h5 --metallicity 1 --branch main --sha abc123f +``` + +The output HDF5 contains three tables: `evolution` (per-step binary data), `errors` (binaries that failed), and `warnings` (warnings raised during evolution). The `/metadata` table records metallicity, binary counts, `PATH_TO_POSYDON_DATA`, and optionally branch, commit SHA, and generation timestamp. + +### Comparing two result files + +```bash +# Exact comparison +python compare_runs.py file_a.h5 file_b.h5 + +# Relaxed tolerances +python compare_runs.py file_a.h5 file_b.h5 --loose + +# Custom tolerances with verbose diagnostics +python compare_runs.py file_a.h5 file_b.h5 --rtol 1e-8 --atol 1e-12 --verbose +``` + +The two files do not need to come from the shell pipeline; any pair of HDF5 files produced by `binaries_suite.py` can be compared. + +## Directory Structure + +``` +dev-tools/ +├── README.md +├── validate_binaries.sh # full validation pipeline +├── generate_baseline.sh # create or promote baselines +├── evolve_binaries.sh # clone, install, and run suite +├── binaries_suite.py # test binary definitions and evolution +├── binaries_params.ini # SimulationProperties configuration +├── compare_runs.py # diff two HDF5 result files +├── baselines/ # stored baseline HDF5 files (per branch) +├── outputs/ # candidate evolution results (per branch) +├── logs/ # per-metallicity evolution logs (per branch) +└── workdirs/ # cloned repos and conda environments (per branch) +``` + +## Interpreting Results + +The comparison report groups differences into four categories. Here's how to read them: + +**Structural** differences (missing/extra binaries, step count changes, newly failing/passing) almost always indicate a real change. A binary that newly fails or changes its number of evolution steps means the code is following a different evolutionary path. These warrant investigation regardless of tolerance settings. + +**Qualitative** differences (state, event, step name, SN type) also represent real behavioral changes. Even a single qualitative diff means the binary is being classified differently, e.g. a different mass transfer history or a changed SN type. These are never tolerance-dependent. + +**Quantitative** differences are more nuanced. With exact comparison (the default), any floating-point difference is reported. This is useful for detecting unintended changes, but expected after compiler/platform changes or numpy version bumps. If you see many quantitative diffs but zero structural/qualitative diffs, the evolution paths are the same and the differences are likely numerical noise — re-run with `--loose` or a custom `--rtol` to confirm. If quantitative diffs persist at `--rtol 1e-6` or larger, something meaningful has changed. + +**Warning** differences are informational. New warnings may indicate a physics edge case being hit differently, or changes to warnings in the code, but are not failures on their own. + +A healthy validation run after a non-physics code change should show zero structural and qualitative differences, and minimal quantitative changes. After an intentional physics change, expect significant diffs. If binaries unrelated to intentional physics changes show structural or qualitative diffs, that may indicate a problem with implementation. + +## Tolerance Design + +By default, comparison is exact (`rtol=0, atol=0`): any bitwise difference in a float is reported. The `--loose` flag sets `rtol=1e-12, atol=1e-15`, which is appropriate for filtering out platform-level floating-point noise while still catching meaningful changes. + +For custom tolerances, `--rtol` and `--atol` follow the semantics of `np.allclose`: a value passes if `abs(baseline - candidate) <= atol + rtol * abs(baseline)`. In practice, `rtol` dominates for most columns (masses, periods, separations are all large numbers), while `atol` only matters near zero (e.g., eccentricity, certain hydrogen fractions). + +Known limitation: when `baseline == 0` and `candidate != 0`, `rtol`-based comparison produces `0 + rtol * 0 = 0`, so any nonzero candidate value fails. This is correct behavior (a zero-to-nonzero change is meaningful), but be aware that the reverse (both values very small but nonzero) may pass even if the relative change is large, since `atol` provides a floor. For most POSYDON quantities this is not an issue, but it matters for quantities that are genuinely expected to be zero (e.g., eccentricity at ZAMS for circular binaries). + +A single global tolerance works well for catching regressions but is a blunt instrument for columns spanning many orders of magnitude. Per-column or per-quantity scaling is a possible future improvement but is not currently implemented. + +The `--loose` defaults (`rtol=1e-12, atol=1e-15`) were chosen just above float64 machine epsilon and may need to be adjusted if there are parts of the code that are non-deterministic. If parts of the POSYDON pipeline in the branches being tested introduce stochasticity (e.g. unseeded RNG), the irreducible noise floor may be higher. To calibrate, run the same branch against itself and check what tolerance is needed for a clean pass: + +```bash +# Evolve the same branch twice under different output names +python binaries_suite.py --output /tmp/run_a.h5 --metallicity 1 +python binaries_suite.py --output /tmp/run_b.h5 --metallicity 1 + +# Compare — any diffs here are the stochasticity floor +python compare_runs.py /tmp/run_a.h5 /tmp/run_b.h5 + +# Find the tolerance that absorbs the noise +python compare_runs.py /tmp/run_a.h5 /tmp/run_b.h5 --rtol 1e-10 +``` + +The `--loose` defaults should sit just above whatever self-comparison noise you observe. If the self-comparison is clean at exact, the current defaults are fine. + +**RNG reproducibility.** Several POSYDON evolution steps (Bondi-Hoyle accretion in `step_detached` and `MesaGridStep`, SN kicks in `step_SN`) use random number generation internally. Without a fixed seed, these produce nondeterministic results that appear as spurious `S1_lg_mdot` diffs in the validation suite. To ensure reproducibility, set the `entropy` parameter to a fixed integer in `binaries_params.ini`. This seeds the RNG passed to each step (see PR#826). + +## Updating the Baseline + +The baseline should be regenerated when the "expected correct" output changes. Typical triggers: + +- **After a release or version tag.** Generate a baseline from the release tag so future development is compared against the release state: `./generate_baseline.sh v2.3` +- **After merging an intentional physics change.** If a PR deliberately changes evolution outcomes (e.g., a new SN prescription), validate the PR branch first to confirm only the expected binaries are affected, then regenerate the baseline from the updated main branch. +- **After updating POSYDON data grids.** Grid changes will alter interpolated values. Regenerate the baseline and record the new `PATH_TO_POSYDON_DATA` in `baseline_info.txt`. + +Do not regenerate the baseline to silence unexpected diffs. If a validation run shows differences you don't understand, investigate them before updating the baseline. + +The `--promote` flag on `generate_baseline.sh` is a convenience for skipping re-evolution when you've already run the suite and are satisfied with the outputs: `./generate_baseline.sh --promote main "1 0.45"`. + +## Adding New Test Binaries + +New binaries are added by appending entries to the `get_test_binaries()` function in `binaries_suite.py`. Each entry is a tuple of `(star1_kwargs, star2_kwargs, binary_kwargs, description)`. + +When adding a binary: + +- Choose initial conditions that reliably trigger the edge case or evolutionary pathway you want to test. Verify it does so at multiple metallicities if possible, since grid coverage varies. +- Use a descriptive string that references the PR or issue number if the binary guards a specific fix (e.g., `"PR574 - stepCE fix"`). +- After adding the binary, regenerate the baseline so it includes the new binary's expected output. +- The binary ID is assigned by list position. Appending to the end avoids changing IDs of existing binaries, which would invalidate old baselines against new code for no reason. diff --git a/dev-tools/evolve_binaries.sh b/dev-tools/evolve_binaries.sh deleted file mode 100755 index fc7a20fb3b..0000000000 --- a/dev-tools/evolve_binaries.sh +++ /dev/null @@ -1,84 +0,0 @@ -#!/bin/bash - -# Script usage: ./evolve_binaries.sh -# This script clones the POSYDON repo to the specified branch (defaults to 'main'), -# copies evolve_binaries.py, runs it, and saves output to evolve_binaries.out - -# Set default branch to 'main' if not provided -BRANCH=${1:-main} -REPO_URL="https://github.com/POSYDON-code/POSYDON" - -if [[ -n "$2" ]]; then - SHA=$2 - WORK_DIR="POSYDON_${BRANCH}_${SHA}" -else - WORK_DIR="POSYDON_$BRANCH" -fi - -# Remove existing directory if it exists -if [ -d "$WORK_DIR" ]; then - echo "🗑️ Removing existing directory: $WORK_DIR" - rm -rf "$WORK_DIR" -fi - -echo "📁 Creating working directory: $WORK_DIR" -# Create the working directory -mkdir -p "$WORK_DIR" - -FULL_PATH="$(realpath "$WORK_DIR")" -CLONE_DIR="$FULL_PATH/POSYDON" - -echo "📋 Copying script_data folder" -# copy the script_data folder -cp -r "./script_data" "$WORK_DIR" - -cd "$WORK_DIR" - -# Initialize conda for bash -echo "🔧 Initializing conda" -# Source conda's shell integration -if [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then - source "$HOME/miniconda3/etc/profile.d/conda.sh" -elif [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then - source "$HOME/anaconda3/etc/profile.d/conda.sh" -elif [ -f "/opt/homebrew/Caskroom/miniconda/base/etc/profile.d/conda.sh" ]; then - source "/opt/homebrew/Caskroom/miniconda/base/etc/profile.d/conda.sh" -else - echo -e "\033[31mError: Could not find conda installation. Please check your conda setup.\033[0m" - exit 1 -fi - -# Clone the repository to the specified branch -echo "🔄 Cloning POSYDON repository (branch: $BRANCH)" -if ! git clone -b "$BRANCH" "$REPO_URL" "$CLONE_DIR" 2>&1 | sed 's/^/ /'; then - echo -e "\033[31mError: Failed to clone branch '$BRANCH'. Please check if the branch exists.\033[0m" - exit 1 -fi - -# if SHA is provided, checkout that commit -if [[ -n "$SHA" ]]; then - echo "🔄 Checking out commit: $SHA" - cd "$CLONE_DIR" - if ! git checkout "$SHA" 2>&1 | sed 's/^/ /'; then - echo -e "\033[31mError: Failed to checkout commit '$SHA'. Please check if the commit exists.\033[0m" - exit 1 - fi - cd - -fi - -# Create conda environment for POSYDON v2 -echo "🐍 Creating conda environment" -conda create --prefix="$FULL_PATH/conda_env" python=3.11 -y -q 2>&1 | sed 's/^/ /' - -echo "⚡ Activating conda environment" -conda activate "$FULL_PATH/conda_env" - -# install POSYDON manually -echo "📦 Installing POSYDON" -pip install -e "$CLONE_DIR" -q 2>&1 | sed 's/^/ /' - -echo "🚀 Running evolve_binaries.py" -# # Run the Python script and capture output (stdout and stderr) -python script_data/1Zsun_binaries_suite.py > $FULL_PATH/evolve_binaries_$BRANCH.out 2>&1 - -echo -e "✅ Script completed. Output saved to \n$FULL_PATH/evolve_binaries_$BRANCH.out" diff --git a/dev-tools/generate_baseline.sh b/dev-tools/generate_baseline.sh new file mode 100755 index 0000000000..2f1d9025c3 --- /dev/null +++ b/dev-tools/generate_baseline.sh @@ -0,0 +1,176 @@ +#!/bin/bash +# ============================================================================= +# generate_baseline.sh — Generate baseline HDF5 files from a designated branch. +# +# This runs the binary validation suite against a chosen branch (or commit) +# and saves the results as the baseline for future comparisons. +# +# Usage: +# ./generate_baseline.sh [sha] [metallicities] +# ./generate_baseline.sh --promote [metallicities] +# +# Examples: +# ./generate_baseline.sh main # evolve + save baseline, all Z +# ./generate_baseline.sh v2.1.0 # baseline from a release tag +# ./generate_baseline.sh main abc123f # baseline from a specific commit +# ./generate_baseline.sh main "" "1 0.45" # baseline for subset of Z +# ./generate_baseline.sh --promote main # promote existing outputs to baseline +# ./generate_baseline.sh --promote main "1 0.45" # promote subset of existing outputs +# +# Output: +# baselines//baseline_Zsun.h5 — one file per metallicity +# baselines//baseline_info.txt — records branch, commit SHA, date +# ============================================================================= + +set -euo pipefail + +# ── Parse arguments ─────────────────────────────────────────────────────── +PROMOTE=false +if [ "${1:-}" = "--promote" ]; then + PROMOTE=true + shift +fi + +BRANCH=${1:-main} +if [ "$PROMOTE" = true ]; then + SHA="" + METALLICITIES=${2:-"2 1 0.45 0.2 0.1 0.01 0.001 0.0001"} +else + SHA=${2:-} + METALLICITIES=${3:-"2 1 0.45 0.2 0.1 0.01 0.001 0.0001"} +fi + +DEV_TOOLS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SCRIPT_DIR="${DEV_TOOLS_DIR}/script_data" +SAFE_BRANCH="${BRANCH//\//_}" +BASELINE_DIR="$SCRIPT_DIR/baselines/${SAFE_BRANCH}" +BINARY_CANDIDATE_DIR="$SCRIPT_DIR/output/binary_star_tests/${SAFE_BRANCH}" + +echo "============================================================" +echo " POSYDON Binary Validation — Generating Baseline" +echo " Branch: $BRANCH" +if [ "$PROMOTE" = true ]; then + echo " Mode: --promote (using existing outputs)" +else + echo " SHA: ${SHA:-HEAD}" +fi +echo " Metallicities: $METALLICITIES" +echo " Output dir: $BASELINE_DIR" +echo "============================================================" + +# ── Step 1: Evolve binaries (skip if --promote) ────────────────────────── +if [ "$PROMOTE" = true ]; then + echo "" + echo "Step 1: SKIPPED (--promote: using existing outputs in $BINARY_CANDIDATE_DIR)" + + if [ ! -d "$BINARY_CANDIDATE_DIR" ]; then + echo "ERROR: No outputs found at $BINARY_CANDIDATE_DIR" >&2 + echo "Run evolve_binaries.sh first, or drop --promote to evolve from scratch." >&2 + exit 1 + fi +else + echo "" + echo "Step 1: Evolving binaries on branch '$BRANCH'..." + "${DEV_TOOLS_DIR}/run_test_suite.sh" "$BRANCH" "$SHA" "$METALLICITIES" +fi + +# ── Step 2: Copy results into the baselines directory ──────────────────── +echo "" +echo "Step 2: Copying results to baseline directory..." + +mkdir -p "$BASELINE_DIR" + +COPIED=0 + +for Z in $METALLICITIES; do + SRC="$BINARY_CANDIDATE_DIR/candidate_${Z}Zsun.h5" + DST="$BASELINE_DIR/baseline_${Z}Zsun.h5" + + if [ -f "$SRC" ]; then + cp "$SRC" "$DST" + echo " Saved: $DST" + COPIED=$((COPIED + 1)) + else + echo " WARNING: Missing output for Z=${Z}: $SRC" >&2 + fi +done + +# ── Step 3: Record baseline metadata ───────────────────────────────────── +CLONE_DIR="$SCRIPT_DIR/workdirs/POSYDON_${SAFE_BRANCH}/POSYDON" +ACTUAL_SHA="" +if [ -d "$CLONE_DIR" ]; then + ACTUAL_SHA=$(cd "$CLONE_DIR" && git rev-parse HEAD 2>/dev/null || echo "unknown") +fi + +# Extract PATH_TO_POSYDON_DATA from the first available baseline HDF5 file +POSYDON_DATA_PATH="unknown" +for Z in $METALLICITIES; do + H5="$BASELINE_DIR/baseline_${Z}Zsun.h5" + if [ -f "$H5" ]; then + POSYDON_DATA_PATH=$(python -c " +import pandas as pd +with pd.HDFStore('$H5', mode='r') as s: + print(s['/metadata']['path_to_posydon_data'].iloc[0]) +" 2>&1) || { + echo " WARNING: Could not read POSYDON data path from $H5" + echo " ($POSYDON_DATA_PATH)" + POSYDON_DATA_PATH="unknown" + } + break + fi +done + +# Check completeness from HDF5 metadata +INCOMPLETE="" +for Z in $METALLICITIES; do + DST="$BASELINE_DIR/baseline_${Z}Zsun.h5" + if [ -f "$DST" ]; then + MISSING=$(python3 -c " +import pandas as pd, sys +try: + with pd.HDFStore('$DST', mode='r') as s: + m = s['/metadata'] + n = int(m['n_missing'].iloc[0]) + if n > 0: + print(f'Z={Z}: {n} missing — {m[\"missing_ids\"].iloc[0]}') +except Exception as e: + print(f'Z={Z}: could not read metadata ({e})', file=sys.stderr) +" 2>/dev/null) + if [ -n "$MISSING" ]; then + echo " ⚠️ $MISSING" + INCOMPLETE="${INCOMPLETE} ${MISSING}\n" + fi + fi +done + +INFO_FILE="$BASELINE_DIR/baseline_info.txt" +cat > "$INFO_FILE" << EOF +POSYDON Binary Validation Baseline +=================================== +Branch: $BRANCH +Commit SHA: ${ACTUAL_SHA:-unknown} +Requested SHA: ${SHA:-HEAD} +Mode: $([ "$PROMOTE" = true ] && echo "promoted from existing outputs" || echo "evolved from scratch") +Generated: $(date -u '+%Y-%m-%d %H:%M:%S UTC') +Metallicities: $METALLICITIES +Files: $COPIED +POSYDON data: $POSYDON_DATA_PATH +EOF + +if [ -n "$INCOMPLETE" ]; then + printf "\nINCOMPLETE BASELINES:\n%b\n" "$INCOMPLETE" >> "$INFO_FILE" + echo "" + echo "⚠️ WARNING: Some baselines have missing binaries. See $INFO_FILE" +fi + +echo "" +echo "============================================================" +echo " Baseline generated: $COPIED file(s)" +echo " Info: $INFO_FILE" +echo " Directory: $BASELINE_DIR" +echo "============================================================" + +if [ $COPIED -eq 0 ]; then + echo "ERROR: No baseline files were created!" >&2 + exit 1 +fi diff --git a/dev-tools/run_test_suite.sh b/dev-tools/run_test_suite.sh new file mode 100755 index 0000000000..fea8bcd621 --- /dev/null +++ b/dev-tools/run_test_suite.sh @@ -0,0 +1,225 @@ +#!/bin/bash + +# ============================================================================= +# evolve_binaries.sh — Clone a POSYDON branch, install it, and run the +# binary validation suite at all requested metallicities. +# +# Usage: +# ./evolve_binaries.sh [sha] [metallicities] +# +# Examples: +# ./evolve_binaries.sh main # all metallicities +# ./evolve_binaries.sh feature/my-fix abc123f # specific commit +# ./evolve_binaries.sh main "" "1 0.45 0.1" # subset of metallicities +# +# Output structure: +# outputs//candidate_Zsun.h5 — evolution results per metallicity +# logs//evolve_Zsun.log — log per metallicity +# workdirs/POSYDON_/ — cloned repo + conda env +# +# The resolved commit SHA and branch name are recorded in each HDF5 file's +# /metadata table for provenance tracking in comparison reports. +# ============================================================================= + +set -euo pipefail +# Load git if needed +if ! command -v git >/dev/null 2>&1; then + if command -v module >/dev/null 2>&1; then + module load git + fi +fi + +# ── Configuration ────────────────────────────────────────────────────────── +ALL_METALLICITIES="2 1 0.45 0.2 0.1 0.01 0.001 0.0001" + +BRANCH=${1:-main} +SHA=${2:-} +METALLICITIES=${3:-$ALL_METALLICITIES} + +REPO_URL="https://github.com/POSYDON-code/POSYDON" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/script_data" + +# Sanitize branch name for filesystem +SAFE_BRANCH="${BRANCH//\//_}" + + +# Directories (all relative to SCRIPT_DIR, the dev-tools root) +WORK_DIR="$SCRIPT_DIR/workdirs/POSYDON_${SAFE_BRANCH}" +BINARY_OUTPUT_DIR="$SCRIPT_DIR/output/binary_star_tests/${SAFE_BRANCH}" +POP_OUTPUT_DIR="$SCRIPT_DIR/output/population_tests/${SAFE_BRANCH}" +LOG_DIR="$SCRIPT_DIR/logs/${SAFE_BRANCH}" +CLONE_DIR="$WORK_DIR/POSYDON" + +mkdir -p ${LOG_DIR} ${BINARY_OUTPUT_DIR} + +# ── Conda Setup ──────────────────────────────────────────────────────────── +echo "🔧 Initializing conda" +CONDA_SH="" +for candidate in \ + "$HOME/miniconda3/etc/profile.d/conda.sh" \ + "$HOME/anaconda3/etc/profile.d/conda.sh" \ + "/opt/homebrew/Caskroom/miniconda/base/etc/profile.d/conda.sh"; do + if [ -f "$candidate" ]; then + CONDA_SH="$candidate" + break + fi +done + +if [ -z "$CONDA_SH" ]; then + if command -v conda >/dev/null 2>&1; then + CONDA_SH="$(conda info --base)/etc/profile.d/conda.sh" + else + echo "ERROR: Could not find conda installation." >&2 + exit 1 + fi +fi +source "$CONDA_SH" + +# ── Clone Repository ────────────────────────────────────────────────────── +if [ -d "$WORK_DIR" ]; then + echo "🗑️ Removing existing work directory: $WORK_DIR" + rm -rf "$WORK_DIR" +fi +mkdir -p "$WORK_DIR" + +echo "🔄 Cloning POSYDON repository (branch: $BRANCH)" +if ! git clone -b "$BRANCH" "$REPO_URL" "$CLONE_DIR" 2>&1 | sed 's/^/ /'; then + echo "ERROR: Failed to clone branch '$BRANCH'." >&2 + exit 1 +fi + +# if SHA is provided, checkout that commit +if [[ -n "$SHA" ]]; then + echo "🔄 Checking out commit: $SHA" + cd "$CLONE_DIR" + if ! git checkout "$SHA" 2>&1 | sed 's/^/ /'; then + echo -e "\033[31mError: Failed to checkout commit '$SHA'. Please check if the commit exists.\033[0m" + exit 1 + fi + cd - +fi + +# ── Create Conda Environment ───────────────────────────────────────────── +ENV_PREFIX="$WORK_DIR/conda_env" + +echo "🐍 Creating conda environment at $ENV_PREFIX" +conda create --prefix="$ENV_PREFIX" python=3.11 -y -q 2>&1 | sed 's/^/ /' +conda activate "$ENV_PREFIX" + +echo "📦 Installing POSYDON" +pip install -e "$CLONE_DIR" -q 2>&1 | sed 's/^/ /' + +# ── Run Suite for Each Metallicity ──────────────────────────────────────── +SUITE_SCRIPT="$SCRIPT_DIR/src/binaries_suite.py" +FAILED=0 + +# Resolve the actual commit SHA for metadata +ACTUAL_SHA=$(cd "$CLONE_DIR" && git rev-parse HEAD 2>/dev/null || echo "unknown") +echo " Resolved SHA: $ACTUAL_SHA" + +# override environment's PATH_TO_POSYDON variable to point to the +# current branch's clone for these tests +#PATH_TO_POSYDON=$CLONE_DIR + +# copy this branch's default .ini file to perform tests +DEFAULT_INI="${CLONE_DIR}/posydon/popsyn/population_params_default.ini" +TEST_INI="${SCRIPT_DIR}/inlists/${SAFE_BRANCH}_test_params.ini" +cp $DEFAULT_INI $TEST_INI +sed -i 's/^\([[:space:]]*\)entropy *= *.*/\1entropy = 0/' $TEST_INI + +for Z in $METALLICITIES; do + OUTPUT_FILE="$BINARY_OUTPUT_DIR/candidate_${Z}Zsun.h5" + LOG_FILE="$LOG_DIR/evolve_${Z}Zsun.log" + + echo "" + echo "============================================================" + echo " 🚀 Evolving binaries for Z = ${Z} Zsun" + echo " Output: $OUTPUT_FILE" + echo " Log: $LOG_FILE" + echo "============================================================" + + python "$SUITE_SCRIPT" \ + --metallicity "$Z" \ + --output "$OUTPUT_FILE" \ + --branch "$BRANCH" \ + --sha "$ACTUAL_SHA" \ + 2>&1 | tee "$LOG_FILE" + EXIT_CODE=${PIPESTATUS[0]} + + if [ $EXIT_CODE -eq 137 ]; then + echo "ERROR: Process killed (likely OOM) for Z=${Z}. Exit code 137 (SIGKILL)." >&2 + echo " Consider increasing job memory." >&2 + FAILED=$((FAILED + 1)) + elif [ $EXIT_CODE -ne 0 ]; then + echo "WARNING: Suite failed for Z=${Z} (exit code $EXIT_CODE). Check $LOG_FILE" >&2 + FAILED=$((FAILED + 1)) + elif [ ! -f "$OUTPUT_FILE" ]; then + echo "WARNING: Output file not created for Z=${Z}" >&2 + FAILED=$((FAILED + 1)) + else + echo " Z=${Z} Zsun complete." + fi + +done + +# Run population synthesis tests and capture output (stdout and stderr) +SUITE_SCRIPT="$SCRIPT_DIR/src/popsynth_suite.py" +FAILED=0 +LOG_FILE="$LOG_DIR/evolve_populations.log" + +MULTIZ_INI="${SCRIPT_DIR}/inlists/${SAFE_BRANCH}_test_multiZ_params.ini" +cp $TEST_INI $MULTIZ_INI +sed -i 's/^\([[:space:]]*\)metallicities *= *\[.*\]/\1metallicities = [2., 1., 1e-4]/' $MULTIZ_INI + +echo "" +echo "============================================================" +echo " 🚀 Evolving populations" +echo " Output: $POP_OUTPUT_DIR" +echo " Log: $LOG_FILE" +echo "============================================================" + +python "$SUITE_SCRIPT" \ + --output "$POP_OUTPUT_DIR" \ + --ini "$TEST_INI" \ + --multiz "$MULTIZ_INI" \ + 2>&1 | tee "$LOG_FILE" +EXIT_CODE=${PIPESTATUS[0]} + +if [ $EXIT_CODE -eq 137 ]; then + echo "ERROR: Process killed (likely OOM) for popsynth_suite.py. Exit code 137 (SIGKILL)." >&2 + echo " Consider increasing job memory." >&2 + FAILED=$((FAILED + 1)) +elif [ $EXIT_CODE -ne 0 ]; then + echo "WARNING: popsynth_suite.py failed (exit code $EXIT_CODE). Check $LOG_FILE" >&2 + FAILED=$((FAILED + 1)) +elif [ ! -f "$OUTPUT_FILE" ]; then + echo "WARNING: Output file not created for popsynth_suite.py" >&2 + FAILED=$((FAILED + 1)) +else + echo " popsynth_suite.py completed." +fi + +# ── Deactivate Environment ──────────────────────────────────────────────── +conda deactivate + +echo "" +echo "============================================================" +if [ $FAILED -eq 0 ]; then + echo "✅ All metallicities completed successfully." +else + echo "Completed with $FAILED failure(s)." +fi +echo " Outputs in: $BINARY_OUTPUT_DIR/" +echo "============================================================" +echo "" +echo "============================================================" +if [ $FAILED -eq 0 ]; then + echo "✅ All population synthesis tests completed successfully." +else + echo "Completed with $FAILED failure(s)." +fi +echo " Outputs in: $POP_OUTPUT_DIR/" +echo "============================================================" + + +exit $FAILED diff --git a/dev-tools/script_data/1Zsun_binaries_params.ini b/dev-tools/script_data/1Zsun_binaries_params.ini deleted file mode 100644 index f5400f3117..0000000000 --- a/dev-tools/script_data/1Zsun_binaries_params.ini +++ /dev/null @@ -1,548 +0,0 @@ -# POSYDON default BinaryPopulation inifile, use ConfigParser syntax - -[environment_variables] - PATH_TO_POSYDON = '' - - -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -;;;;;;;; SimulationProperties ;;;;;;;;;; -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; - -[flow] - import = ['posydon.binary_evol.flow_chart', 'flow_chart'] - # builtin posydon flow - absolute_import = None - # If given, use an absolute filepath to user defined flow: ['', ''] - -[step_HMS_HMS] - import = ['posydon.binary_evol.MESA.step_mesa', 'MS_MS_step'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - interpolation_path = None - # found by default - interpolation_filename = None - # found by default - interpolation_method = 'linear3c_kNN' - # 'nearest_neighbour' 'linear3c_kNN' '1NN_1NN' - save_initial_conditions = True - # only for interpolation_method='nearest_neighbour' - track_interpolation = False - # True False - stop_method = 'stop_at_max_time' - # 'stop_at_end' 'stop_at_max_time' 'stop_at_condition' - stop_star = 'star_1' - # only for stop_method='stop_at_condition' 'star_1' 'star_2' - stop_var_name = None - # only for stop_method='stop_at_condition' str - stop_value = None - # only for stop_method='stop_at_condition' float - stop_interpolate = True - # True False - verbose = False - # True False - - -[step_CO_HeMS] - import = ['posydon.binary_evol.MESA.step_mesa', 'CO_HeMS_step'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - interpolation_path = None - # found by default - interpolation_filename = None - # found by default - interpolation_method = 'linear3c_kNN' - # 'nearest_neighbour' 'linear3c_kNN' '1NN_1NN' - save_initial_conditions = True - # only for interpolation_method='nearest_neighbour' - track_interpolation = False - # True False - stop_method = 'stop_at_max_time' - # 'stop_at_end' 'stop_at_max_time' 'stop_at_condition' - stop_star = 'star_1' - # only for stop_method='stop_at_condition' 'star_1' 'star_2' - stop_var_name = None - # only for stop_method='stop_at_condition' str - stop_value = None - # only for stop_method='stop_at_condition' float - stop_interpolate = True - # True False - verbose = False - # True False - -[step_CO_HMS_RLO] - import = ['posydon.binary_evol.MESA.step_mesa', 'CO_HMS_RLO_step'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - interpolation_path = None - # found by default - interpolation_filename = None - # found by default - interpolation_method = 'linear3c_kNN' - # 'nearest_neighbour' 'linear3c_kNN' '1NN_1NN' - save_initial_conditions = True - # only for interpolation_method='nearest_neighbour' - track_interpolation = False - # True False - stop_method = 'stop_at_max_time' - # 'stop_at_end' 'stop_at_max_time' 'stop_at_condition' - stop_star = 'star_1' - # only for stop_method='stop_at_condition' 'star_1' 'star_2' - stop_var_name = None - # only for stop_method='stop_at_condition' str - stop_value = None - # only for stop_method='stop_at_condition' float - stop_interpolate = True - # True False - verbose = False - # True False - -[step_CO_HeMS_RLO] - import = ['posydon.binary_evol.MESA.step_mesa', 'CO_HeMS_RLO_step'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - interpolation_path = None - # found by default - interpolation_filename = None - # found by default - interpolation_method = 'linear3c_kNN' - # 'nearest_neighbour' 'linear3c_kNN' '1NN_1NN' - save_initial_conditions = True - # only for interpolation_method='nearest_neighbour' - track_interpolation = False - # True False - stop_method = 'stop_at_max_time' - # 'stop_at_end' 'stop_at_max_time' 'stop_at_condition' - stop_star = 'star_1' - # only for stop_method='stop_at_condition' 'star_1' 'star_2' - stop_var_name = None - # only for stop_method='stop_at_condition' str - stop_value = None - # only for stop_method='stop_at_condition' float - stop_interpolate = True - # True False - verbose = False - # True False - - -[step_detached] - import = ['posydon.binary_evol.DT.step_detached', 'detached_step'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - matching_method = 'minimize' - #'minimize' 'root' - do_wind_loss = True - # True, False - do_tides = True - # True, False - do_gravitational_radiation = True - # True, False - do_magnetic_braking = True - # True, False - do_stellar_evolution_and_spin_from_winds = True - # True, False - RLO_orbit_at_orbit_with_same_am = False - # True, False - #record_matching = False - # True, False - verbose = False - # True, False - -[step_disrupted] - import = ['posydon.binary_evol.DT.step_disrupted','DisruptedStep'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - -[step_merged] - import = ['posydon.binary_evol.DT.step_merged','MergedStep'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - -[step_initially_single] - import = ['posydon.binary_evol.DT.step_initially_single','InitiallySingleStep'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - -[step_CE] - import = ['posydon.binary_evol.CE.step_CEE', 'StepCEE'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - prescription='alpha-lambda' - # 'alpha-lambda' - common_envelope_efficiency=1.0 - # float in (0, inf) - common_envelope_option_for_lambda='lambda_from_grid_final_values' - # (1) 'default_lambda', (2) 'lambda_from_grid_final_values', - # (3) 'lambda_from_profile_gravitational', - # (4) 'lambda_from_profile_gravitational_plus_internal', - # (5) 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - common_envelope_lambda_default=0.5 - # float in (0, inf) used only for option (1) - common_envelope_option_for_HG_star="optimistic" - # 'optimistic', 'pessimistic' - common_envelope_alpha_thermal=1.0 - # float in (0, inf) used only for option for (4), (5) - core_definition_H_fraction=0.3 - # 0.01, 0.1, 0.3 - core_definition_He_fraction=0.1 - # 0.1 - CEE_tolerance_err = 0.001 - # float (0, inf) - common_envelope_option_after_succ_CEE = 'two_phases_stableMT' - # 'two_phases_stableMT' 'one_phase_variable_core_definition' - # 'two_phases_windloss' - verbose = False - # True False - -[step_SN] - import = ['posydon.binary_evol.SN.step_SN', 'StepSN'] - # builtin posydon step - absolute_import = None - # 'package' kwarg for importlib.import_module - mechanism = 'Fryer+12-delayed' - # v2 interpolators support: 'Fryer+12-rapid', 'Fryer+12-delayed', - # 'Sukhbold+16-engine', 'Patton&Sukhbold20-engine' - # need profiles: 'direct' - engine = '' - # 'N20' or 'W20' for 'Sukhbold+16-engine', 'Patton&Sukhbold20-engine' - # '' for the others - PISN = "Hendriks+23" - # v2 interpolators support: "Hendriks+23" - # other options: None, "Marchant+19" - PISN_CO_shift = 0.0 - # Only when using Hendriks+23 - # float (-inf,inf) - # v2 interpolators support: 0.0 - PPI_extra_mass_loss = -20.0 - # Only when using Hendriks+23 - # float (-inf,inf) - # v2 interpolators support: 0.0 or -20.0 - ECSN = "Tauris+15" - # "Tauris+15", "Podsiadlowski+04" - conserve_hydrogen_envelope = False - # True, False - conserve_hydrogen_PPI = False - # Only when using Hendriks+23 - # True, False - max_neutrino_mass_loss = 0.5 - # float (0,inf) - # v2 interpolators support: 0.5 - max_NS_mass = 2.5 - # float (0,inf) - # v2 interpolators support: 2.5 - use_interp_values = True - # True, False - use_profiles = True - # True, False - use_core_masses = True - # True, False - allow_spin_None = False - # True, False - approx_at_he_depletion = False - # True, False - kick = True - # True, False - kick_normalisation = 'one_over_mass' - # "one_minus_fallback", "one_over_mass", "NS_one_minus_fallback_BH_one", - # "one", "zero", "asym_ej", "linear", "log_normal" - sigma_kick_CCSN_NS = 265.0 - # float (0,inf) - sigma_kick_CCSN_BH = 265.0 - # float (0,inf) - sigma_kick_ECSN = 20.0 - # float (0,inf) - verbose = False - # True False - -[step_dco] - import = ['posydon.binary_evol.DT.double_CO', 'DoubleCO'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - n_o_steps_history = None - -[step_end] - import = ['posydon.binary_evol.step_end', 'step_end'] - # builtin posydon step - absolute_import = None - # If given, use an absolute filepath to user defined step: ['', ''] - -[extra_hooks] - import_1 = ['posydon.binary_evol.simulationproperties', 'TimingHooks'] - # builtin posydon hook - absolute_import_1 = None - # If given, use an absolute filepath to user defined step: ['', ''] - kwargs_1 = {} - - import_2 = ['posydon.binary_evol.simulationproperties', 'StepNamesHooks'] - # builtin posydon hook - absolute_import_2 = None - # If given, use an absolute filepath to user defined step: ['', ''] - kwargs_2 = {} - - -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -;;;;;;;; BinaryPopulation ;;;;;;;;;; -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; - -[BinaryPopulation_options] - optimize_ram = True - # save population in batches - ram_per_cpu = None - # set maximum ram per cpu before batch saving (GB) - dump_rate = 2000 - # batch save after evolving N binaries - # this should be at least 500 for populations of 100,000 binaries or more - temp_directory = 'batches' - # folder for keeping batch files - tqdm = False - # progress bar - breakdown_to_df = True - # convert BinaryStars into DataFrames after evolution - use_MPI = False - # use only for local MPI runs - metallicity = [1.] #[2., 1., 0.45, 0.2, 0.1, 0.01, 0.001, 0.0001] - # In units of solar metallicity - error_checking_verbose = False - # if True, write all POSYDON errors to stderr at runtime, default=False - warnings_verbose = False - # if True, write all POSYDON warnings to stderr at runtime, default=False - history_verbose = False - # if True, record extra functional steps in the output DataFrames - # (These steps represent internal workings of POSYDON rather than physical phases of evolution) - entropy = None - # `None` uses system entropy (recommended) - number_of_binaries = 10 - # int - binary_fraction_scheme = 'const' - #'const' 'Moe_17' - binary_fraction_const = 1.0 - # float 0< fraction <=1 - star_formation = 'burst' - # 'constant' 'burst' 'custom_linear' 'custom_log10' 'custom_linear_histogram' 'custom_log10_histogram' - max_simulation_time = 13.8e9 - # float (0,inf) - - read_samples_from_file = '' - # path to file to read initial parameters from (if empty string get random samples) - primary_mass_scheme = 'Kroupa2001' - # 'Salpeter', 'Kroupa1993', 'Kroupa2001' - primary_mass_min = 7.0 - # float (0,130) - primary_mass_max = 150.0 - # float (0,130) - secondary_mass_scheme = 'flat_mass_ratio' - # 'flat_mass_ratio', 'q=1' - secondary_mass_min = 0.5 - # float (0,130) - secondary_mass_max = 150.0 - # float (0,130) - orbital_scheme = 'period' - # 'separation', 'period' - orbital_period_scheme = 'Sana+12_period_extended' - # used only for orbital_scheme = 'period' - orbital_period_min = 0.75 - # float (0,inf) - orbital_period_max = 6000.0 - # float (0,inf) - #orbital_separation_scheme = 'log_uniform' - # used only for orbital_scheme = 'separation', 'log_uniform', 'log_normal' - #orbital_separation_min = 5.0 - # float (0,inf) - #orbital_separation_max = 1e5 - # float (0,inf) - #log_orbital_separation_mean = None - # float (0,inf) used only for orbital_separation_scheme ='log_normal' - #log_orbital_separation_sigma = None - # float (0,inf) used only for orbital_separation_scheme ='log_normal' - eccentricity_scheme = 'zero' - # 'zero' 'thermal' 'uniform' - - -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -;;;;;;;; Saving Output ;;;;;;;;;; -;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; - -[BinaryStar_output] - extra_columns = {'step_names':'string', 'step_times':'float64'} - # 'step_times' with from posydon.binary_evol.simulationproperties import TimingHooks - - # LIST BINARY PROPERTIES - only_select_columns=[ - 'state', - 'event', - 'time', - #'separation', - 'orbital_period', - 'eccentricity', - #'V_sys', - #'rl_relative_overflow_1', - #'rl_relative_overflow_2', - 'lg_mtransfer_rate', - #'mass_transfer_case', - #'trap_radius', - #'acc_radius', - #'t_sync_rad_1', - #'t_sync_conv_1', - #'t_sync_rad_2', - #'t_sync_conv_2', - #'nearest_neighbour_distance', - ] - scalar_names=[ - 'interp_class_HMS_HMS', - 'interp_class_CO_HMS_RLO', - 'interp_class_CO_HeMS', - 'interp_class_CO_HeMS_RLO', - 'mt_history_HMS_HMS', - 'mt_history_CO_HMS_RLO', - 'mt_history_CO_HeMS', - 'mt_history_CO_HeMS_RLO', - ] - -[SingleStar_1_output] - # LIST STAR PROPERTIES TO SAVE - include_S1=True - # True, False - only_select_columns=[ - 'state', - #'metallicity', - 'mass', - 'log_R', - 'log_L', - 'lg_mdot', - #'lg_system_mdot', - #'lg_wind_mdot', - 'he_core_mass', - 'he_core_radius', - #'c_core_mass', - #'c_core_radius', - #'o_core_mass', - #'o_core_radius', - 'co_core_mass', - 'co_core_radius', - 'center_h1', - 'center_he4', - #'center_c12', - #'center_n14', - #'center_o16', - 'surface_h1', - 'surface_he4', - #'surface_c12', - #'surface_n14', - #'surface_o16', - #'log_LH', - #'log_LHe', - #'log_LZ', - #'log_Lnuc', - #'c12_c12', - #'center_gamma', - #'avg_c_in_c_core', - #'surf_avg_omega', - 'surf_avg_omega_div_omega_crit', - #'total_moment_of_inertia', - #'log_total_angular_momentum', - 'spin', - #'conv_env_top_mass', - #'conv_env_bot_mass', - #'conv_env_top_radius', - #'conv_env_bot_radius', - #'conv_env_turnover_time_g', - #'conv_env_turnover_time_l_b', - #'conv_env_turnover_time_l_t', - #'envelope_binding_energy', - #'mass_conv_reg_fortides', - #'thickness_conv_reg_fortides', - #'radius_conv_reg_fortides', - #'lambda_CE_1cent', - #'lambda_CE_10cent', - #'lambda_CE_30cent', - #'lambda_CE_pure_He_star_10cent', - #'profile', - #'total_mass_h1', - #'total_mass_he4', - ] - scalar_names=[ - 'natal_kick_array', - 'SN_type', - 'f_fb', - 'spin_orbit_tilt_first_SN', - 'spin_orbit_tilt_second_SN', - ] - -[SingleStar_2_output] - # LIST STAR PROPERTIES TO SAVE - include_S2 = True - # True, False - only_select_columns = [ - 'state', - #'metallicity', - 'mass', - 'log_R', - 'log_L', - 'lg_mdot', - #'lg_system_mdot', - #'lg_wind_mdot', - 'he_core_mass', - 'he_core_radius', - #'c_core_mass', - #'c_core_radius', - #'o_core_mass', - #'o_core_radius', - 'co_core_mass', - 'co_core_radius', - 'center_h1', - 'center_he4', - #'center_c12', - #'center_n14', - #'center_o16', - 'surface_h1', - 'surface_he4', - #'surface_c12', - #'surface_n14', - #'surface_o16', - #'log_LH', - #'log_LHe', - #'log_LZ', - #'log_Lnuc', - #'c12_c12', - #'center_gamma', - #'avg_c_in_c_core', - #'surf_avg_omega', - 'surf_avg_omega_div_omega_crit', - #'total_moment_of_inertia', - #'log_total_angular_momentum', - 'spin', - #'conv_env_top_mass', - #'conv_env_bot_mass', - #'conv_env_top_radius', - #'conv_env_bot_radius', - #'conv_env_turnover_time_g', - #'conv_env_turnover_time_l_b', - #'conv_env_turnover_time_l_t', - #'envelope_binding_energy', - #'mass_conv_reg_fortides', - #'thickness_conv_reg_fortides', - #'radius_conv_reg_fortides', - #'lambda_CE_1cent', - #'lambda_CE_10cent', - #'lambda_CE_30cent', - #'lambda_CE_pure_He_star_10cent', - #'profile', - #'total_mass_h1', - #'total_mass_he4', - ] - scalar_names=[ - 'natal_kick_array', - 'SN_type', - 'f_fb', - 'spin_orbit_tilt_first_SN', - 'spin_orbit_tilt_second_SN', - ] diff --git a/dev-tools/script_data/1Zsun_binaries_suite.py b/dev-tools/script_data/1Zsun_binaries_suite.py deleted file mode 100644 index 5b2c416c3c..0000000000 --- a/dev-tools/script_data/1Zsun_binaries_suite.py +++ /dev/null @@ -1,768 +0,0 @@ -#!/usr/bin/env python3 -""" -Script to evolve a few binaries. -Used for validation of the branch. - -Author: Max Briel -""" - -import argparse -import os -import signal -import sys -import warnings - -from posydon.binary_evol.binarystar import BinaryStar, SingleStar -from posydon.binary_evol.simulationproperties import SimulationProperties -from posydon.popsyn.io import simprop_kwargs_from_ini -from posydon.utils.common_functions import orbital_separation_from_period - -target_rows = 12 -line_length = 140 -columns_to_show = ['step_names', 'state', 'event', 'S1_state', 'S1_mass', 'S2_state', 'S2_mass', 'orbital_period'] - -def load_inlist(verbose): - - sim_kwargs = simprop_kwargs_from_ini('script_data/1Zsun_binaries_params.ini', verbose=verbose) - metallicity = {'metallicity':1, 'verbose':verbose} - - sim_kwargs['step_HMS_HMS'][1].update(metallicity) - sim_kwargs['step_CO_HeMS'][1].update(metallicity) - sim_kwargs['step_CO_HMS_RLO'][1].update(metallicity) - sim_kwargs['step_CO_HeMS_RLO'][1].update(metallicity) - sim_kwargs['step_detached'][1].update(metallicity) - sim_kwargs['step_disrupted'][1].update(metallicity) - sim_kwargs['step_merged'][1].update(metallicity) - sim_kwargs['step_initially_single'][1].update(metallicity) - - sim_prop = SimulationProperties(**sim_kwargs) - - sim_prop.load_steps(verbose=verbose) - return sim_prop - -def write_binary_to_screen(binary): - """Writes a binary DataFrame prettily to the screen - - Args: - binary: BinaryStar object with evolved data - """ - df = binary.to_df(**{'extra_columns':{'step_names':'str'}}) - - # Filter to only existing columns - available_columns = [col for col in columns_to_show if col in df.columns] - df_filtered = df[available_columns] - - # Reset index to use a counter instead of NaN - df_filtered = df_filtered.reset_index(drop=True) - - print("=" * line_length) - - # Print the DataFrame - df_string = df_filtered.to_string(index=True, float_format='%.3f') - print(df_string) - - # Add empty lines to reach exactly 10 rows of output - current_rows = len(df_filtered) + 1 # add one for header - - if current_rows < target_rows: - # Calculate the width of the output to print empty lines of the same width - lines = df_string.split('\n') - if len(lines) > 1: - # Use the width of the data lines (skip header) - empty_lines_needed = target_rows - current_rows - for i in range(empty_lines_needed): - print("") - - print("-" * line_length) - - -def print_failed_binary(binary,e, max_error_lines=3): - - print("=" * line_length) - print(f"🚨 Binary Evolution Failed!") - print(f"Exception: {type(e).__name__}") - print(f"Message: {e}") - - # Get the binary's current state and limit output - try: - df = binary.to_df(**{'extra_columns':{'step_names':'str'}}) - if len(df) > 0: - # Select only the desired columns - - available_columns = [col for col in columns_to_show if col in df.columns] - df_filtered = df[available_columns] - - # Reset index to use a counter instead of NaN - df_filtered = df_filtered.reset_index(drop=True) - - # Limit to max_error_lines - if len(df_filtered) > max_error_lines: - df_filtered = df_filtered.tail(max_error_lines) - print(f"\nShowing last {max_error_lines} evolution steps before failure:") - else: - print(f"\nEvolution steps before failure ({len(df_filtered)} steps):") - - df_string = df_filtered.to_string(index=True, float_format='%.3f') - print(df_string) - - current_rows = len(df_filtered) + 1 + 5 # add one for header - empty_lines_needed = target_rows - current_rows - for i in range(empty_lines_needed): - print("") - else: - print("\nNo evolution steps recorded before failure.") - except Exception as inner_e: - print(f"\nCould not retrieve binary state: {inner_e}") - - print("-" * line_length) - -def evolve_binary(binary): - - # Capture warnings during evolution - captured_warnings = [] - - def warning_handler(message, category, filename, lineno, file=None, line=None): - captured_warnings.append({ - 'message': str(message), - 'category': category.__name__, - 'filename': filename, - 'lineno': lineno - }) - - # Set up warning capture - old_showwarning = warnings.showwarning - warnings.showwarning = warning_handler - - try: - binary.evolve() - # Display the evolution summary for successful evolution - write_binary_to_screen(binary) - - # Show warnings if any were captured - if captured_warnings: - print(f"⚠️ {len(captured_warnings)} warning(s) raised during evolution:") - for i, warning in enumerate(captured_warnings[:3], 1): # Show max 3 warnings - print(f" {i}. {warning['category']}: {warning['message']}") - if len(captured_warnings) > 3: - print(f" ... and {len(captured_warnings) - 3} more warning(s)") - elif len(captured_warnings) <= 3: - for i in range(4-len(captured_warnings)): - print("") - else: - print(f"No warning(s) raised during evolution\n\n\n\n") - print("=" * line_length) - - except Exception as e: - - # turn off binary alarm in case of exception - signal.alarm(0) - - print_failed_binary(binary, e) - - # Show warnings if any were captured before the exception - if captured_warnings: - print(f"\n⚠️ {len(captured_warnings)} warning(s) raised before failure:") - for i, warning in enumerate(captured_warnings[:3], 1): # Show max 3 warnings - print(f" {i}. {warning['category']}: {warning['message']}") - if len(captured_warnings) > 3: - print(f" ... and {len(captured_warnings) - 3} more warning(s)") - else: - print(f"No warning(s) raised during evolution\n\n\n\n") - - print("=" * line_length) - finally: - # Always turn off binary alarm and restore warning handler - signal.alarm(0) - warnings.showwarning = old_showwarning - - -def evolve_binaries(verbose): - """Evolves a few binaries to validate their output - """ - sim_prop = load_inlist(verbose) - - ######################################## - # Failing binary in matching - ######################################## - star_1 = SingleStar(**{'mass': 11.948472796094759, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [231.97383621190582, 5.927334890264575, 1.5990566013567014, 6.137994236518587]}) - star_2 = SingleStar(**{'mass': 7.636958434479617, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 190925.99636740884,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # Failing binary in matching - ######################################## - star_1 = SingleStar(**{'mass': 30.169861921689556, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [77.96834852144123, 0.05021460132555987, 2.3146518208348152, 1.733054979982291]}) - star_2 = SingleStar(**{'mass': 10.972734402996027, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 20479.71919353725,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # flipped S1 and S2 ? - ######################################## - star_1 = SingleStar(**{'mass': 9.474917413943635, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [133.5713935237759, 4.398754864537542, 2.703102872841114, 1.4633904612711142]}) - star_2 = SingleStar(**{'mass': 9.311073918196263, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 18.605997832086413,'eccentricity': 0.0}, properties = sim_prop) - - evolve_binary(binary) - ######################################## - # flipped S1 and S2 - ######################################## - star_1 = SingleStar(**{'mass': 10.438541, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - star_2 = SingleStar(**{'mass': 1.400713, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 9.824025,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # flipped S1 and S2 - ######################################## - star_1= SingleStar(**{'mass': 9.845907 , 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}) - star_2 = SingleStar(**{'mass': 9.611029, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 3.820571,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # Normal binary evolution - ######################################## - star_1= SingleStar(**{'mass': 30.845907 , 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}) - star_2 = SingleStar(**{'mass': 30.611029, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 30.820571,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # Normal binary - ######################################## - star_1= SingleStar(**{'mass': 9.213534679594247 , 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [327.5906384501521, 1.7707176050073297, 1.573225822966838, 1.6757313876001914]}) - star_2 = SingleStar(**{'mass': 7.209878522799272, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 63123.74544474666,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # Normal binary - ######################################## - star_1= SingleStar(**{'mass': 9.561158487732602 , 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [317.5423844462847, 2.9095984678057603, 1.754121288652108, 2.3693917842468784]}) - star_2 = SingleStar(**{'mass': 9.382732464319286, 'state': 'H-rich_Core_H_burning','metallicity':1, - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 27.77657038557851,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # Normal binary - ######################################## - star1 = SingleStar(**{'mass': 7.552858,#29829485, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [40.91509926587841, 2.6295454150818256, 1.6718337470964977, 6.0408769315244895]}) - star2 = SingleStar(**{'mass': 6.742063, #481560266, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star1, star2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 17.957531550841225, 'eccentricity': 0.0,}, - properties=sim_prop) - evolve_binary(binary) - ######################################## - # High BH spin options - ######################################## - star_1 = SingleStar(**{'mass': 31.616785, 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [10, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}) - star_2 = SingleStar(**{'mass': 26.874267, 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 501.99252706449792,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # Original a>1 spin error - ######################################## - star_1 = SingleStar(**{'mass': 18.107506844123645, 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [528.2970725443025, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}) - star_2 = SingleStar(**{'mass': 15.641392951875442, 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 151.99252706449792,'eccentricity': 0.0}, properties = sim_prop) - ######################################## - # FIXED disrupted crash - ######################################## - STAR1 = SingleStar(**{'mass': 52.967313, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 36.306444, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':12.877004, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED error with SN type - ######################################## - STAR1 = SingleStar(**{'mass': 17.782576, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass':3.273864, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':4513.150157, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED oRLO2 looping - ######################################## - STAR1 = SingleStar(**{'mass': 170.638207, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [4.921294, 4.31745, 1.777768, 3.509656]}) - STAR2 = SingleStar(**{'mass':37.917852, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':113.352736, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # Redirect to step_CO_HeMS (H-rich non-burning?) - ######################################## - star_1 = SingleStar(**{'mass': 8.333579, 'state': 'H-rich_Core_H_burning',\ - 'natal_kick_array': [17.125568, 4.101834, 0.917541, 3.961291]}) - star_2 = SingleStar(**{'mass' : 8.208376, 'state' : 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period': 66.870417, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(binary) - ######################################## - # FIXED oRLO2 looping - ######################################## - star_1 = SingleStar(**{'mass': 16.921378, 'state': 'H-rich_Core_H_burning',\ - 'natal_kick_array': [268.837139, 5.773527, 2.568105, 2.519068]}) - star_2 = SingleStar(**{'mass' : 16.286318, 'state' : 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period': 37.958768, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(binary) - ######################################## - # FIXED? step_detached failure - ######################################## - STAR1 = SingleStar(**{'mass': 19.787769, 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [24.464803, 0.666314, 1.954698, 5.598975]}) - STAR2 = SingleStar(**{'mass': 7.638741, 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':3007.865561, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # Disrupted binary - ######################################## - star_1 = SingleStar(**{'mass': 16.921378, 'state': 'H-rich_Core_H_burning',\ - 'natal_kick_array': [268.837139, 5.773527, 2.568105, 2.519068]}) - star_2 = SingleStar(**{'mass' : 16.286318, 'state' : 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(star_1, star_2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':3007.865561, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED Detached binary failure (low mass) - ######################################## - STAR1 = SingleStar(**{'mass': 9, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass':0.8, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':4513.150157, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED SN_TYPE = None crash - ######################################## - STAR1 = SingleStar(**{'mass': 17.782576, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass':3.273864, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':4513.150157, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED SN_TYPE errors - ######################################## - STAR1 = SingleStar(**{'mass': 6.782576, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass':3.273864, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':4513.150157, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED SN_TYPE errors - ######################################## - STAR1 = SingleStar(**{'mass': 40.638207, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [30.921294, 4.31745, 1.777768, 3.509656]}) - STAR2 = SingleStar(**{'mass':37.917852, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':2113.352736, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED ECSN errors? - ######################################## - STAR1 = SingleStar(**{'mass': 12.376778, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [80, 4.31745, 1.777768, 3.509656]}) - STAR2 = SingleStar(**{'mass': 9.711216, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':79.83702, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # Interpolator masses?? - ######################################## - STAR1 = SingleStar(**{'mass': 7.592921, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass':5.038679 , - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':5.537807, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # Interpolator masses? - ######################################## - star_1 = SingleStar(**{'mass': 38.741115, - 'state': 'H-rich_Core_H_burning',\ - 'natal_kick_array': [21.113771, 2.060135, 2.224789, 4.089729]}) - star_2 = SingleStar(**{'mass': 27.776178, - 'state': 'H-rich_Core_H_burning',\ - 'natal_kick_array': [282.712103, 0.296252, 1.628433, 5.623812]}) - - BINARY = BinaryStar(star_1, star_2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period': 93.387072, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED NaN spin - ######################################## - STAR1 = SingleStar(**{'mass': 70.066924, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0], - 'metallicity':1}) - STAR2 = SingleStar(**{'mass': 34.183110, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0], - 'metallicity':1}) - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':5.931492e+03, - 'separation': orbital_separation_from_period(5.931492e+03, STAR1.mass, STAR2.mass), - 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # FIXED NaN spin - ######################################## - STAR1 = SingleStar(**{'mass': 28.837286, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 6.874867, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':35.609894, - 'separation': orbital_separation_from_period(35.609894, STAR1.mass, STAR2.mass), - 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # oRLO2 issue - ######################################## - STAR1 = SingleStar(**{'mass':29.580210, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 28.814626, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':40.437993, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - - ######################################## - # oRLO2 issue - ######################################## - STAR1 = SingleStar(**{'mass':67.126795, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 19.622908, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':1484.768582, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - - ######################################## - # oRLO2 issue - ######################################## - STAR1 = SingleStar(**{'mass': 58.947503, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 56.660506, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':2011.300659, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - - ######################################## - # oRLO2 issue - ######################################## - STAR1 = SingleStar(**{'mass': 170.638207, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[47.979957374424956, 5.317304576107798, 2.7259013166068145, 4.700929589520818]}) - STAR2 = SingleStar(**{'mass': 37.917852, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':113.352736, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - - ######################################## - # oRLO2 issue - ######################################## - STAR1 = SingleStar(**{'mass': 109.540207, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 84.344530, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':5.651896, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - - ######################################## - # redirect - ######################################## - STAR1 = SingleStar(**{'mass': 13.889634, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass':0.490231, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':14513.150157, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # redirect - ######################################## - STAR1 = SingleStar(**{'mass': 9, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass':0.8, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':4513.150157, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - ######################################## - # Max time - ######################################## - star_1 = SingleStar(**{'mass': 103.07996766780799, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.2965418610971261, 2.0789170290719117, 3.207488023705968]}) - star_2 = SingleStar(**{'mass': 83.66522615073987, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 1449.1101985875678,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # Max time - ######################################## - star_1 = SingleStar(**{'mass': 8.860934140643465, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [11.818027275431337, 2.812412688633058, 0.4998731824233789, 2.9272630485628643]}) - star_2 = SingleStar(**{'mass': 8.584716012668551, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - binary = BinaryStar(star_1, star_2, **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', - 'orbital_period': 20.82030114750744,'eccentricity': 0.0}, properties = sim_prop) - evolve_binary(binary) - ######################################## - # PR421 - ######################################## - STAR1 = SingleStar(**{'mass': 24.035366, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 23.187355, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':18.865029, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - - ######################################## - # CE class - ######################################## - STAR1 = SingleStar(**{'mass':33.964274, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 28.98149, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':82.370989, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - - ######################################## - # PR574 - stepCE fix - ######################################## - STAR1 = SingleStar(**{'mass':29.580210, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - STAR2 = SingleStar(**{'mass': 28.814626*0.4, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array':[0.0, 0.0, 0.0, 0.0]}) - - BINARY = BinaryStar(STAR1, STAR2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':300.437993, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(BINARY) - - ######################################## - # e_ZAMS error - ######################################## - star1 = SingleStar(**{'mass': 8.161885721822461, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - star2 = SingleStar(**{'mass': 3.5907829421526154, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - - binary = BinaryStar(star1, star2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period': 36.873457164644144, 'eccentricity': 0.0}, - properties = sim_prop) - evolve_binary(binary) - - ######################################## - # e_ZAMS error - ######################################## - star1 = SingleStar(**{'mass': 35.24755025317775, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [19.755993125895806, 0.37149222852233904, 1.6588846085306563, - 1.434617029858906]}) - star2 = SingleStar(**{'mass': 30.000450298072902, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - - binary = BinaryStar(star1, star2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period': 24060.02101364665, 'eccentricity': 0.8085077857996965}, - properties = sim_prop) - evolve_binary(binary) - - ######################################## - # e_ZAMS error - ######################################## - star1 = SingleStar(**{'mass': 11.862930493162692, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - star2 = SingleStar(**{'mass': 1.4739109294156703, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - - binary = BinaryStar(star1, star2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period': 4111.083887312003, 'eccentricity':0.0}, - properties = sim_prop) - evolve_binary(binary) - - ######################################## - # e_ZAMS error - ######################################## - star1 = SingleStar(**{'mass': 8.527361341212108, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - star2 = SingleStar(**{'mass': 0.7061748406821822, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - - binary = BinaryStar(star1, star2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period': 2521.1927287891444, 'eccentricity':0.0}, - properties = sim_prop) - evolve_binary(binary) - - ######################################## - # e_ZAMS error - ######################################## - star1 = SingleStar(**{'mass': 13.661942533447398 ,#29829485, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - star2 = SingleStar(**{'mass': 4.466151109802313 , #481560266, - 'state': 'H-rich_Core_H_burning', - 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}) - - binary = BinaryStar(star1, star2, - **{'time': 0.0, 'state': 'detached', 'event': 'ZAMS', 'orbital_period':3110.1346707516914, 'eccentricity':0.0}, - properties = sim_prop) - evolve_binary(binary) - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Evolve binaries for validation.') - parser.add_argument('--verbose', '-v', action='store_true', default=False, - help='Enable verbose output (default: False)') - args = parser.parse_args() - - evolve_binaries(verbose=args.verbose) diff --git a/dev-tools/script_data/baselines/.gitkeep b/dev-tools/script_data/baselines/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dev-tools/script_data/inlists/.gitkeep b/dev-tools/script_data/inlists/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dev-tools/script_data/logs/.gitkeep b/dev-tools/script_data/logs/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dev-tools/script_data/output/.gitkeep b/dev-tools/script_data/output/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dev-tools/script_data/output/binary_star_tests/.gitkeep b/dev-tools/script_data/output/binary_star_tests/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dev-tools/script_data/output/population_tests/.gitkeep b/dev-tools/script_data/output/population_tests/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dev-tools/script_data/src/binaries_suite.py b/dev-tools/script_data/src/binaries_suite.py new file mode 100644 index 0000000000..92c8982b86 --- /dev/null +++ b/dev-tools/script_data/src/binaries_suite.py @@ -0,0 +1,274 @@ +""" +Script to evolve a few binaries. +Used for validation of the branch. + +Author: Max Briel +""" + +import argparse +import os +import shutil +import signal +import warnings +from datetime import datetime, timezone + +import numpy as np +import pandas as pd +from binary_test_cases import get_test_binaries +from formatting import AVAILABLE_METALLICITIES, LINE_LENGTH +from utils import print_failed_binary, print_warnings, write_binary_to_screen + +from posydon.binary_evol.binarystar import BinaryStar, SingleStar +from posydon.binary_evol.simulationproperties import SimulationProperties +from posydon.config import PATH_TO_POSYDON, PATH_TO_POSYDON_DATA +from posydon.utils.common_functions import orbital_separation_from_period + + +def load_inlist(ini_path, metallicity, verbose): + + if ini_path is None: + # copy the .ini file from the POSYDON version installed in current environment + default_ini_fn = os.path.join(PATH_TO_POSYDON, "posydon/popsyn/population_params_default.ini") + test_ini_fn = os.path.join(PATH_TO_POSYDON, "dev-tools/script_data/inlists/test_params.ini") + ini_path = shutil.copyfile(default_ini_fn, test_ini_fn) + + print(f"Reading inlist: {ini_path}") + sim_prop = SimulationProperties.from_ini(ini_path) + + # TODO: try to create/pass RNG with a fixed seed + RNG = np.random.default_rng(0) + try: + sim_prop.load_steps(verbose=verbose, RNG=RNG, metallicity=metallicity) + except TypeError as e: + sim_prop.load_steps(verbose=verbose, metallicity=metallicity) + + return sim_prop + +def evolve_binary(binary, binary_id): + + # Capture warnings during evolution + captured_warnings = [] + + def warning_handler(message, category, filename, lineno, file=None, line=None): + captured_warnings.append({ + "binary_id": int(binary_id), + "category": category.__name__, + "message": str(message), + "filename": filename, + "lineno": lineno + }) + + print(f"Binary {binary_id}") + evolution_df = None + error_df = None + + # Set up warning capture + old_showwarning = warnings.showwarning + warnings.showwarning = warning_handler + + try: + binary.evolve() + # Display the evolution summary for successful evolution + write_binary_to_screen(binary) + evolution_df = binary.to_df(extra_columns={'step_names':'str'}) + + # Show warnings if any were captured + print_warnings(captured_warnings) + print("=" * LINE_LENGTH) + + except Exception as e: + + # turn off binary alarm in case of exception + signal.alarm(0) + + print_failed_binary(binary, e) + error_df = pd.DataFrame([{ + "binary_id": int(binary_id), + "exception_type": type(e).__name__, + "exception_message": str(e) + }]) + + # Show warnings if any were captured before the exception + print_warnings(captured_warnings) + + print("=" * LINE_LENGTH) + finally: + # Always turn off binary alarm and restore warning handler + signal.alarm(0) + warnings.showwarning = old_showwarning + + # Ensure we always have a dataframe + if evolution_df is not None: + # Decode bytes columns if needed + for col in evolution_df.select_dtypes([object]): + if evolution_df[col].apply(lambda x: isinstance(x, bytes)).any(): + evolution_df[col] = evolution_df[col].apply( + lambda x: x.decode('utf-8') if isinstance(x, bytes) else x + ) + + # Always ensure binary_id exists + if "binary_id" not in evolution_df.columns: + evolution_df["binary_id"] = int(binary_id) + + # Defragment the DataFrame from POSYDON's column-by-column construction + evolution_df = evolution_df.copy() + + # Save warnings + if captured_warnings: + print(f"⚠️ {len(captured_warnings)} warning(s) raised during evolution:") + for i, w in enumerate(captured_warnings[:3], 1): + print(f" {i}. {w['category']}: {w['message'][:80]}") + if len(captured_warnings) > 3: + print(f" ... and {len(captured_warnings) - 3} more warning(s)") + else: + print(f" No warning(s) raised during evolution") + + print(f"✅ Finished binary {binary_id}") + print("=" * LINE_LENGTH) + + return evolution_df, error_df, captured_warnings + +def create_binary(s1_kw, s2_kw, bin_kw, sim_prop): + + star_1 = SingleStar(**s1_kw) + star_2 = SingleStar(**s2_kw) + + # Add separation from period if not explicitly provided + if 'separation' not in bin_kw and 'orbital_period' in bin_kw: + bin_kw['separation'] = orbital_separation_from_period( + bin_kw['orbital_period'], star_1.mass, star_2.mass + ) + + return BinaryStar(star_1, star_2, **bin_kw, properties=sim_prop) + +def evolve_binaries(metallicity, output_path, verbose, ini_path=None, branch=None, sha=None): + """Evolves the test binary suite at the given metallicity and saves results. + + Args: + metallicity: float, metallicity in solar units + verbose: bool + output_path: str, path to save HDF5 output + ini_path: str, path to ini file (auto-detected if None) + """ + print(f"{'=' * LINE_LENGTH}") + print(f" Evolving test binaries at Z = {metallicity} Zsun") + print(f" Output: {output_path}") + print(f"{'=' * LINE_LENGTH}\n") + + sim_prop = load_inlist(ini_path, metallicity, verbose) + test_binaries = get_test_binaries(metallicity) + + # Collect all results in memory, then write once at the end. + # This avoids repeated HDFStore.append() calls, each of which + # reconciles schemas, checks string sizing, and flushes to disk. + all_evolution_dfs = [] + all_error_dfs = [] + all_warning_dfs = [] + + for binary_id, (s1_kw, s2_kw, bin_kw, description) in enumerate(test_binaries): + print(f"\n[{binary_id}/{len(test_binaries)-1}] {description}") + + binary = create_binary(s1_kw, s2_kw, bin_kw, sim_prop) + + evo_df, err_df, warn_list = evolve_binary(binary, binary_id) + + if evo_df is not None: + all_evolution_dfs.append(evo_df) + if err_df is not None: + all_error_dfs.append(err_df) + if warn_list: + all_warning_dfs.append(pd.DataFrame(warn_list)) + + # ── Completeness check ────────────────────────────────────────── + expected_ids = set(range(len(test_binaries))) + evolved_ids = set() + errored_ids = set() + + for df in all_evolution_dfs: + if 'binary_id' in df.columns: + evolved_ids.update(df['binary_id'].unique()) + for df in all_error_dfs: + if 'binary_id' in df.columns: + errored_ids.update(df['binary_id'].unique()) + + accounted_ids = evolved_ids | errored_ids + missing_ids = sorted(expected_ids - accounted_ids) + + if missing_ids: + print(f"\n⚠️ WARNING: {len(missing_ids)} binary(ies) unaccounted for: {missing_ids}") + print(f" These produced neither evolution output nor a caught error.") + + # ── Single-pass HDF5 write ────────────────────────────────────────── + with pd.HDFStore(output_path, mode="w") as h5file: + # Save metadata + meta_df = pd.DataFrame([{ + 'metallicity': metallicity, + 'n_binaries': len(test_binaries), + 'n_evolved': len(evolved_ids), + 'n_errored': len(errored_ids), + 'n_missing': len(missing_ids), + 'missing_ids': str(missing_ids) if missing_ids else '', + 'path_to_posydon_data': PATH_TO_POSYDON_DATA, + 'branch': branch or '', + 'commit_sha': sha or '', + 'generated_at': datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC'), + }]) + h5file.put("metadata", meta_df, format="table") + + if all_evolution_dfs: + combined_evo = pd.concat(all_evolution_dfs, ignore_index=True) + string_cols = combined_evo.select_dtypes([object]).columns + min_itemsize = {col: 500 for col in string_cols} + h5file.put("evolution", combined_evo, format="table", + data_columns=True, min_itemsize=min_itemsize) + + if all_error_dfs: + combined_err = pd.concat(all_error_dfs, ignore_index=True) + err_string_cols = combined_err.select_dtypes([object]).columns + err_min_itemsize = {col: 1000 for col in err_string_cols} + h5file.put("errors", combined_err, format="table", + min_itemsize=err_min_itemsize) + + if all_warning_dfs: + combined_warn = pd.concat(all_warning_dfs, ignore_index=True) + warn_string_cols = combined_warn.select_dtypes([object]).columns + warn_min_itemsize = {col: 1000 for col in warn_string_cols} + h5file.put("warnings", combined_warn, format="table", + min_itemsize=warn_min_itemsize) + + print(f"\n{'=' * LINE_LENGTH}") + print(f" All {len(test_binaries)} binaries complete. Results saved to {output_path}") + print(f"{'=' * LINE_LENGTH}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Evolve test binaries for POSYDON branch validation.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--verbose', '-v', action='store_true', default=False, + help='Enable verbose output') + parser.add_argument('--output', type=str, required=True, + help='Path to save HDF5 output') + parser.add_argument('--metallicity', '-Z', type=float, default=1.0, + help=f'Metallicity in solar units. Available: {AVAILABLE_METALLICITIES}') + parser.add_argument('--ini', type=str, default=None, + help='Path to params ini file (auto-detected if not given)') + parser.add_argument('--branch', type=str, default=None, + help='Branch name to record in HDF5 metadata') + parser.add_argument('--sha', type=str, default=None, + help='Commit SHA to record in HDF5 metadata') + args = parser.parse_args() + + if args.metallicity not in AVAILABLE_METALLICITIES: + print(f"WARNING: Metallicity {args.metallicity} not in standard set {AVAILABLE_METALLICITIES}.") + print(f"Proceeding anyway, but POSYDON grids may not exist for this value.") + + evolve_binaries( + metallicity=args.metallicity, + verbose=args.verbose, + output_path=args.output, + ini_path=args.ini, + branch=args.branch, + sha=args.sha, + ) diff --git a/dev-tools/script_data/src/binary_test_cases.py b/dev-tools/script_data/src/binary_test_cases.py new file mode 100644 index 0000000000..36444a57a0 --- /dev/null +++ b/dev-tools/script_data/src/binary_test_cases.py @@ -0,0 +1,403 @@ +def get_test_binaries(metallicity): + + Z = metallicity + + test_binaries = [ + # 0: Failing binary in matching + ({'mass': 11.948472796094759, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [231.97383621190582, 5.927334890264575, 1.5990566013567014, 6.137994236518587]}, + {'mass': 7.636958434479617, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 190925.99636740884, 'eccentricity': 0.0}, + "Failing binary in matching"), + + # 1: Failing binary in matching + ({'mass': 30.169861921689556, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [77.96834852144123, 0.05021460132555987, 2.3146518208348152, 1.733054979982291]}, + {'mass': 10.972734402996027, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 20479.71919353725, 'eccentricity': 0.0}, + "Failing binary in matching"), + + # 2: Flipped S1 and S2 (near-equal mass) + ({'mass': 9.474917413943635, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [133.5713935237759, 4.398754864537542, 2.703102872841114, 1.4633904612711142]}, + {'mass': 9.311073918196263, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 18.605997832086413, 'eccentricity': 0.0}, + "Flipped S1 and S2 (near-equal mass)"), + + # 3: Flipped S1 and S2 (high mass ratio) + ({'mass': 10.438541, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 1.400713, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 9.824025, 'eccentricity': 0.0}, + "Flipped S1 and S2 (high mass ratio)"), + + # 4: Flipped S1 and S2 + ({'mass': 9.845907, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}, + {'mass': 9.611029, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 3.820571, 'eccentricity': 0.0}, + "Flipped S1 and S2"), + + # 5: Normal binary evolution (high mass) + ({'mass': 30.845907, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}, + {'mass': 30.611029, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 30.820571, 'eccentricity': 0.0}, + "Normal binary evolution (high mass)"), + + # 6: Normal binary (wide orbit) + ({'mass': 9.213534679594247, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [327.5906384501521, 1.7707176050073297, 1.573225822966838, 1.6757313876001914]}, + {'mass': 7.209878522799272, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 63123.74544474666, 'eccentricity': 0.0}, + "Normal binary (wide orbit)"), + + # 7: Normal binary (near-equal mass, close) + ({'mass': 9.561158487732602, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [317.5423844462847, 2.9095984678057603, 1.754121288652108, 2.3693917842468784]}, + {'mass': 9.382732464319286, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 27.77657038557851, 'eccentricity': 0.0}, + "Normal binary (near-equal mass, close)"), + + # 8: Normal binary + ({'mass': 7.552858, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [40.91509926587841, 2.6295454150818256, 1.6718337470964977, 6.0408769315244895]}, + {'mass': 6.742063, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 17.957531550841225, 'eccentricity': 0.0}, + "Normal binary"), + + # 9: High BH spin options + ({'mass': 31.616785, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [10, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}, + {'mass': 26.874267, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 501.99252706449792, 'eccentricity': 0.0}, + "High BH spin options"), + + # 10: Original a>1 spin error + ({'mass': 18.107506844123645, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [528.2970725443025, 4.190728383757787, 1.1521129697118118, 5.015343794234789]}, + {'mass': 15.641392951875442, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 151.99252706449792, 'eccentricity': 0.0}, + "Original a>1 spin error"), + + # 11: FIXED disrupted crash + ({'mass': 52.967313, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 36.306444, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 12.877004, 'eccentricity': 0.0}, + "FIXED disrupted crash"), + + # 12: FIXED error with SN type + ({'mass': 17.782576, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 3.273864, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 4513.150157, 'eccentricity': 0.0}, + "FIXED error with SN type"), + + # 13: FIXED oRLO2 looping + ({'mass': 170.638207, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [4.921294, 4.31745, 1.777768, 3.509656]}, + {'mass': 37.917852, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 113.352736, 'eccentricity': 0.0}, + "FIXED oRLO2 looping"), + + # 14: Redirect to step_CO_HeMS + ({'mass': 8.333579, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [17.125568, 4.101834, 0.917541, 3.961291]}, + {'mass': 8.208376, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 66.870417, 'eccentricity': 0.0}, + "Redirect to step_CO_HeMS"), + + # 15: FIXED oRLO2 looping + ({'mass': 16.921378, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [268.837139, 5.773527, 2.568105, 2.519068]}, + {'mass': 16.286318, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 37.958768, 'eccentricity': 0.0}, + "FIXED oRLO2 looping"), + + # 16: FIXED step_detached failure + ({'mass': 19.787769, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [24.464803, 0.666314, 1.954698, 5.598975]}, + {'mass': 7.638741, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 3007.865561, 'eccentricity': 0.0}, + "FIXED step_detached failure"), + + # 17: Disrupted binary + ({'mass': 16.921378, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [268.837139, 5.773527, 2.568105, 2.519068]}, + {'mass': 16.286318, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 3007.865561, 'eccentricity': 0.0}, + "Disrupted binary"), + + # 18: FIXED Detached binary failure (low mass) + ({'mass': 9, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 0.8, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 4513.150157, 'eccentricity': 0.0}, + "FIXED Detached binary failure (low mass)"), + + # 19: FIXED SN_TYPE = None crash + ({'mass': 17.782576, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 3.273864, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 4513.150157, 'eccentricity': 0.0}, + "FIXED SN_TYPE = None crash"), + + # 20: FIXED SN_TYPE errors (low mass primary) + ({'mass': 6.782576, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 3.273864, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 4513.150157, 'eccentricity': 0.0}, + "FIXED SN_TYPE errors (low mass primary)"), + + # 21: FIXED SN_TYPE errors (high mass) + ({'mass': 40.638207, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [30.921294, 4.31745, 1.777768, 3.509656]}, + {'mass': 37.917852, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 2113.352736, 'eccentricity': 0.0}, + "FIXED SN_TYPE errors (high mass)"), + + # 22: FIXED ECSN errors + ({'mass': 12.376778, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [80, 4.31745, 1.777768, 3.509656]}, + {'mass': 9.711216, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 79.83702, 'eccentricity': 0.0}, + "FIXED ECSN errors"), + + # 23: Interpolator masses (close) + ({'mass': 7.592921, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 5.038679, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 5.537807, 'eccentricity': 0.0}, + "Interpolator masses (close)"), + + # 24: Interpolator masses (both kicked) + ({'mass': 38.741115, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [21.113771, 2.060135, 2.224789, 4.089729]}, + {'mass': 27.776178, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [282.712103, 0.296252, 1.628433, 5.623812]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 93.387072, 'eccentricity': 0.0}, + "Interpolator masses (both kicked)"), + + # 25: FIXED NaN spin (very high mass, wide) + ({'mass': 70.066924, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 34.183110, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 5.931492e+03, 'eccentricity': 0.0}, + "FIXED NaN spin (very high mass, wide)"), + + # 26: FIXED NaN spin + ({'mass': 28.837286, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 6.874867, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 35.609894, 'eccentricity': 0.0}, + "FIXED NaN spin"), + + # 27: oRLO2 issue (near-equal mass) + ({'mass': 29.580210, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 28.814626, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 40.437993, 'eccentricity': 0.0}, + "oRLO2 issue (near-equal mass)"), + + # 28: oRLO2 issue (high mass ratio) + ({'mass': 67.126795, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 19.622908, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 1484.768582, 'eccentricity': 0.0}, + "oRLO2 issue (high mass ratio)"), + + # 29: oRLO2 issue (very high mass, near-equal) + ({'mass': 58.947503, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 56.660506, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 2011.300659, 'eccentricity': 0.0}, + "oRLO2 issue (very high mass, near-equal)"), + + # 30: oRLO2 issue (extreme mass, kicked) + ({'mass': 170.638207, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [47.979957374424956, 5.317304576107798, 2.7259013166068145, 4.700929589520818]}, + {'mass': 37.917852, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 113.352736, 'eccentricity': 0.0}, + "oRLO2 issue (extreme mass, kicked)"), + + # 31: oRLO2 issue (very high mass, close) + ({'mass': 109.540207, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 84.344530, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 5.651896, 'eccentricity': 0.0}, + "oRLO2 issue (very high mass, close)"), + + # 32: Redirect (extreme mass ratio) + ({'mass': 13.889634, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 0.490231, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 14513.150157, 'eccentricity': 0.0}, + "Redirect (extreme mass ratio)"), + + # 33: Redirect (low mass secondary) + ({'mass': 9, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 0.8, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 4513.150157, 'eccentricity': 0.0}, + "Redirect (low mass secondary)"), + + # 34: Max time (very high mass, wide) + ({'mass': 103.07996766780799, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.2965418610971261, 2.0789170290719117, 3.207488023705968]}, + {'mass': 83.66522615073987, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 1449.1101985875678, 'eccentricity': 0.0}, + "Max time (very high mass, wide)"), + + # 35: Max time + ({'mass': 8.860934140643465, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [11.818027275431337, 2.812412688633058, 0.4998731824233789, 2.9272630485628643]}, + {'mass': 8.584716012668551, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 20.82030114750744, 'eccentricity': 0.0}, + "Max time"), + + # 36: PR421 + ({'mass': 24.035366, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 23.187355, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 18.865029, 'eccentricity': 0.0}, + "PR421"), + + # 37: CE class + ({'mass': 33.964274, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 28.98149, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 82.370989, 'eccentricity': 0.0}, + "CE class"), + + # 38: PR574 - stepCE fix + ({'mass': 29.580210, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 28.814626 * 0.4, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 300.437993, 'eccentricity': 0.0}, + "PR574 - stepCE fix"), + + # 39: e_ZAMS error + ({'mass': 8.161885721822461, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 3.5907829421526154, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 36.873457164644144, 'eccentricity': 0.0}, + "e_ZAMS error"), + + # 40: e_ZAMS error (eccentric) + ({'mass': 35.24755025317775, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [19.755993125895806, 0.37149222852233904, 1.6588846085306563, 1.434617029858906]}, + {'mass': 30.000450298072902, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 24060.02101364665, 'eccentricity': 0.8085077857996965}, + "e_ZAMS error (eccentric)"), + + # 41: e_ZAMS error (high mass ratio) + ({'mass': 11.862930493162692, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 1.4739109294156703, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 4111.083887312003, 'eccentricity': 0.0}, + "e_ZAMS error (high mass ratio)"), + + # 42: e_ZAMS error (extreme mass ratio) + ({'mass': 8.527361341212108, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 0.7061748406821822, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 2521.1927287891444, 'eccentricity': 0.0}, + "e_ZAMS error (extreme mass ratio)"), + + # 43: e_ZAMS error + ({'mass': 13.661942533447398, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'mass': 4.466151109802313, 'state': 'H-rich_Core_H_burning', 'metallicity': Z, + 'natal_kick_array': [0.0, 0.0, 0.0, 0.0]}, + {'time': 0.0, 'state': 'detached', 'event': 'ZAMS', + 'orbital_period': 3110.1346707516914, 'eccentricity': 0.0}, + "e_ZAMS error"), + ] + + return test_binaries diff --git a/dev-tools/script_data/src/compare_runs.py b/dev-tools/script_data/src/compare_runs.py new file mode 100644 index 0000000000..c815092228 --- /dev/null +++ b/dev-tools/script_data/src/compare_runs.py @@ -0,0 +1,501 @@ +#!/usr/bin/env python3 +""" +Compare evolution outcomes of test binaries saved by binaries_suite.py. + +Reports three categories of differences: + 1. QUANTITATIVE: any numeric difference beyond floating-point representation + 2. QUALITATIVE: changes to categorical/string columns (states, events, step names, SN types, etc.) + 3. WARNINGS & ERRORS: changes to warnings raised or binaries that error out + +The report header includes provenance metadata (branch, commit SHA, generation +time, POSYDON data path) read from each HDF5 file's /metadata table when +available. Older files without these fields are handled gracefully. + +By default, uses exact comparison (atol=0, rtol=0). The --loose flag enables +a small tolerance for cases where minor floating-point differences are expected. + +Usage: + python compare_runs.py baseline.h5 candidate.h5 + python compare_runs.py baseline.h5 candidate.h5 --verbose + python compare_runs.py baseline.h5 candidate.h5 --loose + +Authors: Elizabeth Teng +""" + +import argparse +import os +import sys + +import numpy as np +import pandas as pd + +# Columns that represent qualitative (categorical) evolution properties. +# Any column matching these names will be compared as exact string matches +# and reported under "QUALITATIVE" differences. +QUALITATIVE_COLUMNS = { + 'state', 'event', 'step_names', 'S1_state', 'S2_state', + 'SN_type', 'S1_SN_type', 'S2_SN_type', + 'interp_class_HMS_HMS', 'interp_class_CO_HMS_RLO', + 'interp_class_CO_HeMS', 'interp_class_CO_HeMS_RLO', + 'mt_history_HMS_HMS', 'mt_history_CO_HMS_RLO', + 'mt_history_CO_HeMS', 'mt_history_CO_HeMS_RLO', + 'mass_transfer_case', +} + + +def classify_column(col, dtype): + """Classify a column as 'qualitative' or 'quantitative'.""" + if col in QUALITATIVE_COLUMNS: + return 'qualitative' + if pd.api.types.is_numeric_dtype(dtype): + return 'quantitative' + # Catch-all: treat remaining object/string columns as qualitative + return 'qualitative' + + +def compare_evolution_tables(base_df, cand_df, rtol, atol, + base_error_ids=None, cand_error_ids=None): + """Compare two evolution DataFrames, reporting per-binary diffs. + + Args: + base_error_ids: set of binary IDs that errored in the baseline run. + cand_error_ids: set of binary IDs that errored in the candidate run. + Binaries present in these sets are excluded from MISSING/EXTRA + reporting here, since they are already covered by compare_errors_tables. + + Returns: + dict with keys 'quantitative', 'qualitative', 'structural' + each mapping to a list of diff strings. + """ + quant_diffs = [] + qual_diffs = [] + struct_diffs = [] + + # Check that binary_id columns exist + if 'binary_id' not in base_df.columns or 'binary_id' not in cand_df.columns: + struct_diffs.append("'binary_id' column missing; cannot do per-binary comparison") + return {'quantitative': quant_diffs, 'qualitative': qual_diffs, 'structural': struct_diffs} + + base_ids = set(base_df['binary_id'].unique()) + cand_ids = set(cand_df['binary_id'].unique()) + + # Missing/extra binaries (excluding those already reported under errors) + for bid in sorted(base_ids - cand_ids): + if bid in cand_error_ids: + continue # candidate errored; reported by compare_errors_tables + struct_diffs.append(f"Binary {bid}: MISSING in candidate") + for bid in sorted(cand_ids - base_ids): + if bid in base_error_ids: + continue # baseline errored; reported by compare_errors_tables + struct_diffs.append(f"Binary {bid}: EXTRA in candidate") + + common_ids = sorted(base_ids & cand_ids) + + for bid in common_ids: + b = base_df[base_df['binary_id'] == bid].reset_index(drop=True) + c = cand_df[cand_df['binary_id'] == bid].reset_index(drop=True) + + # ── Step count ──────────────────────────────────────────────── + if len(b) != len(c): + struct_diffs.append( + f"Binary {bid}: evolution step count differs " + f"(baseline={len(b)}, candidate={len(c)})" + ) + + # ── Column presence ─────────────────────────────────────────── + base_only_cols = set(b.columns) - set(c.columns) - {'binary_id'} + cand_only_cols = set(c.columns) - set(b.columns) - {'binary_id'} + if base_only_cols: + struct_diffs.append(f"Binary {bid}: columns only in baseline: {sorted(base_only_cols)}") + if cand_only_cols: + struct_diffs.append(f"Binary {bid}: columns only in candidate: {sorted(cand_only_cols)}") + + # ── Per-column comparison ───────────────────────────────────── + common_cols = sorted(set(b.columns) & set(c.columns) - {'binary_id'}) + min_rows = min(len(b), len(c)) + + for col in common_cols: + b_col = b[col].iloc[:min_rows] + c_col = c[col].iloc[:min_rows] + col_type = classify_column(col, b_col.dtype) + + if col_type == 'quantitative': + b_arr = b_col.to_numpy(dtype=float) + c_arr = c_col.to_numpy(dtype=float) + + # NaN handling: both NaN = match, one NaN = mismatch + both_nan = np.isnan(b_arr) & np.isnan(c_arr) + one_nan = np.isnan(b_arr) ^ np.isnan(c_arr) + + if one_nan.any(): + nan_steps = np.where(one_nan)[0].tolist() + direction = [] + for s in nan_steps[:5]: + bv = "NaN" if np.isnan(b_arr[s]) else f"{b_arr[s]:.6g}" + cv = "NaN" if np.isnan(c_arr[s]) else f"{c_arr[s]:.6g}" + direction.append(f"step {s}: {bv} -> {cv}") + quant_diffs.append( + f"Binary {bid}, '{col}': NaN mismatch at {len(nan_steps)} step(s): " + + "; ".join(direction) + ) + + # Compare non-NaN values + valid = ~(np.isnan(b_arr) | np.isnan(c_arr)) + if valid.any(): + b_valid = b_arr[valid] + c_valid = c_arr[valid] + not_equal = b_valid != c_valid + + if rtol == 0 and atol == 0: + # Exact comparison + if not_equal.any(): + diff_indices = np.where(valid)[0][not_equal] + abs_diff = np.abs(b_valid[not_equal] - c_valid[not_equal]) + worst = np.argmax(abs_diff) + worst_step = diff_indices[worst] + quant_diffs.append( + f"Binary {bid}, '{col}': {not_equal.sum()} value(s) differ. " + f"Largest abs diff = {abs_diff[worst]:.6e} " + f"at step {worst_step} " + f"(baseline={b_valid[not_equal][worst]:.15g}, " + f"candidate={c_valid[not_equal][worst]:.15g})" + ) + else: + # Tolerance-based comparison + if not np.allclose(b_valid, c_valid, rtol=rtol, atol=atol): + abs_diff = np.abs(b_valid - c_valid) + with np.errstate(divide='ignore', invalid='ignore'): + denom = np.maximum(np.abs(b_valid), atol) + rel_diff = abs_diff / denom + worst = np.argmax(abs_diff) + worst_step = np.where(valid)[0][worst] + quant_diffs.append( + f"Binary {bid}, '{col}': numeric mismatch " + f"(max abs diff = {abs_diff[worst]:.6e}, " + f"max rel diff = {rel_diff[worst]:.6e}, " + f"at step {worst_step}, " + f"baseline={b_valid[worst]:.15g}, " + f"candidate={c_valid[worst]:.15g})" + ) + + else: + # Qualitative comparison: exact string match + b_str = b_col.astype(str).values + c_str = c_col.astype(str).values + mismatches = np.where(b_str != c_str)[0] + if len(mismatches) > 0: + details = [] + for s in mismatches[:5]: + details.append(f"step {s}: '{b_str[s]}' -> '{c_str[s]}'") + qual_diffs.append( + f"Binary {bid}, '{col}': {len(mismatches)} step(s) differ: " + + "; ".join(details) + ) + + return {'quantitative': quant_diffs, 'qualitative': qual_diffs, 'structural': struct_diffs} + + +def compare_warnings_tables(base_df, cand_df): + """Compare warning tables between baseline and candidate. + + Returns list of diff strings. + """ + diffs = [] + + if base_df is None and cand_df is None: + return diffs + if base_df is None: + diffs.append(f"Candidate has {len(cand_df)} warning(s), baseline has none") + return diffs + if cand_df is None: + diffs.append(f"Baseline has {len(base_df)} warning(s), candidate has none") + return diffs + + if len(base_df) != len(cand_df): + diffs.append(f"Total warning count differs (baseline={len(base_df)}, candidate={len(cand_df)})") + + # Per-binary warning comparison + if 'binary_id' in base_df.columns and 'binary_id' in cand_df.columns: + base_grouped = base_df.groupby('binary_id') + cand_grouped = cand_df.groupby('binary_id') + all_ids = sorted(set(base_df['binary_id'].unique()) | set(cand_df['binary_id'].unique())) + + for bid in all_ids: + b_warnings = base_grouped.get_group(bid) if bid in base_grouped.groups else pd.DataFrame() + c_warnings = cand_grouped.get_group(bid) if bid in cand_grouped.groups else pd.DataFrame() + + b_count = len(b_warnings) + c_count = len(c_warnings) + + if b_count == 0 and c_count > 0: + cats = c_warnings['category'].unique().tolist() if 'category' in c_warnings.columns else ['unknown'] + diffs.append(f"Binary {bid}: {c_count} NEW warning(s) in candidate ({', '.join(str(c) for c in cats)})") + elif b_count > 0 and c_count == 0: + diffs.append(f"Binary {bid}: {b_count} warning(s) REMOVED in candidate") + elif b_count != c_count: + diffs.append(f"Binary {bid}: warning count changed ({b_count} -> {c_count})") + elif b_count > 0: + # Same count — check if warning categories or messages changed + if 'category' in b_warnings.columns and 'category' in c_warnings.columns: + b_cats = sorted(b_warnings['category'].astype(str).tolist()) + c_cats = sorted(c_warnings['category'].astype(str).tolist()) + if b_cats != c_cats: + diffs.append(f"Binary {bid}: warning categories changed ({b_cats} -> {c_cats})") + + if 'message' in b_warnings.columns and 'message' in c_warnings.columns: + b_msgs = sorted(b_warnings['message'].astype(str).tolist()) + c_msgs = sorted(c_warnings['message'].astype(str).tolist()) + if b_msgs != c_msgs: + diffs.append(f"Binary {bid}: warning messages changed") + + return diffs + + +def compare_errors_tables(base_df, cand_df): + """Compare error tables between baseline and candidate. + + Returns list of diff strings. + """ + diffs = [] + + if base_df is None and cand_df is None: + return diffs + if base_df is None and cand_df is not None: + cand_ids = sorted(cand_df['binary_id'].unique()) if 'binary_id' in cand_df.columns else [] + diffs.append(f"Candidate has {len(cand_df)} error(s) (binaries {cand_ids}), baseline has none") + return diffs + if base_df is not None and cand_df is None: + base_ids = sorted(base_df['binary_id'].unique()) if 'binary_id' in base_df.columns else [] + diffs.append(f"Baseline has {len(base_df)} error(s) (binaries {base_ids}), candidate has none") + return diffs + + # Compare per-binary errors + if 'binary_id' in base_df.columns and 'binary_id' in cand_df.columns: + base_ids = set(base_df['binary_id'].unique()) + cand_ids = set(cand_df['binary_id'].unique()) + + for bid in sorted(cand_ids - base_ids): + row = cand_df[cand_df['binary_id'] == bid].iloc[0] + exc = row.get('exception_type', 'unknown') + diffs.append(f"Binary {bid}: NEWLY FAILING in candidate ({exc})") + + for bid in sorted(base_ids - cand_ids): + diffs.append(f"Binary {bid}: NEWLY PASSING in candidate (was failing in baseline)") + + for bid in sorted(base_ids & cand_ids): + b_row = base_df[base_df['binary_id'] == bid].iloc[0] + c_row = cand_df[cand_df['binary_id'] == bid].iloc[0] + b_exc = str(b_row.get('exception_type', '')) + c_exc = str(c_row.get('exception_type', '')) + b_msg = str(b_row.get('exception_message', '')) + c_msg = str(c_row.get('exception_message', '')) + if b_exc != c_exc: + diffs.append(f"Binary {bid}: error type changed ('{b_exc}' -> '{c_exc}')") + if b_msg != c_msg: + diffs.append(f"Binary {bid}: error message changed") + + return diffs + + +def read_table_safe(store, key): + """Read a table from HDFStore, returning None if it doesn't exist.""" + try: + if key in store: + return store[key] + except Exception: + pass + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Compare baseline and candidate binary evolution HDF5 files.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +By default, uses EXACT comparison (any numeric difference is reported). +Use --loose to allow small floating-point tolerances (rtol=1e-12, atol=1e-15). + """, + ) + parser.add_argument("baseline", help="Path to baseline HDF5 file") + parser.add_argument("candidate", help="Path to candidate HDF5 file") + parser.add_argument("--loose", action="store_true", + help="Allow small floating-point tolerance (rtol=1e-12, atol=1e-15)") + parser.add_argument("--rtol", type=float, default=None, + help="Override relative tolerance (default: 0, or 1e-12 with --loose)") + parser.add_argument("--atol", type=float, default=None, + help="Override absolute tolerance (default: 0, or 1e-15 with --loose)") + parser.add_argument("--verbose", "-v", action="store_true", + help="Print extra diagnostic info") + args = parser.parse_args() + + # Set tolerances + if args.loose: + rtol = args.rtol if args.rtol is not None else 1e-12 + atol = args.atol if args.atol is not None else 1e-15 + else: + rtol = args.rtol if args.rtol is not None else 0.0 + atol = args.atol if args.atol is not None else 0.0 + + for f in [args.baseline, args.candidate]: + if not os.path.exists(f): + print(f"ERROR: File not found: {f}", file=sys.stderr) + sys.exit(2) + + quant_diffs = [] + qual_diffs = [] + struct_diffs = [] + warn_diffs = [] + + try: + with pd.HDFStore(args.baseline, mode='r') as base_store, \ + pd.HDFStore(args.candidate, mode='r') as cand_store: + + base_keys = set(base_store.keys()) + cand_keys = set(cand_store.keys()) + + if args.verbose: + print(f"Baseline keys: {sorted(base_keys)}") + print(f"Candidate keys: {sorted(cand_keys)}") + + # ── Errors table (read early so IDs are available for evolution comparison) + base_err = read_table_safe(base_store, '/errors') + cand_err = read_table_safe(cand_store, '/errors') + + base_error_ids = set(base_err['binary_id'].unique()) \ + if base_err is not None and 'binary_id' in base_err.columns else set() + cand_error_ids = set(cand_err['binary_id'].unique()) \ + if cand_err is not None and 'binary_id' in cand_err.columns else set() + + # ── Evolution table ─────────────────────────────────────── + base_evol = read_table_safe(base_store, '/evolution') + cand_evol = read_table_safe(cand_store, '/evolution') + + if base_evol is None and cand_evol is None: + struct_diffs.append("Neither file contains an 'evolution' table") + elif base_evol is None: + struct_diffs.append("Baseline missing 'evolution' table") + elif cand_evol is None: + struct_diffs.append("Candidate missing 'evolution' table") + else: + if args.verbose: + n_base = base_evol['binary_id'].nunique() if 'binary_id' in base_evol.columns else '?' + n_cand = cand_evol['binary_id'].nunique() if 'binary_id' in cand_evol.columns else '?' + print(f"Baseline: {n_base} binaries, {len(base_evol)} total rows") + print(f"Candidate: {n_cand} binaries, {len(cand_evol)} total rows") + + evol_results = compare_evolution_tables(base_evol, cand_evol, rtol, atol, + base_error_ids, cand_error_ids) + quant_diffs.extend(evol_results['quantitative']) + qual_diffs.extend(evol_results['qualitative']) + struct_diffs.extend(evol_results['structural']) + + # ── Warnings table ──────────────────────────────────────── + base_warn = read_table_safe(base_store, '/warnings') + cand_warn = read_table_safe(cand_store, '/warnings') + warn_diffs.extend(compare_warnings_tables(base_warn, cand_warn)) + + # ── Errors table (comparison) ────────────────────────────────────────── + error_diffs = compare_errors_tables(base_err, cand_err) + struct_diffs.extend(error_diffs) + + # ── Extra/missing top-level keys ────────────────────────── + ignored_keys = {'/evolution', '/warnings', '/errors', '/metadata'} + for k in sorted(base_keys - cand_keys): + if k not in ignored_keys: + struct_diffs.append(f"Table '{k}' missing in candidate") + for k in sorted(cand_keys - base_keys): + if k not in ignored_keys: + struct_diffs.append(f"Table '{k}' extra in candidate") + + except Exception as e: + print(f"ERROR reading HDF5 files: {e}", file=sys.stderr) + sys.exit(2) + + # ── Report ──────────────────────────────────────────────────────────── + total_diffs = len(quant_diffs) + len(qual_diffs) + len(struct_diffs) + len(warn_diffs) + tol_label = f"rtol={rtol}, atol={atol}" if rtol > 0 or atol > 0 else "EXACT (rtol=0, atol=0)" + + # Read metadata for the report header + def _read_meta(filepath): + """Extract metadata fields from an HDF5 file, returning a dict.""" + meta = {} + try: + with pd.HDFStore(filepath, mode='r') as store: + if '/metadata' in store: + m = store['/metadata'] + for field in ('branch', 'commit_sha', 'generated_at', + 'path_to_posydon_data'): + if field in m.columns: + meta[field] = str(m[field].iloc[0]) + except Exception: + pass + return meta + + base_meta = _read_meta(args.baseline) + cand_meta = _read_meta(args.candidate) + + print("=" * 70) + print("POSYDON Binary Validation — Comparison Report") + print(f" Baseline: {args.baseline}") + if base_meta: + if base_meta.get('branch'): + print(f" Branch: {base_meta['branch']}") + if base_meta.get('commit_sha'): + print(f" SHA: {base_meta['commit_sha']}") + if base_meta.get('generated_at'): + print(f" Generated: {base_meta['generated_at']}") + if base_meta.get('path_to_posydon_data'): + print(f" Data path: {base_meta['path_to_posydon_data']}") + print(f" Candidate: {args.candidate}") + if cand_meta: + if cand_meta.get('branch'): + print(f" Branch: {cand_meta['branch']}") + if cand_meta.get('commit_sha'): + print(f" SHA: {cand_meta['commit_sha']}") + if cand_meta.get('generated_at'): + print(f" Generated: {cand_meta['generated_at']}") + if cand_meta.get('path_to_posydon_data'): + print(f" Data path: {cand_meta['path_to_posydon_data']}") + print(f" Tolerances: {tol_label}") + print("=" * 70) + + if struct_diffs: + print(f"\n--- STRUCTURAL ({len(struct_diffs)}) ---") + print(" (missing/extra binaries, step count changes, newly failing/passing, errors)\n") + for d in struct_diffs: + print(f" - {d}") + + if qual_diffs: + print(f"\n--- QUALITATIVE ({len(qual_diffs)}) ---") + print(" (state, event, step name, SN type, interpolation class changes)\n") + for d in qual_diffs: + print(f" - {d}") + + if quant_diffs: + print(f"\n--- QUANTITATIVE ({len(quant_diffs)}) ---") + print(" (any numeric value change)\n") + for d in quant_diffs: + print(f" - {d}") + + if warn_diffs: + print(f"\n--- WARNINGS ({len(warn_diffs)}) ---") + print(" (new, removed, or changed warnings)\n") + for d in warn_diffs: + print(f" - {d}") + + print("\n" + "=" * 70) + if total_diffs == 0: + print("RESULT: IDENTICAL — candidate matches baseline exactly.") + sys.exit(0) + else: + print(f"RESULT: {total_diffs} DIFFERENCE(S) DETECTED") + print(f" Structural: {len(struct_diffs)}") + print(f" Qualitative: {len(qual_diffs)}") + print(f" Quantitative: {len(quant_diffs)}") + print(f" Warnings: {len(warn_diffs)}") + print("=" * 70) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/dev-tools/script_data/src/formatting.py b/dev-tools/script_data/src/formatting.py new file mode 100644 index 0000000000..61a34fccfb --- /dev/null +++ b/dev-tools/script_data/src/formatting.py @@ -0,0 +1,4 @@ +target_rows = 12 +LINE_LENGTH = 140 +columns_to_show = ['step_names', 'state', 'event', 'S1_state', 'S1_mass', 'S2_state', 'S2_mass', 'orbital_period'] +AVAILABLE_METALLICITIES = [2., 1., 0.45, 0.2, 0.1, 0.01, 0.001, 0.0001] diff --git a/dev-tools/script_data/src/popsynth_suite.py b/dev-tools/script_data/src/popsynth_suite.py new file mode 100644 index 0000000000..97bdbb172f --- /dev/null +++ b/dev-tools/script_data/src/popsynth_suite.py @@ -0,0 +1,224 @@ +import argparse +import os +import shutil +import subprocess +import traceback +import warnings + +from formatting import LINE_LENGTH, columns_to_show +from pandas.testing import assert_frame_equal +from utils import print_pop_settings, print_warnings + +from posydon.config import PATH_TO_POSYDON +from posydon.popsyn.binarypopulation import BinaryPopulation +from posydon.popsyn.synthetic_population import Population, PopulationRunner + +script_dir = os.path.dirname(os.path.abspath(__file__)) + +def test_binpop_evolve(population, popevo_kwargs, verbose=False): + + # this function runs a pop.evolve test given a set of kwargs + + # Capture warnings during evolution + captured_warnings = [] + + def warning_handler(message, category, filename, lineno, file=None, line=None): + captured_warnings.append({ + 'message': str(message), + 'category': category.__name__, + 'filename': filename, + 'lineno': lineno + }) + + # Set up warning capture + old_showwarning = warnings.showwarning + warnings.showwarning = warning_handler + + try: + + print("Running BinaryPopulation.evolve() with settings:") + for key, val in popevo_kwargs.items(): + print(f"\t {key} : {val}") + + print("\nEvolving BinaryPopulation...\n") + population.evolve(**popevo_kwargs) + + print_warnings(captured_warnings) + print("✅ BinaryPopulation evolved successfully.") + print("=" * LINE_LENGTH) + + except Exception as e: + print_warnings(captured_warnings) + print(f"🚨 BinaryPopulation evolution failed!\n") + traceback.print_exc(limit=3) + print("\n") + print("=" * LINE_LENGTH) + + return population + +def compare_io_to_ram(loaded_pop, pop_in_ram): + + # this function compares binaries saved/loaded to any stored in RAM + + # check that binaries match between pop runs w/ fixed entropy + # and that saved/loaded binaries match those from a memory loaded run + N = pop_in_ram.number_of_binaries + df_from_ram = pop_in_ram.to_df() + ram_dflist = [df_from_ram.loc[i] for i in range(N)] + print("🔍 Checking that binaries in RAM match those retrieved from I/O...") + for i, ram_df in enumerate(ram_dflist): + io_df = loaded_pop.history[i] + ram_df = ram_df[columns_to_show] + io_df = io_df[columns_to_show] + try: + assert_frame_equal(ram_df, io_df) + except AssertionError as e: + print("🚨 A binary from I/O does not equal the same binary stored in RAM:") + print(e) + print("\nBinary in RAM:\n", ram_df) + print("\nBinary from I/O:\n", io_df) + print("=" * LINE_LENGTH) + return + if i == len(loaded_pop.history): + break + + print("✅ Binaries from I/O match those in RAM.") + print("=" * LINE_LENGTH) + +def print_testinfo(test_title, population, popevo_kwargs): + + # prints some info about the tests + + # print test title str + numchar = (LINE_LENGTH - len(test_title)) // 2 + print("=" * numchar + test_title + "=" * numchar) + + optimize_ram = popevo_kwargs.get("optimize_ram", False) + breakdown_to_df = popevo_kwargs.get("breakdown_to_df", False) + N_binaries = population.number_of_binaries + + # print expected behavior based on settings + if not breakdown_to_df and not optimize_ram: + print(f"🚀 Evolving a population and storing {N_binaries} binaries in RAM.") + elif breakdown_to_df and not optimize_ram: + print("🚀 Evolving a population and saving binaries to a hdf5 file.") + elif optimize_ram and not breakdown_to_df: + num_batch_files = population.number_of_binaries // population.kwargs["dump_rate"] + print(f"🚀 Evolving a population and saving to {num_batch_files} batch files.") + +def check_test(pop_in_ram, out_path, load_pop=False): + + # checks that a test went OK + + # if we have population in RAM, check that the number + # stored in RAM matches the number we expected to run + if pop_in_ram and not load_pop: + num_binaries = pop_in_ram.number_of_binaries + num_in_ram = len(pop_in_ram.manager.binaries) + print(f"🔍 Checking that we have {num_binaries} binaries in RAM...") + assert num_binaries == num_in_ram, \ + f"🚨 Number of binaries in RAM ({num_in_ram}) " \ + f"does not equal the number specified to run ({num_binaries})." + print(f"✅ Successfully ran and stored {num_in_ram} binaries in RAM.") + + elif pop_in_ram and load_pop: + save_fn = os.path.join(out_path, "batches", "evolution.combined.h5") + loaded_pop = Population(save_fn) + compare_io_to_ram(loaded_pop, pop_in_ram) + +def test_popruns(ini_path, multiz_path, out_path, verbose): + + # primary function + + print("Performing population run tests...") + print(f"Reading inlist: {ini_path}") + pop = BinaryPopulation.from_ini(ini_path, verbose=verbose) + pop.kwargs.update({"temp_directory": os.path.join(out_path, "batches")}) + print_pop_settings(pop) + + # DO TESTS: + + # test simple run, stays in RAM + kwargs = {"optimize_ram":False, "breakdown_to_df":False, "tqdm":True} + print_testinfo("TEST: 01", pop, kwargs) + pop_in_ram = test_binpop_evolve(pop, kwargs, verbose=verbose) + check_test(pop_in_ram, out_path, load_pop=False) + + # test same but w/ saving/loading binaries + kwargs = {"optimize_ram":False, "breakdown_to_df":True, "tqdm":True} + print_testinfo("TEST: 02", pop, kwargs) + _ = test_binpop_evolve(pop, kwargs, verbose=verbose) + check_test(pop_in_ram, out_path, load_pop=True) + + # test optimize RAM run w/ batch saving + kwargs = {"optimize_ram":True, "breakdown_to_df":False, "tqdm":True} + print_testinfo("TEST: 03", pop, kwargs) + _ = test_binpop_evolve(pop, kwargs, verbose=verbose) + check_test(pop_in_ram, out_path, load_pop=True) + + # TEST POPRUNNER + # This is can be RAM heavy (may fail esp. on personal computers) + # Using flush on print here since we are running subprocceses and want them to + # show in order with shell stdout. + # ================================================================================ + os.chdir(out_path) + test_str = " TEST: 04 " + numchar = (LINE_LENGTH - len(test_str)) // 2 + print("=" * numchar + test_str + "=" * numchar, flush=True) + print("Test PopulationRunner with multiple metallicities...", flush=True) + print(f"Reading inlist: {multiz_path}", flush=True) + poprun = PopulationRunner(multiz_path, verbose=True) + print('\t Number of binary populations:', len(poprun.binary_populations), flush=True) + print('\t Metallicities:', poprun.solar_metallicities, flush=True) + print('\t Number of binaries (per pop):', poprun.binary_populations[0].number_of_binaries, flush=True) + print("🚀 Evolving PopulationRunner...", flush=True) + poprun.evolve(overwrite=True) + print("✅ PopulationRunner evolved successfully.", flush=True) + print("=" * LINE_LENGTH, flush=True) + + # TEST PIPELINE + # This is can also be RAM heavy + # ================================================================================ + test_str = " TEST: 05 " + numchar = (LINE_LENGTH - len(test_str)) // 2 + print("=" * numchar + test_str + "=" * numchar, flush=True) + print("Test posydon-popsyn pipeline for multiple metallicities...", flush=True) + shutil.copy(os.path.join(script_dir, "setup_poprun.sh"), out_path) + subprocess.run(["bash", "setup_poprun.sh", multiz_path], check=True) + # mimic SLURM job array env vars, as if jobs submitted with --job_array=1 + # this is needed to test merge_metallicity.py, which looks for jobs per task ID + # to merge. + os.environ["SLURM_ARRAY_JOB_ID"] = "0" + os.environ["SLURM_ARRAY_TASK_MIN"] = "0" + os.environ["SLURM_ARRAY_TASK_ID"] = "0" + os.environ["SLURM_ARRAY_TASK_COUNT"] = "1" + + for metallicity in poprun.solar_metallicities: + subprocess.run(["echo", f"🚀 Running pipeline for metallicity {metallicity}..."]) + subprocess.run(["python", "run_metallicity.py", str(metallicity)], check=True) + subprocess.run(["python", "merge_metallicity.py", str(metallicity)], check=True) + + print("✅ Successfully evolved multiple populations with posydon-popsyn.", flush=True) + + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description='Evolve test binary populations for POSYDON branch validation.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--verbose', '-v', action='store_true', default=False, + help='Enable verbose output') + parser.add_argument('--output', '-o', type=str, required=True, + help='Path to save population synthesis output') + parser.add_argument('--ini', type=str, default=None, + help='Path to params ini file (auto-detected if not given)') + parser.add_argument('--multiz', type=str, default=None, + help='Path to params ini file with multiple metallicities (auto-detected if not given)') + args = parser.parse_args() + + test_popruns(ini_path=args.ini, + multiz_path=args.multiz, + out_path=args.output, + verbose=args.verbose) diff --git a/dev-tools/script_data/src/setup_poprun.sh b/dev-tools/script_data/src/setup_poprun.sh new file mode 100644 index 0000000000..ff797aeb54 --- /dev/null +++ b/dev-tools/script_data/src/setup_poprun.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +INI_FILE="$1" +posydon-popsyn setup ${INI_FILE} --job_array=1 --walltime=00:40:00 --partition=partition --account=account --email=email@domain.com --mem_per_cpu=10G diff --git a/dev-tools/script_data/src/utils.py b/dev-tools/script_data/src/utils.py new file mode 100644 index 0000000000..7482377531 --- /dev/null +++ b/dev-tools/script_data/src/utils.py @@ -0,0 +1,112 @@ +from formatting import LINE_LENGTH, columns_to_show, target_rows + + +def print_warnings(captured_warnings): + # Show warnings if any were captured + if captured_warnings: + print(f"⚠️ {len(captured_warnings)} warning(s) raised during evolution:") + for i, warning in enumerate(captured_warnings[:3], 1): # Show max 3 warnings + print(f" {i}. {warning['category']}: {warning['message']}") + if len(captured_warnings) > 3: + print(f" ... and {len(captured_warnings) - 3} more warning(s)") + elif len(captured_warnings) <= 3: + for i in range(4-len(captured_warnings)): + print("") + else: + print(f"No warning(s) raised during evolution\n\n") + +def print_pop_settings(population): + + print("\nPopulation settings:") + + ignore_kwargs = ["extra_columns", "only_select_columns", "scalar_names", + "include_S1", "S1_kwargs", "include_S2", "S2_kwargs", + "population_properties", "warnings_verbose", "history_verbose", + "error_checking_verbose", "use_MPI", "read_samples_from_file", + "RANK", "size", "optimize_ram", "ram_per_cpu", + "dump_rate", "tqdm", "breakdown_to_df"] + + for key, val in population.kwargs.items(): + if key in ignore_kwargs: + continue + else: + print(f"\t {key} : {val}") + + print("\n") + + +def write_binary_to_screen(binary): + """Writes a binary DataFrame prettily to the screen + + Args: + binary: BinaryStar object with evolved data + """ + df = binary.to_df(**{'extra_columns':{'step_names':'str'}}) + + # Filter to only existing columns + available_columns = [col for col in columns_to_show if col in df.columns] + df_filtered = df[available_columns] + + # Reset index to use a counter instead of NaN + df_filtered = df_filtered.reset_index(drop=True) + + print("=" * LINE_LENGTH) + + # Print the DataFrame + df_string = df_filtered.to_string(index=True, float_format='%.3f') + print(df_string) + + # Add empty lines to reach exactly 10 rows of output + current_rows = len(df_filtered) + 1 # add one for header + + if current_rows < target_rows: + # Calculate the width of the output to print empty lines of the same width + lines = df_string.split('\n') + if len(lines) > 1: + # Use the width of the data lines (skip header) + empty_lines_needed = target_rows - current_rows + for i in range(empty_lines_needed): + print("") + + print("-" * LINE_LENGTH) + + +def print_failed_binary(binary, e, max_error_lines=3): + + print("=" * LINE_LENGTH) + print(f"🚨 Binary Evolution Failed!") + print(f"Exception: {type(e).__name__}") + print(f"Message: {e}") + + # Get the binary's current state and limit output + try: + df = binary.to_df(**{'extra_columns':{'step_names':'str'}}) + if len(df) > 0: + # Select only the desired columns + + available_columns = [col for col in columns_to_show if col in df.columns] + df_filtered = df[available_columns] + + # Reset index to use a counter instead of NaN + df_filtered = df_filtered.reset_index(drop=True) + + # Limit to max_error_lines + if len(df_filtered) > max_error_lines: + df_filtered = df_filtered.tail(max_error_lines) + print(f"\nShowing last {max_error_lines} evolution steps before failure:") + else: + print(f"\nEvolution steps before failure ({len(df_filtered)} steps):") + + df_string = df_filtered.to_string(index=True, float_format='%.3f') + print(df_string) + + current_rows = len(df_filtered) + 1 + 5 # add one for header + empty_lines_needed = target_rows - current_rows + for i in range(empty_lines_needed): + print("") + else: + print("\nNo evolution steps recorded before failure.") + except Exception as inner_e: + print(f"\nCould not retrieve binary state: {inner_e}") + + print("-" * LINE_LENGTH) diff --git a/dev-tools/script_data/workdirs/.gitkeep b/dev-tools/script_data/workdirs/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dev-tools/validate_binaries.sh b/dev-tools/validate_binaries.sh new file mode 100755 index 0000000000..797bb99f3b --- /dev/null +++ b/dev-tools/validate_binaries.sh @@ -0,0 +1,241 @@ +#!/bin/bash +# ============================================================================= +# validate_binaries.sh: Run the full validation pipeline: +# 1. Evolve test binaries on a candidate branch +# 2. Compare against baseline files +# +# Usage: +# ./validate_binaries.sh [baseline_branch] [metallicities] +# [--loose] [--rtol VALUE] [--atol VALUE] [--skip-evolve] +# +# Positional arguments: +# candidate_branch Branch or tag to validate (required) +# baseline_branch Branch or tag to compare against (default: main) +# metallicities Space-separated list of Z values, quoted +# (default: "2 1 0.45 0.2 0.1 0.01 0.001 0.0001") +# +# Tolerance flags (passed through to compare_runs.py): +# --loose Use relaxed floating-point tolerances +# (rtol=1e-12, atol=1e-15 unless overridden) +# --rtol VALUE Set explicit relative tolerance as per np.allclose +# --atol VALUE Set explicit absolute tolerance as per np.allclose +# +# --rtol and --atol can be combined with --loose (explicit values take +# precedence over the --loose defaults) or used on their own without --loose. +# +# Other flags: +# --skip-evolve Skip Step 1 (evolution) and compare existing outputs +# against baseline. Candidate files must already exist. +# +# Examples: +# ./validate_binaries.sh feature/new-SN # compare vs main, exact +# ./validate_binaries.sh feature/new-SN v2.1.0 # compare vs v2.1.0, exact +# ./validate_binaries.sh feature/new-SN main "1 0.45" # subset of metallicities +# ./validate_binaries.sh feature/new-SN --loose # relaxed tolerances +# ./validate_binaries.sh feature/new-SN main --rtol 1e-8 # custom rtol, default atol +# ./validate_binaries.sh feature/new-SN main "1 0.45" --loose --atol 1e-10 +# ./validate_binaries.sh feature/new-SN --skip-evolve # compare existing outputs only +# +# Prerequisites: +# Run generate_baseline.sh first to create baseline files. +# +# Output: +# outputs//comparison_Zsun.txt — per-metallicity comparison reports +# outputs//comparison_summary.txt — overall summary +# ============================================================================= + +set -euo pipefail + +# ── Parse arguments ─────────────────────────────────────────────────────── + +CANDIDATE_BRANCH=${1:?Usage: ./validate_binaries.sh [baseline_branch] [metallicities] [--loose] [--rtol VALUE] [--atol VALUE]} +BASELINE_BRANCH=${2:-main} +METALLICITIES=${3:-"2 1 0.45 0.2 0.1 0.01 0.001 0.0001"} +shift $(( $# < 3 ? $# : 3 )) + +LOOSE=false +RTOL="" +ATOL="" +SKIP_EVOLVE=false + +while [[ $# -gt 0 ]]; do + case "$1" in + --loose) + LOOSE=true + shift + ;; + --rtol) + RTOL="${2:?--rtol requires a value}" + shift 2 + ;; + --atol) + ATOL="${2:?--atol requires a value}" + shift 2 + ;; + --skip-evolve) + SKIP_EVOLVE=true + shift + ;; + *) + echo "ERROR: Unknown option: $1" >&2 + exit 1 + ;; + esac +done + +DEV_TOOLS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SCRIPT_DIR="${DEV_TOOLS_DIR}/script_data" +SRC_DIR="${SCRIPT_DIR}/src" +SAFE_CANDIDATE="${CANDIDATE_BRANCH//\//_}" +SAFE_BASELINE="${BASELINE_BRANCH//\//_}" + +BASELINE_DIR="$SCRIPT_DIR/baselines/${SAFE_BASELINE}" +BINARY_OUTPUT_DIR="$SCRIPT_DIR/output/binary_star_tests/${SAFE_CANDIDATE}" +SUMMARY_FILE="$BINARY_OUTPUT_DIR/comparison_summary.txt" + +# ── Build compare_runs.py flags ─────────────────────────────────────────── +COMPARE_FLAGS="" +if [ "$LOOSE" = "true" ]; then + COMPARE_FLAGS="$COMPARE_FLAGS --loose" +fi +if [ -n "$RTOL" ]; then + COMPARE_FLAGS="$COMPARE_FLAGS --rtol $RTOL" +fi +if [ -n "$ATOL" ]; then + COMPARE_FLAGS="$COMPARE_FLAGS --atol $ATOL" +fi + +# Build a human-readable tolerance label for the summary +if [ -n "$RTOL" ] || [ -n "$ATOL" ]; then + TOL_LABEL="rtol=${RTOL:-default}, atol=${ATOL:-default}" + if [ "$LOOSE" = "true" ]; then + TOL_LABEL="$TOL_LABEL (--loose)" + fi +elif [ "$LOOSE" = "true" ]; then + TOL_LABEL="--loose (rtol=1e-12, atol=1e-15)" +else + TOL_LABEL="EXACT (rtol=0, atol=0)" +fi + +echo "============================================================" +echo " POSYDON Binary Validation" +echo " Candidate: $CANDIDATE_BRANCH" +echo " Baseline: $BASELINE_BRANCH" +echo " Metallicities: $METALLICITIES" +echo "============================================================" + +# ── Verify baseline exists ──────────────────────────────────────────────── +if [ ! -d "$BASELINE_DIR" ]; then + echo "ERROR: Baseline directory not found: $BASELINE_DIR" >&2 + echo "Run generate_baseline.sh first:" >&2 + echo " ./generate_baseline.sh $BASELINE_BRANCH" >&2 + exit 1 +fi + +# Check that at least one baseline file exists +BASELINE_COUNT=0 +for Z in $METALLICITIES; do + if [ -f "$BASELINE_DIR/baseline_${Z}Zsun.h5" ]; then + BASELINE_COUNT=$((BASELINE_COUNT + 1)) + fi +done +if [ $BASELINE_COUNT -eq 0 ]; then + echo "ERROR: No baseline files found in $BASELINE_DIR for requested metallicities." >&2 + exit 1 +fi +echo " Found $BASELINE_COUNT baseline file(s)." + +# ── Step 1: Evolve binaries on candidate branch ────────────────────────── +if [ "$SKIP_EVOLVE" = true ]; then + echo "" + echo "Step 1: SKIPPED (--skip-evolve: using existing outputs in $OUTPUT_DIR)" + if [ ! -d "$OUTPUT_DIR" ]; then + echo "ERROR: No outputs found at $OUTPUT_DIR" >&2 + echo "Run evolve_binaries.sh first, or drop --skip-evolve to evolve from scratch." >&2 + exit 1 + fi +else + echo "" + echo "Step 1: Evolving binaries on candidate branch '$CANDIDATE_BRANCH'..." + "$DEV_TOOLS_DIR/run_test_suite.sh" "$CANDIDATE_BRANCH" "" "$METALLICITIES" +fi + +# ── Step 2: Compare each metallicity ───────────────────────────────────── +echo "" +echo "Step 2: Comparing results..." + +TOTAL=0 +PASS=0 +FAIL=0 +SKIP=0 + +# Initialize summary +mkdir -p "$BINARY_OUTPUT_DIR" +cat > "$SUMMARY_FILE" << EOF +POSYDON Binary Validation — Comparison Summary +================================================ +Candidate branch: $CANDIDATE_BRANCH +Baseline branch: $BASELINE_BRANCH +Tolerances: $TOL_LABEL +Date: $(date -u '+%Y-%m-%d %H:%M:%S UTC') +================================================ + +EOF + +for Z in $METALLICITIES; do + TOTAL=$((TOTAL + 1)) + + BASELINE_FILE="$BASELINE_DIR/baseline_${Z}Zsun.h5" + CANDIDATE_FILE="$BINARY_OUTPUT_DIR/candidate_${Z}Zsun.h5" + COMPARISON_FILE="$BINARY_OUTPUT_DIR/comparison_${Z}Zsun.txt" + + echo "" + echo "--- Z = ${Z} Zsun ---" + + if [ ! -f "$BASELINE_FILE" ]; then + echo " SKIP: No baseline file for Z=${Z}" + echo "Z = ${Z} Zsun: SKIPPED (no baseline)" >> "$SUMMARY_FILE" + SKIP=$((SKIP + 1)) + continue + fi + + if [ ! -f "$CANDIDATE_FILE" ]; then + echo " FAIL: No candidate file for Z=${Z}" + echo "Z = ${Z} Zsun: FAIL (no candidate output)" >> "$SUMMARY_FILE" + FAIL=$((FAIL + 1)) + continue + fi + + # $COMPARE_FLAGS is intentionally unquoted so it word-splits into + # separate arguments for compare_runs.py. + if python "$SRC_DIR/compare_runs.py" "$BASELINE_FILE" "$CANDIDATE_FILE" \ + $COMPARE_FLAGS \ + 2>&1 | tee "$COMPARISON_FILE"; then + echo " PASS: No differences" + echo "Z = ${Z} Zsun: PASS" >> "$SUMMARY_FILE" + PASS=$((PASS + 1)) + else + echo " DIFFERENCES DETECTED — see $COMPARISON_FILE" + echo "Z = ${Z} Zsun: DIFFERENCES DETECTED (see comparison_${Z}Zsun.txt)" >> "$SUMMARY_FILE" + FAIL=$((FAIL + 1)) + fi +done + +# ── Final Summary ───────────────────────────────────────────────────────── +cat >> "$SUMMARY_FILE" << EOF + +================================================ +TOTAL: $TOTAL | PASS: $PASS | FAIL: $FAIL | SKIP: $SKIP +EOF + +echo "" +echo "============================================================" +echo " Validation Summary" +echo " TOTAL: $TOTAL | PASS: $PASS | FAIL: $FAIL | SKIP: $SKIP" +echo " Full summary: $SUMMARY_FILE" +echo "============================================================" + +if [ $FAIL -gt 0 ]; then + exit 1 +fi +exit 0 diff --git a/docs/_source/components-overview/pop_syn/single_star.rst b/docs/_source/components-overview/pop_syn/single_star.rst index a6454e5b6f..60cef6e1e8 100644 --- a/docs/_source/components-overview/pop_syn/single_star.rst +++ b/docs/_source/components-overview/pop_syn/single_star.rst @@ -32,7 +32,7 @@ The star properties are defined as follows * - ``state`` - The state of the star, see state options. * - ``metallicity`` - - Fractional metal content (Z) of the star. + - Ratio to solar metallicity (Z/Z_sun, e.g., 1.0 for solar metallicity). * - ``mass`` - Stellar mass in M_sun. * - ``log_R`` @@ -103,20 +103,29 @@ The star properties are defined as follows Additional scalar properties are added during the evolution depending on which steps the star has undergone. These properties are not stored in the history. .. list-table:: Additional output - :header-rows: 1 - :widths: 50 150 + :header-rows: 1 + :widths: 50 150 * - Properties - Descriptions * - ``natal_kick_array`` - | The natal kick array for the star if it has undergone a SN. - | contains: - - * velocity - * theta - * phi - * mean anomaly - + | This has been replaced with the individual properties below. + | ``natal_kick_array`` contains: + + * velocity (km/s) + * azimuthal angle phi (radians) + * polar angle theta (radians) + * mean anomaly (radians) + + * - ``natal_kick_velocity`` + - The magnitude of the natal kick velocity in km/s. + * - ``natal_kick_phi`` + - The natal kick azimuthal angle phi in radians. + * - ``natal_kick_theta`` + - The natal kick polar angle theta in radians. + * - ``natal_kick_mean_anomaly`` + - The natal kick mean anomaly in radians. * - ``SN_type`` - The supernova type of the star. * - ``f_fb`` diff --git a/docs/_source/getting-started/installation-guide.rst b/docs/_source/getting-started/installation-guide.rst index fec6909bf6..6b9414ae6f 100644 --- a/docs/_source/getting-started/installation-guide.rst +++ b/docs/_source/getting-started/installation-guide.rst @@ -14,6 +14,18 @@ Installing POSYDON Anaconda (Recommended) ---------------------- +.. important:: + **Conda Version Requirements**: POSYDON requires a recent version of conda (version >= 23.1.0) with the libmamba solver for efficient dependency resolution. Older conda versions (especially those prior to v23.1.0) may take an extremely long time (hours or more) to resolve dependencies and may fail to complete installation. + + To check your conda version and solver configuration: + + .. code-block:: bash + + conda --version + conda config --show solver + + If you're using an older conda version or experiencing slow installation, please see the :ref:`troubleshooting guide ` for detailed instructions on updating conda or configuring the libmamba solver. + 1. **Install Anaconda** If you haven't already, download and install Anaconda from `Anaconda's official website `_. diff --git a/docs/_source/troubleshooting-faqs/installation-issues.rst b/docs/_source/troubleshooting-faqs/installation-issues.rst index aeed294d70..89abf4ac0a 100644 --- a/docs/_source/troubleshooting-faqs/installation-issues.rst +++ b/docs/_source/troubleshooting-faqs/installation-issues.rst @@ -5,20 +5,49 @@ Common Installation Issues From time to time, users might encounter issues during the installation of POSYDON. This page aims to address common installation problems and offer solutions. If your problem isn't covered here, please `report the issue `_ so we can assist you and possibly update this page for the benefit of others. -1. **Slow `conda` solving:** +1. **Slow `conda` solving or installation taking hours:** -`conda` can be very slow and sometimes gets stuck on "Verifying transaction" or "Executing transaction", especially when installing packages on a cluster. -It creates many small files, which can be difficult for HPC clusters to handle. -One way to speed up the installation is to use the `mamba` package manager, which is a drop-in replacement for `conda` but is much faster (`click here for more details `_). -Please proceed at your own discretion, as this has not been fully vetted. Alternatively, you can install the `libmamba` solver to speed up the solving process for new installations in a `conda` environment. -To install the `libmamba` solver, run the following command in your `conda` environment of choice or `base` `conda` environment: + **Conda Version Requirements**: POSYDON has a complex dependency tree, and older conda versions (especially those prior to v23.1.0) use very slow dependency solvers that can take hours or may never complete the installation process. Modern conda versions (>= 23.1.0) include the libmamba solver by default, which resolves dependencies efficiently in seconds to minutes. -```bash -conda install conda-libmamba-solver -``` + **Check your conda version and solver:** -This will install the `libmamba` solver, which is a drop-in replacement for the default `conda` solver. -This should speed up solving the environment and installing packages but is not guaranteed to work in all cases. + .. code-block:: bash + + conda --version + conda config --show solver + + **Solutions:** + + a. **Update conda** (Recommended): If you have administrative access or can install conda locally, we strongly recommend updating to the latest conda version (2025 or later), which includes the fast libmamba solver by default: + + .. code-block:: bash + + # Update conda in your base environment + conda update -n base conda + + # Or install a fresh conda distribution from https://www.anaconda.com/download + + b. **Install and configure the libmamba solver** (For existing conda installations): + + If you cannot update conda but have version 4.12 or later, you can install and configure the libmamba solver: + + .. code-block:: bash + + # Install the libmamba solver + conda install -n base conda-libmamba-solver + + # Set libmamba as the default solver + conda config --set solver libmamba + + # Verify the configuration + conda config --show solver + + c. **Use mamba** (Alternative): Another option is to use the `mamba` package manager, which is a drop-in replacement for `conda` but is much faster (`click here for more details `_). However, this approach has not been fully vetted with POSYDON. + + .. note:: + If you're on an HPC cluster with an old system-wide conda installation (e.g., conda 2021.11), you may need to install a recent conda version locally in your home directory rather than using the system version. + + `conda` can also be slow and sometimes gets stuck on "Verifying transaction" or "Executing transaction", especially when installing packages on a cluster, as it creates many small files which can be difficult for HPC clusters to handle. The libmamba solver helps with this issue as well. 2. **Failed Dependencies**: - **Description**: Sometimes, certain dependencies might fail to install or conflict with pre-existing ones. diff --git a/posydon/CLI/io.py b/posydon/CLI/io.py index afa4eda0fd..d9ec806f9d 100644 --- a/posydon/CLI/io.py +++ b/posydon/CLI/io.py @@ -165,6 +165,8 @@ def create_slurm_array(metallicity, walltime, account, mem_per_cpu, + max_concurrent_jobs, + exclude, path_to_posydon, path_to_posydon_data): '''Creates the slurm array script for population synthesis job arrays. @@ -208,11 +210,16 @@ def create_slurm_array(metallicity, "#SBATCH --mail-type=FAIL", f"#SBATCH --mail-user={email}" ]) + if exclude is not None: + optional_directives.append(f"#SBATCH --exclude={exclude}") optional_section = "\n".join(optional_directives) if optional_section: optional_section += "\n" + if max_concurrent_jobs is not None: + job_array_length = f"{job_array_length}%{max_concurrent_jobs}" + text_pre = textwrap.dedent(f'''\ #!/bin/bash #SBATCH --array=0-{job_array_length} @@ -322,6 +329,8 @@ def create_slurm_rescue(metallicity, walltime, account, mem_per_cpu, + max_concurrent_jobs, + exclude, path_to_posydon, path_to_posydon_data): '''Creates the slurm rescue script for resubmitting failed population synthesis jobs. @@ -369,11 +378,16 @@ def create_slurm_rescue(metallicity, "#SBATCH --mail-type=FAIL", f"#SBATCH --mail-user={email}" ]) + if exclude is not None: + optional_directives.append(f"#SBATCH --exclude={exclude}") optional_section = "\n".join(optional_directives) if optional_section: optional_section += "\n" + if max_concurrent_jobs is not None: + job_array_str = f"{job_array_str}%{max_concurrent_jobs}" + text_pre = textwrap.dedent(f'''\ #!/bin/bash #SBATCH --array={job_array_str} @@ -427,6 +441,7 @@ def create_slurm_scripts(metallicity, args): # pragma: no cover ''' create_slurm_array(metallicity, args.job_array, args.partition, args.email, args.walltime, args.account, args.mem_per_cpu, + args.max_concurrent_jobs, args.exclude, PATH_TO_POSYDON, os.path.dirname(PATH_TO_POSYDON_DATA)) @@ -508,12 +523,19 @@ def create_batch_rescue_script(args, batch_status): mem_per_cpu = None path_to_posydon = None path_to_posydon_data = None + max_concurrent_jobs = None + exclude = None for line in lines: if line.startswith('#SBATCH --array='): array_range = line.split('=')[1].strip() - if '-' in array_range: - start, end = map(int, array_range.split('-')) + if '%' in array_range: + tmp_array_range = array_range.split('%')[0] + max_concurrent_jobs = int(array_range.split('%')[1]) + else: + tmp_array_range = array_range + if '-' in tmp_array_range: + start, end = map(int, tmp_array_range.split('-')) job_array_length = end - start + 1 elif line.startswith("#SBATCH --time="): walltime = line.split('=')[1].strip() @@ -525,6 +547,8 @@ def create_batch_rescue_script(args, batch_status): account = line.split('=')[1].strip() elif line.startswith("#SBATCH --mail-user="): email = line.split('=')[1].strip() + elif line.startswith("#SBATCH --exclude="): + exclude = line.split('=')[1].strip() elif line.startswith("export PATH_TO_POSYDON="): path_to_posydon = line.split('=')[1].strip() elif line.startswith("export PATH_TO_POSYDON_DATA="): @@ -541,6 +565,10 @@ def create_batch_rescue_script(args, batch_status): account = args.account if args.email is not None: email = args.email + if args.max_concurrent_jobs is not None: + max_concurrent_jobs = args.max_concurrent_jobs + if args.exclude is not None: + exclude = args.exclude # Create the rescue script create_slurm_rescue( @@ -552,6 +580,8 @@ def create_batch_rescue_script(args, batch_status): walltime=walltime, account=account, mem_per_cpu=mem_per_cpu, + max_concurrent_jobs=max_concurrent_jobs, + exclude=exclude, path_to_posydon=path_to_posydon, path_to_posydon_data=path_to_posydon_data ) diff --git a/posydon/CLI/popsyn/check.py b/posydon/CLI/popsyn/check.py index d70c4417a3..519817f0ad 100644 --- a/posydon/CLI/popsyn/check.py +++ b/posydon/CLI/popsyn/check.py @@ -324,6 +324,9 @@ def get_expected_batch_count(run_folder, str_met): for line in f: if line.startswith('#SBATCH --array='): array_range = line.split('=')[1].strip() + # remove any job limit specifiers + if '%' in array_range: + array_range = array_range.split('%')[0] if '-' in array_range: start, end = map(int, array_range.split('-')) return end - start + 1 diff --git a/posydon/CLI/popsyn/setup.py b/posydon/CLI/popsyn/setup.py index 0c088f8a6e..f60887a5f0 100644 --- a/posydon/CLI/popsyn/setup.py +++ b/posydon/CLI/popsyn/setup.py @@ -14,6 +14,7 @@ from posydon.grids.SN_MODELS import get_SN_MODEL_NAME from posydon.popsyn.io import binarypop_kwargs_from_ini, simprop_kwargs_from_ini from posydon.utils.common_functions import convert_metallicity_to_string +from posydon.utils.posydonwarning import Pwarn def check_SN_MODEL_validity(ini_file, verbose_on_fail=True): @@ -86,6 +87,22 @@ def setup_popsyn_function(args): validate_ini_file(args.ini_file) synpop_params = binarypop_kwargs_from_ini(args.ini_file) + # warn if mass ratio q = M2/M1 could fall below 0.05 + if synpop_params['secondary_mass_scheme'] == 'flat_mass_ratio': + q_min_possible = (synpop_params['secondary_mass_min'] + / synpop_params['primary_mass_max']) + if q_min_possible < 0.05: + Pwarn( + f"With secondary_mass_min=" + f"{synpop_params['secondary_mass_min']} and " + f"primary_mass_max=" + f"{synpop_params['primary_mass_max']}, the mass " + f"ratio q=M2/M1 can be as low as {q_min_possible:.4f}. " + f"Some binaries with q<0.05 might fall outside the POSYDON " + f"default grids.", + "InappropriateValueWarning" + ) + metallicities = synpop_params['metallicities'] if synpop_params['number_of_binaries'] / args.job_array < 1: raise ValueError("The number of binaries is less than the job array" diff --git a/posydon/__init__.py b/posydon/__init__.py index bf13af499c..d3c2c0325a 100644 --- a/posydon/__init__.py +++ b/posydon/__init__.py @@ -1,6 +1,11 @@ -from ._version import get_versions +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("posydon") +except PackageNotFoundError: + # Package is not installed + __version__ = "unknown" -__version__ = get_versions()['version'] __author__ = "Tassos Fragos " __credits__ = [ "Emmanouil Zapartas ", @@ -19,5 +24,3 @@ "Ying Qin <", "Aaron Dotter ", ] - -del get_versions diff --git a/posydon/_version.py b/posydon/_version.py deleted file mode 100644 index 030e6a8b1b..0000000000 --- a/posydon/_version.py +++ /dev/null @@ -1,532 +0,0 @@ - -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -__authors__ = [ - "Scott Coughlin ", - "Matthias Kruckow ", -] - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "" - cfg.versionfile_source = "posydon/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Get decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date, rc = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root) - if rc != 0: - if verbose: - print("Retry 'git show'") - date, rc = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root) - if date is None: - raise NotThisMethod("'git show' failed") - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} diff --git a/posydon/binary_evol/CE/step_CEE.py b/posydon/binary_evol/CE/step_CEE.py index a4c899d1a5..dcbcd7f5c8 100644 --- a/posydon/binary_evol/CE/step_CEE.py +++ b/posydon/binary_evol/CE/step_CEE.py @@ -33,12 +33,10 @@ "Matthias Kruckow ", ] - import numpy as np import pandas as pd from posydon.binary_evol.binarystar import BINARYPROPERTIES -from posydon.binary_evol.DT.track_match import TrackMatcher from posydon.binary_evol.flow_chart import ( STAR_STATES_CO, STAR_STATES_H_RICH, @@ -58,68 +56,65 @@ from posydon.utils.constants import Zsun from posydon.utils.posydonwarning import Pwarn -MODEL = {"prescription": 'alpha-lambda', - "common_envelope_efficiency": 1.0, - "common_envelope_lambda_default": 0.5, - "common_envelope_option_for_lambda": 'lambda_from_grid_final_values', - "common_envelope_option_for_HG_star": "optimistic", - "common_envelope_alpha_thermal": 1.0, - "core_definition_H_fraction": 0.3, # with 0.01 no CE BBHs - "core_definition_He_fraction": 0.1, - "CEE_tolerance_err": 0.001, - "verbose": False, - "common_envelope_option_after_succ_CEE": 'two_phases_stableMT', - "mass_loss_during_CEE_merged": False, # If False, then no mass loss from this step for a merged star - # If True, then we remove mass according to the alpha-lambda prescription - # assuming a final separation where the inner core RLOF starts. - # "one_phase_variable_core_definition" for core_definition_H_fraction=0.01 - "metallicity": None, - "record_matching": False - } - - -# common_envelope_option_for_lambda: -# 1) 'default_lambda': using for lambda the constant value of -# common_envelope_lambda_default parameter -# 2) 'lambda_from_grid_final_values': using lambda parameter from MESA history -# which was calulated ni the same way as method (5) below -# 3) 'lambda_from_profile_gravitational': calculating the lambda parameter -# from the donor's profile by using the gravitational binding energy from the -# surface to the core (needing "mass", and "radius" as columns in the profile) -# 4) 'lambda_from_profile_gravitational_plus_internal': as above but taking -# into account a factor of common_envelope_alpha_thermal * internal energy too -# in the binding energy (needing also "energy" as column in the profile) -# 5) 'lambda_from_profile_gravitational_plus_internal_minus_recombination': -# as above but not taking into account the recombination energy in the internal -# energy (needing also "y_mass_fraction_He", "x_mass_fraction_H", -# "neutral_fraction_H", "neutral_fraction_He", and "avg_charge_He" as column -# in the profile) -# the mass fraction of an element which is used as threshold to define a core, class StepCEE(object): - """Compute supernova final remnant mass, fallback fraction & stellar state. - - This consider the nearest neighboor of the He core mass of the star, - previous to the collapse. Considering a set of data for which the He core - mass of the compact object projenitos previous the collapse, the final - remnant mass and final stellar state of the compact object is known. + """Handle common envelope evolution (CEE) for binary systems. + + This class computes the outcome of a common envelope phase for binary + systems containing a giant star. It calculates how much the orbit must + shrink in order to expel the envelope using the alpha-prescription. + If at the required post-CEE separation one of the stars fills its Roche lobe, + the system is considered a merger. Otherwise, the envelope is lost, + leaving a binary system with the core of the donor star (which initiates + the unstable CEE) and the core of the companion. + + If stellar profiles are available, the lambda parameter for the donor + can be calculated directly from the profile. Otherwise, default values + are used. The evolution is computed using a specified prescription + (e.g., alpha-lambda) which determines the final state of the binary + based on energy budget considerations. Parameters ---------- verbose : bool - If True, the messages will be prited in the console. + If True, the messages will be printed in the console. Keyword Arguments ----------------- prescription : str - Prescription to use for computing the prediction of common enevelope + Prescription to use for computing the prediction of common envelope evolution. Available options are: - * 'alpha-lambda' : Considers the the alpha-lambda prescription + * 'alpha-lambda' : Considers the alpha-lambda prescription described in [1]_ and [2]_ to predict the outcome of the common envelope evolution. If the profile of the donor star is available then it is used to compute the value of lambda. + common_envelope_option_for_lambda : str + Method for calculating the lambda parameter. Available options are: + + 1. 'default_lambda' : Use a constant value from the + `common_envelope_lambda_default` parameter. + + 2. 'lambda_from_grid_final_values' : Use the lambda parameter from + MESA history, calculated using the same method as option 5 below. + + 3. 'lambda_from_profile_gravitational' : Calculate lambda from the + donor's profile using the gravitational binding energy from the + surface to the core (requires "mass" and "radius" columns in the + profile). + + 4. 'lambda_from_profile_gravitational_plus_internal' : As above, + but also accounting for a factor of `common_envelope_alpha_thermal` + times the internal energy in the binding energy (requires also + "energy" column in the profile). + + 5. 'lambda_from_profile_gravitational_plus_internal_minus_recombination' : + As above, but excluding the recombination energy from the internal + energy calculation (requires also "y_mass_fraction_He", + "x_mass_fraction_H", "neutral_fraction_H", "neutral_fraction_He", + and "avg_charge_He" columns in the profile). + References ---------- .. [1] Webbink, R. F. (1984). Double white dwarfs as progenitors of R @@ -129,79 +124,42 @@ class StepCEE(object): .. [2] De Kool, M. (1990). Common envelope evolution and double cores of planetary nebulae. The Astrophysical Journal, 358, 189-195. """ - - def __init__( - self, prescription=MODEL['prescription'], - common_envelope_efficiency=MODEL['common_envelope_efficiency'], - common_envelope_lambda_default=MODEL[ - 'common_envelope_lambda_default'], - common_envelope_option_for_lambda=MODEL[ - 'common_envelope_option_for_lambda'], - common_envelope_option_for_HG_star=MODEL[ - 'common_envelope_option_for_HG_star'], - common_envelope_option_after_succ_CEE=MODEL[ - 'common_envelope_option_after_succ_CEE'], - common_envelope_alpha_thermal=MODEL[ - 'common_envelope_alpha_thermal'], - core_definition_H_fraction=MODEL[ - 'core_definition_H_fraction'], - core_definition_He_fraction=MODEL[ - 'core_definition_He_fraction'], - CEE_tolerance_err=MODEL['CEE_tolerance_err'], - mass_loss_during_CEE_merged=MODEL['mass_loss_during_CEE_merged'], - verbose=MODEL['verbose'], - metallicity = MODEL['metallicity'], - record_matching = MODEL['record_matching'], - **kwargs): + DEFAULT_KWARGS = {"prescription": 'alpha-lambda', + "common_envelope_efficiency": 1.0, + "common_envelope_lambda_default": 0.5, + "common_envelope_option_for_lambda": 'lambda_from_grid_final_values', + "common_envelope_option_for_HG_star": "optimistic", + "common_envelope_alpha_thermal": 1.0, + "core_definition_H_fraction": 0.3, # with 0.01 no CE BBHs + "core_definition_He_fraction": 0.1, + "CEE_tolerance_err": 0.001, + "verbose": False, + "common_envelope_option_after_succ_CEE": 'two_phases_stableMT', + "mass_loss_during_CEE_merged": False, + # If False, then no mass loss from this step for a merged star + # If True, then we remove mass according to the alpha-lambda prescription + # assuming a final separation where the inner core RLOF starts. + # "one_phase_variable_core_definition" for core_definition_H_fraction=0.01 + "metallicity": None, + "track_matcher": None + } + + + def __init__(self, **kwargs): """Initialize a StepCEE instance.""" # read kwargs to initialize the class if kwargs: for key in kwargs: - if key not in MODEL: + if key not in self.DEFAULT_KWARGS: raise ValueError(key + " is not a valid parameter name!") - for varname in MODEL: - default_value = MODEL[varname] + for varname in self.DEFAULT_KWARGS: + default_value = self.DEFAULT_KWARGS[varname] setattr(self, varname, kwargs.get(varname, default_value)) else: - self.prescription = prescription - self.common_envelope_efficiency = common_envelope_efficiency - self.common_envelope_lambda_default = \ - common_envelope_lambda_default - self.common_envelope_option_for_lambda = \ - common_envelope_option_for_lambda - self.common_envelope_option_for_HG_star = \ - common_envelope_option_for_HG_star - self.common_envelope_alpha_thermal = common_envelope_alpha_thermal - self.core_definition_H_fraction = core_definition_H_fraction - self.core_definition_He_fraction = core_definition_He_fraction - self.CEE_tolerance_err = CEE_tolerance_err - self.common_envelope_option_after_succ_CEE = \ - common_envelope_option_after_succ_CEE - self.mass_loss_during_CEE_merged = mass_loss_during_CEE_merged - self.metallicity = metallicity - self.record_matching = record_matching - self.verbose = verbose - self.path_to_posydon = PATH_TO_POSYDON - - - list_for_matching_HMS = [ - ["mass", "center_h1", "he_core_mass"], - [20.0, 1.0, 10.0], - ["log_min_max", "min_max", "min_max"], - [0.1, 300], [0.0, None] - ] - self.track_matcher = TrackMatcher(grid_name_Hrich = None, - grid_name_strippedHe = None, - path=PATH_TO_POSYDON_DATA, - metallicity = self.metallicity, - matching_method = "minimize", - matching_tolerance=1e-2, - matching_tolerance_hard=1e-1, - list_for_matching_HMS = list_for_matching_HMS, - list_for_matching_HeStar = None, - list_for_matching_postMS = None, - record_matching = self.record_matching, - verbose = self.verbose) + for varname in self.DEFAULT_KWARGS: + default_value = self.DEFAULT_KWARGS[varname] + setattr(self, varname, default_value) + def __call__(self, binary): """Perform the CEE step for a BinaryStar object.""" # Determine which star is the donor and which is the companion @@ -825,11 +783,12 @@ def CEE_two_phases_windloss(self, donor, mc1_i, rc1_i, donor_type, def CEE_simple_alpha_prescription( self, binary, donor, comp_star, lambda1_CE, mc1_i, rc1_i, donor_type, lambda2_CE, mc2_i, rc2_i, comp_type, double_CE=False, - verbose=False, common_envelope_option_after_succ_CEE=MODEL[ - 'common_envelope_option_after_succ_CEE'], - core_definition_H_fraction=MODEL['core_definition_H_fraction'], - core_definition_He_fraction=MODEL['core_definition_He_fraction'], - mass_loss_during_CEE_merged=MODEL['mass_loss_during_CEE_merged']): + verbose=False, + common_envelope_option_after_succ_CEE=\ + DEFAULT_KWARGS['common_envelope_option_after_succ_CEE'], + core_definition_H_fraction=DEFAULT_KWARGS['core_definition_H_fraction'], + core_definition_He_fraction=DEFAULT_KWARGS['core_definition_He_fraction'], + mass_loss_during_CEE_merged=DEFAULT_KWARGS['mass_loss_during_CEE_merged']): """Apply the alpha-lambda common-envelope prescription. It uses energetics to calculate the shrinakge of the orbit diff --git a/posydon/binary_evol/DT/double_CO.py b/posydon/binary_evol/DT/double_CO.py index 615185b4ee..2aed9ff865 100644 --- a/posydon/binary_evol/DT/double_CO.py +++ b/posydon/binary_evol/DT/double_CO.py @@ -228,13 +228,6 @@ def __init__(self, **kwargs): super().__init__(**kwargs) - # For DCO system, only gravitational radiation is considered - self.do_magnetic_braking = False - self.do_tides = False - self.do_wind_loss = False - self.do_stellar_evolution_and_spin_from_winds = False - self.do_gravitational_radiation = True - def set_stars(self, primary, secondary, t0=0.0): diff --git a/posydon/binary_evol/DT/step_detached.py b/posydon/binary_evol/DT/step_detached.py index de47199265..73907fddf7 100644 --- a/posydon/binary_evol/DT/step_detached.py +++ b/posydon/binary_evol/DT/step_detached.py @@ -25,10 +25,6 @@ from posydon.binary_evol.DT.gravitational_radiation.default_gravrad import ( default_gravrad, ) -from posydon.binary_evol.DT.key_library import ( - DEFAULT_TRANSLATED_KEYS, - DEFAULT_TRANSLATION, -) from posydon.binary_evol.DT.magnetic_braking.prescriptions import ( CARB_braking, G18_braking, @@ -36,7 +32,6 @@ RVJ83_braking, ) from posydon.binary_evol.DT.tides.default_tides import default_tides -from posydon.binary_evol.DT.track_match import TrackMatcher from posydon.binary_evol.DT.winds.default_winds import ( default_sep_from_winds, default_spin_from_winds, @@ -59,6 +54,10 @@ set_binary_to_failed, zero_negative_values, ) +from posydon.utils.key_library import ( + DEFAULT_TRANSLATED_KEYS, + DEFAULT_TRANSLATION, +) from posydon.utils.posydonerror import ( ClassificationError, FlowError, @@ -86,73 +85,6 @@ class detached_step: Parameters ---------- - path : str - Path to the directory that contains POSYDON data HDF5 files. Defaults - to the PATH_TO_POSYDON_DATA environment variable. Used for track - matching. - - metallicity : float - The metallicity of the grid. This should be one of the eight - supported metallicities: - - [2e+00, 1e+00, 4.5e-01, 2e-01, 1e-01, 1e-02, 1e-03, 1e-04] - - and this will be converted to a corresponding string (e.g., - 1e+00 --> "1e+00_Zsun"). Used for track matching. - - matching_method : str - Method to find the best match between a star from a previous step and a - point in a single star evolution track. Options: - - "root": Tries to find a root of two matching quantities. It is - possible to not find one, causing the evolution to fail. - - "minimize": Minimizes the sum of squares of differences of - various quantities between the previous evolution step and - a stellar evolution track. - - Used for track matching. - - grid_name_Hrich : str - Name of the single star H-rich grid h5 file, - including its parent directory. This is set to - (for example): - - grid_name_Hrich = 'single_HMS/1e+00_Zsun.h5' - - by default if not specified. Used for track matching. - - grid_name_strippedHe : str - Name of the single star He-rich grid h5 file. This is - set to (for example): - - grid_name_strippedHe = 'single_HeMS/1e+00_Zsun.h5' - - by default if not specified. Used for track matching. - - list_for_matching_HMS : list - A list of mixed type that specifies properties of the matching - process for HMS stars. Used for track matching. - - list_for_matching_postMS : list - A list of mixed type that specifies properties of the matching - process for postMS stars. Used for track matching. - - list_for_matching_HeStar : list - A list of mixed type that specifies properties of the matching - process for He stars. Used for track matching. - - record_matching : bool - Whether properties of the matched star(s) should be recorded in the - binary evolution history. Used for track matching. - - Attributes - ---------- - KEYS : list[str] - Contains keywords corresponding to MESA data column names - which are used to extract quantities from the single star - evolution grids. - dt : float The timestep size, in years, to be appended to the history of the binary. None means only the final step. Note: do not select very @@ -202,10 +134,6 @@ class detached_step: evolved until RLO commences once again, but without changing the orbit. - translate : dict - Dictionary containing data column name (key) translations between - POSYDON h5 file PSyGrid data names (items) and MESA data names (keys). - track_matcher : TrackMatcher object The TrackMatcher object performs functions related to matching binary stellar evolution components to single star evolution models. @@ -213,81 +141,77 @@ class detached_step: verbose : bool True if we want to print stuff. + Attributes + ---------- + KEYS : list[str] + Contains keywords corresponding to MESA data column names + which are used to extract quantities from the single star + evolution grids. + + translate : dict + Dictionary containing data column name (key) translations between + POSYDON h5 file PSyGrid data names (items) and MESA data names (keys). + + evo : detached_evolution + Handler object responsible for performing the detached binary + evolution. + + evo_kwargs : dict + Keyword arguments used to initialize ``detached_evolution``. + """ - def __init__( - self, - dt=None, - n_o_steps_history=None, - do_wind_loss=True, - do_tides=True, - do_gravitational_radiation=True, - do_magnetic_braking=True, - magnetic_braking_mode="RVJ83", - do_stellar_evolution_and_spin_from_winds=True, - RLO_orbit_at_orbit_with_same_am=False, - record_matching=False, - verbose=False, - grid_name_Hrich=None, - grid_name_strippedHe=None, - metallicity=None, - path=PATH_TO_POSYDON_DATA, - matching_method="minimize", - matching_tolerance=1e-2, - matching_tolerance_hard=1e-1, - list_for_matching_HMS=None, - list_for_matching_postMS=None, - list_for_matching_HeStar=None - ): - """Initialize the step. See class documentation for details.""" - self.dt = dt - self.n_o_steps_history = n_o_steps_history - self.do_wind_loss = do_wind_loss - self.do_tides = do_tides - self.do_gravitational_radiation = do_gravitational_radiation - self.do_magnetic_braking = do_magnetic_braking - self.magnetic_braking_mode = magnetic_braking_mode - self.do_stellar_evolution_and_spin_from_winds = ( - do_stellar_evolution_and_spin_from_winds - ) - self.RLO_orbit_at_orbit_with_same_am = RLO_orbit_at_orbit_with_same_am - self.verbose = verbose + # settings in .ini will override + DEFAULT_KWARGS = {"dt": None, + "n_o_steps_history": None, + "do_wind_loss": True, + "do_tides": True, + "do_gravitational_radiation": True, + "do_magnetic_braking": True, + "magnetic_braking_mode": "RVJ83", + "do_stellar_evolution_and_spin_from_winds": True, + "RLO_orbit_at_orbit_with_same_am": False, + "metallicity": None, + "track_matcher": None, + "RNG": np.random.default_rng(), + "verbose": False} + + def __init__(self, **kwargs): - if self.verbose: - print( - dt, - n_o_steps_history, - matching_method, - do_wind_loss, - do_tides, - do_gravitational_radiation, - do_magnetic_braking, - magnetic_braking_mode, - do_stellar_evolution_and_spin_from_winds) + """Initialize the step. See class documentation for details.""" + # read kwargs to initialize the class + if kwargs: + for key in kwargs: + if key not in self.DEFAULT_KWARGS: + raise ValueError(key + " is not a valid parameter name!") + for varname in self.DEFAULT_KWARGS: + default_value = self.DEFAULT_KWARGS[varname] + setattr(self, varname, kwargs.get(varname, default_value)) + else: + for varname in self.DEFAULT_KWARGS: + default_value = self.DEFAULT_KWARGS[varname] + setattr(self, varname, default_value) self.translate = DEFAULT_TRANSLATION - # these are the KEYS read from POSYDON h5 grid files (after translating # them to the appropriate columns) self.KEYS = DEFAULT_TRANSLATED_KEYS - # creating a track matching object - self.track_matcher = TrackMatcher(grid_name_Hrich = grid_name_Hrich, - grid_name_strippedHe = grid_name_strippedHe, - path=path, metallicity = metallicity, - matching_method = matching_method, - matching_tolerance=matching_tolerance, - matching_tolerance_hard=matching_tolerance_hard, - list_for_matching_HMS = list_for_matching_HMS, - list_for_matching_HeStar = list_for_matching_HeStar, - list_for_matching_postMS = list_for_matching_postMS, - record_matching = record_matching, - verbose = self.verbose) - # create evolution handler object self.init_evo_kwargs() self.evo = detached_evolution(**self.evo_kwargs) + if self.verbose: + print(self.dt, + self.n_o_steps_history, + self.track_matcher.matching_method, + self.do_wind_loss, + self.do_tides, + self.do_gravitational_radiation, + self.do_magnetic_braking, + self.magnetic_braking_mode, + self.do_stellar_evolution_and_spin_from_winds) + return def init_evo_kwargs(self): @@ -301,12 +225,12 @@ def init_evo_kwargs(self): "magnetic_braking_mode": self.magnetic_braking_mode, "do_stellar_evolution_and_spin_from_winds": self.do_stellar_evolution_and_spin_from_winds, "do_gravitational_radiation": self.do_gravitational_radiation, - "verbose": self.verbose, - } + "verbose": self.verbose} def __repr__(self): """Return the name of evolution step.""" - return "Detached Step." + return "detached_step:\n" + \ + "\n".join([f"{key} = {getattr(self, key)}" for key in self.__dict__]) def __call__(self, binary): """ @@ -429,7 +353,7 @@ def __call__(self, binary): elif primary.co: mdot_acc = np.atleast_1d(bondi_hoyle( binary, primary, secondary, slice(-len(t), None), - wind_disk_criteria=True, scheme='Kudritzki+2000')) + wind_disk_criteria=True, RNG=self.RNG, scheme='Kudritzki+2000')) primary.lg_mdot = np.log10(mdot_acc.item(-1)) primary.lg_mdot_history[len(primary.lg_mdot_history) - len(t) + 1:] = np.log10(mdot_acc[:-1]) else: @@ -850,6 +774,102 @@ def update_co_stars(self, t, primary, secondary): getattr(obj, key + "_history").extend(history) class detached_evolution: + """ + ODE system describing the evolution of a detached binary. + + This class defines the differential equations governing the orbital + evolution and stellar spin evolution of a detached binary system. + It is designed to be passed directly to ``scipy.integrate.solve_ivp``, + with ``__call__`` returning the derivatives of the system state. + + The evolution can include contributions from several physical processes: + + - Stellar wind mass loss + - Tidal interactions + - Magnetic braking + - Gravitational wave radiation + - Spin evolution from stellar winds and structural changes + + The stellar properties required for these calculations are obtained + from interpolated single-star evolution tracks associated with the + ``SingleStar`` objects (binary components) and their + PChipInterpolator2 objects. + + Parameters + ---------- + primary : SingleStar, optional + Primary star of the binary (typically the more evolved star). + These must have an ``interp1d`` interpolator to return stellar + properties as a function of time. + + secondary : SingleStar, optional + Secondary star of the binary. + + do_wind_loss : bool, optional + If True, include orbital evolution due to stellar wind mass loss. + + do_tides : bool, optional + If True, include tidal interactions affecting orbital separation, + eccentricity, and stellar spin. + + do_magnetic_braking : bool, optional + If True, include stellar spin evolution due to magnetic braking. + + magnetic_braking_mode : {"RVJ83", "M15", "G18", "CARB"}, optional + Magnetic braking prescription: + + - RVJ83 — Rappaport, Verbunt & Joss (1983) + - M15 — Matt et al. (2015) + - G18 — Garraffo et al. (2018) + - CARB — Van & Ivanova (2019) + + do_stellar_evolution_and_spin_from_winds : bool, optional + If True, include spin evolution caused by stellar structural + evolution and angular momentum loss from winds. + + do_gravitational_radiation : bool, optional + If True, include orbital evolution from gravitational wave emission. + + verbose : bool, optional + If True, print diagnostic information during the integration. + + Attributes + ---------- + primary : SingleStar + Primary star used in the evolution. + + secondary : SingleStar + Secondary star used in the evolution. + + a : float + Current orbital separation (solar radii). + + e : float + Current orbital eccentricity. + + phys_keys : list of str + Names of stellar quantities tracked from the interpolated stellar + evolution models. + + t : float + Current system age during integration. + + Notes + ----- + The system state vector ``y`` evolved by ``solve_ivp`` is defined as:: + + y = [a, e, omega_secondary, omega_primary] + + where + + - ``a`` is the orbital separation (R☉) + - ``e`` is the orbital eccentricity + - ``omega_secondary`` is the spin angular velocity of the secondary (rad/yr) + - ``omega_primary`` is the spin angular velocity of the primary (rad/yr) + + Event functions defined in this class detect important transitions such + as Roche-lobe overflow or reaching the end of a stellar evolution track. + """ def __init__(self, primary=None, secondary=None, do_wind_loss=True, @@ -1159,7 +1179,7 @@ def update_props(self, t, y): y[3] = np.max([y[3], 0]) self.primary.latest["omega"] = y[3] - # store current delta(t)/time + # store current time self.t = t def __call__(self, t, y): diff --git a/posydon/binary_evol/DT/step_disrupted.py b/posydon/binary_evol/DT/step_disrupted.py index c3096b8d0a..99d51e8710 100644 --- a/posydon/binary_evol/DT/step_disrupted.py +++ b/posydon/binary_evol/DT/step_disrupted.py @@ -26,17 +26,9 @@ class DisruptedStep(IsolatedStep): Prepare a runaway star to do an an isolated_step) """ - def __init__(self, - grid_name_Hrich=None, - grid_name_strippedHe=None, - path=PATH_TO_POSYDON_DATA, - *args, **kwargs): - - super().__init__( - grid_name_Hrich=grid_name_Hrich, - grid_name_strippedHe=grid_name_strippedHe, - *args, - **kwargs) + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) def __call__(self,binary): diff --git a/posydon/binary_evol/DT/step_initially_single.py b/posydon/binary_evol/DT/step_initially_single.py index 564f745479..3b37e6cac2 100644 --- a/posydon/binary_evol/DT/step_initially_single.py +++ b/posydon/binary_evol/DT/step_initially_single.py @@ -26,17 +26,9 @@ class InitiallySingleStep(IsolatedStep): Prepare a runaway star to do an an isolated_step) """ - def __init__(self, - grid_name_Hrich=None, - grid_name_strippedHe=None, - path=PATH_TO_POSYDON_DATA, - *args, **kwargs): - - super().__init__( - grid_name_Hrich=grid_name_Hrich, - grid_name_strippedHe=grid_name_strippedHe, - *args, - **kwargs) + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) def __call__(self,binary): diff --git a/posydon/binary_evol/DT/step_isolated.py b/posydon/binary_evol/DT/step_isolated.py index ecb67f7fe9..2afc19717f 100644 --- a/posydon/binary_evol/DT/step_isolated.py +++ b/posydon/binary_evol/DT/step_isolated.py @@ -24,29 +24,9 @@ class IsolatedStep(detached_step): """ - def __init__(self, - grid_name_Hrich=None, - grid_name_strippedHe=None, - path=PATH_TO_POSYDON_DATA, - #dt=None, - #n_o_steps_history=None, - do_wind_loss=False, - do_tides=False, - do_gravitational_radiation=False, - do_magnetic_braking=False, - *args, **kwargs): - super().__init__( - grid_name_Hrich=grid_name_Hrich, - grid_name_strippedHe=grid_name_strippedHe, - path=path, - #dt=dt, - #n_o_steps_history=n_o_steps_history, - do_wind_loss=do_wind_loss, - do_tides=do_tides, - do_gravitational_radiation=do_gravitational_radiation, - do_magnetic_braking=do_magnetic_braking, - *args, - **kwargs) + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) diff --git a/posydon/binary_evol/DT/step_merged.py b/posydon/binary_evol/DT/step_merged.py index d3c32bee3a..55e697daf7 100644 --- a/posydon/binary_evol/DT/step_merged.py +++ b/posydon/binary_evol/DT/step_merged.py @@ -21,7 +21,6 @@ STARPROPERTIES, convert_star_to_massless_remnant, ) -from posydon.config import PATH_TO_POSYDON_DATA from posydon.utils.common_functions import check_state_of_star from posydon.utils.posydonerror import ModelError from posydon.utils.posydonwarning import Pwarn @@ -41,50 +40,16 @@ class MergedStep(IsolatedStep): Prepare a merging star to do an an IsolatedStep """ - def __init__( - self, - grid_name_Hrich=None, - grid_name_strippedHe=None, - path=PATH_TO_POSYDON_DATA, - merger_critical_rot = 0.4, - rel_mass_lost_HMS_HMS = 0.1, - list_for_matching_HMS = [ - ["mass", "center_h1", "he_core_mass"], - [20.0, 1.0, 10.0], - ["log_min_max", "min_max", "min_max"], - #[m_min_H, m_max_H], [0, None] - [None, None], [0, None] - ], - list_for_matching_postMS = [ - ["mass", "center_he4", "he_core_mass"], - [20.0, 1.0, 10.0], - ["log_min_max", "min_max", "min_max"], - #[m_min_H, m_max_H], [0, None] - [None, None], [0, None] - ], - list_for_matching_HeStar = [ - ["he_core_mass", "center_he4"], - [10.0, 1.0], - ["min_max" , "min_max"], - #[[m_min_He, m_max_He], [0, None]], - [None, None], [0, None] - ], - *args, - **kwargs - ): + def __init__(self, + merger_critical_rot = 0.4, + rel_mass_lost_HMS_HMS = 0.1, + *args, + **kwargs): self.merger_critical_rot = merger_critical_rot self.rel_mass_lost_HMS_HMS = rel_mass_lost_HMS_HMS - super().__init__( - grid_name_Hrich=grid_name_Hrich, - grid_name_strippedHe=grid_name_strippedHe, - list_for_matching_HMS = list_for_matching_HMS, - list_for_matching_postMS = list_for_matching_postMS, - list_for_matching_HeStar = list_for_matching_HeStar, - *args, - **kwargs) - + super().__init__(*args, **kwargs) def __call__(self,binary): diff --git a/posydon/binary_evol/MESA/step_mesa.py b/posydon/binary_evol/MESA/step_mesa.py index b6fb65343c..1b324c3b84 100644 --- a/posydon/binary_evol/MESA/step_mesa.py +++ b/posydon/binary_evol/MESA/step_mesa.py @@ -126,22 +126,22 @@ class MesaGridStep: """Superclass for steps using the POSYDON grids.""" - def __init__( - self, - metallicity, - grid_name, - path=PATH_TO_POSYDON_DATA, - interpolation_path=None, - interpolation_filename=None, - interpolation_method="linear3c_kNN", - save_initial_conditions=True, - track_interpolation=False, - stop_method='stop_at_max_time', # "stop_at_end", - stop_star="star_1", - stop_var_name=None, - stop_value=None, - stop_interpolate=True, - verbose=False): + DEFAULT_KWARGS = {'metallicity': None, + 'grid_path': None, + 'interpolation_path': None, + 'interpolation_filename': None, + 'interpolation_method': 'nearest_neighbour', + 'save_initial_conditions': True, + 'track_interpolation': False, + 'stop_method': 'stop_at_max_time', # "stop_at_end" + 'stop_star': 'star_1', + 'stop_var_name': None, + 'stop_value': None, + 'stop_interpolate': True, + 'RNG': np.random.default_rng(), + 'verbose': False} + + def __init__(self, **kwargs): """Evolve a binary object given a MESA grid or interpolation object. Parameters @@ -179,26 +179,34 @@ def __init__( stop_value """ - # class variable - self.path = path - self.interpolation_method = interpolation_method - self.save_initial_conditions = save_initial_conditions - self.track_interpolation = track_interpolation - self.stop_method = stop_method - self.verbose = verbose + # read kwargs to initialize the class + if kwargs: + for key in kwargs: + if key not in self.DEFAULT_KWARGS: + raise ValueError(key + " is not a valid parameter name!") + for varname in self.DEFAULT_KWARGS: + default_value = self.DEFAULT_KWARGS[varname] + setattr(self, varname, kwargs.get(varname, default_value)) + else: + for varname in self.DEFAULT_KWARGS: + default_value = self.DEFAULT_KWARGS[varname] + setattr(self, varname, default_value) if (self.track_interpolation and self.interpolation_method != 'nearest_neighbour'): raise ValueError('Track interpolation is currently supported only ' 'by the nearest neighbour interpolation method!') + z_str = convert_metallicity_to_string(self.metallicity) + self.grid_name = os.path.join(self.grid_path, f"{z_str}_Zsun.h5") + # we load NN any time stop_at_max_time requested - regardless # of interp method if (self.stop_method == 'stop_at_max_time' or self.interpolation_method == 'nearest_neighbour'): - self.load_psyTrackInterp(grid_name) + self.load_psyTrackInterp() - grid_name = grid_name.replace('_%d', '') + self.grid_name = self.grid_name.replace('_%d', '') # Check interpolation method provided self.supported_interp_methods = ['linear_kNN', 'linear3c_kNN', @@ -206,20 +214,19 @@ def __init__( if self.interpolation_method in self.supported_interp_methods: # Set the interpolation path - if interpolation_path is None: - interpolation_path = os.path.join(self.path, - os.path.split(grid_name)[0], + if self.interpolation_path is None: + self.interpolation_path = os.path.join(self.grid_path, 'interpolators/%s' % self.interpolation_method) # Set the interpolation filename - if interpolation_filename is None: - interpolation_filename = os.path.join(interpolation_path, - os.path.split(grid_name)[1].replace('h5', 'pkl')) + if self.interpolation_filename is None: + self.interpolation_filename = os.path.join(self.interpolation_path, + os.path.basename(self.grid_name).replace('h5', 'pkl')) else: - interpolation_filename = os.path.join(interpolation_path, - interpolation_filename) + self.interpolation_filename = os.path.join(self.interpolation_path, + self.interpolation_filename) - self.load_Interp(interpolation_filename) + self.load_Interp(self.interpolation_filename) if (not (hasattr(self, '_psyTrackInterp') or hasattr(self, '_Interp'))): @@ -233,10 +240,6 @@ def __init__( # we drop the history self.flush_history = False self.flush_entries = None - self.stop_star = stop_star - self.stop_var_name = stop_var_name - self.stop_value = stop_value - self.stop_interpolate = stop_interpolate self._find_boundaries() def _find_boundaries(self): @@ -256,17 +259,16 @@ def initial_values_min_max(parameter_name): self.m2_min, self.m2_max = initial_values_min_max('star_2_mass') self.p_min, self.p_max = initial_values_min_max('period_days') - def load_psyTrackInterp(self, grid_name): + def load_psyTrackInterp(self): """Load the interpolator that has been trained on the grid.""" # Check if interpolation files exist - filename = os.path.join(self.path,grid_name) - if not (os.path.exists(filename.replace('%d','0')) or - os.path.exists(filename.replace('_%d',''))): + if not (os.path.exists(self.grid_name.replace('%d','0')) or + os.path.exists(self.grid_name.replace('_%d',''))): data_download() if self.verbose: - print("loading psyTrackInterp: {}".format(filename)) - self._psyTrackInterp = psyTrackInterp(filename, + print("loading psyTrackInterp: {}".format(self.grid_name)) + self._psyTrackInterp = psyTrackInterp(self.grid_name, interp_in_q=self.interp_in_q, verbose=self.verbose) self._psyTrackInterp.train() @@ -696,8 +698,8 @@ def update_properties_NN(self, star_1_CO=False, star_2_CO=False, setattr(binary, 'event', binary_event) setattr(binary, 'mass_transfer_case', MT_case) - culmulative_mt_case = self.termination_flags[1] - setattr(self.binary, f'culmulative_mt_case_{self.grid_type}', culmulative_mt_case) + cumulative_mt_case = self.termination_flags[1] + setattr(self.binary, f'cumulative_mt_case_{self.grid_type}', cumulative_mt_case) setattr(self.binary, f'interp_class_{self.grid_type}', interpolation_class) mt_history = self.termination_flags[2] # mass transfer history (TF12 plot label) setattr(self.binary, f'mt_history_{self.grid_type}', mt_history) @@ -774,7 +776,7 @@ def update_properties_NN(self, star_1_CO=False, star_2_CO=False, key_bh = POSYDON_TO_MESA['star']['lg_mdot']+'_%d' % (k_bh+1) tmp_lg_mdot = np.log10(10**cb_bh[key_bh][-1] + cf.bondi_hoyle( binary, accretor, donor, idx=-1, - wind_disk_criteria=True, scheme='Kudritzki+2000')) + wind_disk_criteria=True, RNG=self.RNG, scheme='Kudritzki+2000')) mdot_edd = cf.eddington_limit(binary, idx=-1)[0] if 10**tmp_lg_mdot > mdot_edd: @@ -787,7 +789,7 @@ def update_properties_NN(self, star_1_CO=False, star_2_CO=False, history_of_attribute = (np.log10( 10**cb_bh[key_bh][0] + cf.bondi_hoyle( binary, accretor, donor, idx=len_binary_hist, - wind_disk_criteria=True, scheme='Kudritzki+2000'))) + wind_disk_criteria=True, RNG=self.RNG, scheme='Kudritzki+2000'))) if 10**history_of_attribute > edd: history_of_attribute = np.log10(edd) accretor.lg_mdot_history.append(history_of_attribute) @@ -799,6 +801,7 @@ def update_properties_NN(self, star_1_CO=False, star_2_CO=False, # hence we loop one back range(-N-1,-1) tmp_h = [cf.bondi_hoyle(binary, accretor, donor, idx=i, wind_disk_criteria=True, + RNG=self.RNG, scheme='Kudritzki+2000') for i in range(-length_hist-1, -1)] tmp_edd = [cf.eddington_limit(binary, idx=i)[0] @@ -930,7 +933,7 @@ def initial_final_interpolation(self, star_1_CO=False, star_2_CO=False): setattr(self.binary, f'mt_history_{self.grid_type}', mt_history) #TODO: add classifier for tf2 - #setattr(self.binary, f'culmulative_mt_case', self.classes['termination_flags_2']) + #setattr(self.binary, f'cumulative_mt_case', self.classes['termination_flags_2']) S1_state_inferred = cf.check_state_of_star(self.binary.star_1, star_CO=star_1_CO) S2_state_inferred = cf.check_state_of_star(self.binary.star_2, @@ -965,7 +968,8 @@ def initial_final_interpolation(self, star_1_CO=False, star_2_CO=False): tmp_lg_mdot = np.log10( 10**fv[key_bh] + cf.bondi_hoyle( binary, accretor, donor, idx=-1, - wind_disk_criteria=True, scheme='Kudritzki+2000')) + wind_disk_criteria=True, + RNG=self.RNG, scheme='Kudritzki+2000')) mdot_edd = cf.eddington_limit(binary, idx=-1)[0] if 10**tmp_lg_mdot > mdot_edd: @@ -1262,16 +1266,11 @@ def interpolate_at_t(self, t, t_before, t_after, v_before, v_after): class MS_MS_step(MesaGridStep): """Class for performing the MESA step for a MS-MS binary.""" - def __init__(self, metallicity=1., grid_name=None, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize a MS_MS_step instance.""" self.grid_type = 'HMS_HMS' self.interp_in_q = True - if grid_name is None: - metallicity = convert_metallicity_to_string(metallicity) - grid_name = 'HMS-HMS/' + metallicity + '_Zsun.h5' - super().__init__(metallicity=metallicity, - grid_name=grid_name, - *args, **kwargs) + super().__init__(*args, **kwargs) # special stuff for my step goes here # set mass ratio @@ -1360,20 +1359,20 @@ def __call__(self, binary): '- H-rich_Core_H_burning - * - ZAMS' % (state_1, state_2, event)) + def __repr__(self): + """Return the name of evolution step and settings.""" + return "MS_MS_step:\n" + \ + "\n".join([f"{key} = {getattr(self, key)}" for key in self.__dict__]) + class CO_HMS_RLO_step(MesaGridStep): """Class for performing the MESA step for a CO-HMS_RLO binary.""" - def __init__(self, metallicity=1., grid_name=None, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize a CO_HMS_RLO_step instance.""" self.grid_type = 'CO_HMS_RLO' self.interp_in_q = False - if grid_name is None: - metallicity = convert_metallicity_to_string(metallicity) - grid_name = 'CO-HMS_RLO/' + metallicity + '_Zsun.h5' - super().__init__(metallicity=metallicity, - grid_name=grid_name, - *args, **kwargs) + super().__init__(*args, **kwargs) def __call__(self, binary): """Evolve a binary using the MESA step.""" @@ -1475,20 +1474,20 @@ def __call__(self, binary): self.binary.event = "redirect_from_CO_HMS_RLO" return + def __repr__(self): + """Return the name of evolution step and settings.""" + return "CO_HMS_RLO_step:\n" + \ + "\n".join([f"{key} = {getattr(self, key)}" for key in self.__dict__]) + class CO_HeMS_RLO_step(MesaGridStep): """Class for performing the MESA step for a CO-HeMS_RLO binary.""" - def __init__(self, metallicity=1., grid_name=None, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize a CO_HeMS_RLO_step instance.""" self.grid_type = 'CO_HeMS_RLO' self.interp_in_q = False - if grid_name is None: - metallicity = convert_metallicity_to_string(metallicity) - grid_name = 'CO-HeMS_RLO/' + metallicity + '_Zsun.h5' - super().__init__(metallicity=metallicity, - grid_name=grid_name, - *args, **kwargs) + super().__init__(*args, **kwargs) def __call__(self, binary): """Evolve a binary using the MESA step.""" @@ -1590,20 +1589,20 @@ def __call__(self, binary): self.binary.event = "redirect_from_CO_HeMS_RLO" return + def __repr__(self): + """Return the name of evolution step and settings.""" + return "CO_HeMS_RLO_step:\n" + \ + "\n".join([f"{key} = {getattr(self, key)}" for key in self.__dict__]) + class CO_HeMS_step(MesaGridStep): """Class for performing the MESA step for a CO-HeMS binary.""" - def __init__(self, metallicity=1., grid_name=None, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize a CO_HeMS_step instance.""" self.grid_type = 'CO_HeMS' self.interp_in_q = False - if grid_name is None: - metallicity = convert_metallicity_to_string(metallicity) - grid_name = 'CO-HeMS/' + metallicity + '_Zsun.h5' - super().__init__(metallicity=metallicity, - grid_name=grid_name, - *args, **kwargs) + super().__init__(*args, **kwargs) def __call__(self, binary): """Apply the CO_HeMS step to a BinaryStar object.""" @@ -1686,6 +1685,11 @@ def __call__(self, binary): self.binary.event = 'redirect_from_CO_HeMS' return + def __repr__(self): + """Return the name of evolution step and settings.""" + return "CO_HeMS_step:\n" + \ + "\n".join([f"{key} = {getattr(self, key)}" for key in self.__dict__]) + class HMS_HMS_RLO_step(MesaGridStep): """Class for performing the MESA step for a HMS-HMS RLO binary. @@ -1693,16 +1697,11 @@ class HMS_HMS_RLO_step(MesaGridStep): we evolve them first with step detached and map to the HMS-HMS RLO grid using `initial_eccentricity_flow_chart`.""" - def __init__(self, metallicity=1., grid_name=None, *args, **kwargs): + def __init__(self, *args, **kwargs): """Initialize a HMS_HMS_RLO_step instance.""" self.grid_type = 'HMS_HMS_RLO' self.interp_in_q = True - if grid_name is None: - metallicity = convert_metallicity_to_string(metallicity) - grid_name = 'HMS-HMS_RLO/' + metallicity + '_Zsun.h5' - super().__init__(metallicity=metallicity, - grid_name=grid_name, - *args, **kwargs) + super().__init__(*args, **kwargs) # special stuff for my step goes here # If nothing to do, no init necessary @@ -1835,3 +1834,8 @@ def __call__(self, binary): self.binary.state = "detached" self.binary.event = "redirect_from_HMS_HMS_RLO" return + + def __repr__(self): + """Return the name of evolution step and settings.""" + return "HMS_HMS_RLO_step:\n" + \ + "\n".join([f"{key} = {getattr(self, key)}" for key in self.__dict__]) diff --git a/posydon/binary_evol/SN/step_SN.py b/posydon/binary_evol/SN/step_SN.py index 6881de110f..554aa4525c 100644 --- a/posydon/binary_evol/SN/step_SN.py +++ b/posydon/binary_evol/SN/step_SN.py @@ -20,7 +20,8 @@ "Tassos Fragos ", "Matthias Kruckow ", "Max Briel ", - "Seth Gossage " + "Seth Gossage ", + "Dimitris Souropanis " ] __credits__ = [ @@ -86,24 +87,6 @@ path_to_Couch_datasets = os.path.join(PATH_TO_POSYDON_DATA, "Couch+2020/") -SN_MODEL = { - # kick physics - "kick": True, - "kick_normalisation": 'one_over_mass', - "kick_prescription": 'maxwellian', - "sigma_kick_CCSN_NS": 265.0, - "mean_kick_CCSN_NS": None, - "sigma_kick_CCSN_BH": 265.0, - "mean_kick_CCSN_BH": None, - "sigma_kick_ECSN": 20.0, - "mean_kick_ECSN": None, - # other - "verbose": False, -} -# add core collapse physics -SN_MODEL.update(DEFAULT_SN_MODEL) - - class StepSN(object): """The supernova step in POSYDON. @@ -132,8 +115,11 @@ class StepSN(object): * 'Couch+20-engine': Uses the results from [6]_ to describe the collapse of the star. + * 'Maltsev+25-engine': Uses the results from [8]_ + to describe the collapse of the star + engine : str - Engine used for supernova remnanrt outcome propierties for the + Engine used for supernova remnant outcome propierties for the Sukhbold+16-engineand and Patton&Sukhbold20-engine mechanisms. Available options: @@ -254,22 +240,45 @@ class StepSN(object): Heger, A., and Pfahl, E. 2004, ApJ, 612, 1044. The Effects of Binary Evolution on the Dynamics of Core Collapse and Neutron Star Kicks + .. [8] K. Maltsev, F.R.N. Schneider, I. Mandel, B. Mueller, A. Heger, F.K. Roepke, + E. Laplace, 2025, A&A, 700, A20. Explodability criteria for the neutrino-driven + supernova mechanism """ + DEFAULT_KWARGS = { + # kick physics + "kick": True, + "kick_normalisation": 'one_over_mass', + "kick_prescription": 'maxwellian', + "sigma_kick_CCSN_NS": 265.0, + "mean_kick_CCSN_NS": None, + "sigma_kick_CCSN_BH": 265.0, + "mean_kick_CCSN_BH": None, + "sigma_kick_ECSN": 20.0, + "mean_kick_ECSN": None, + # other + "RNG": None, + "verbose": False + } + # add core collapse physics + DEFAULT_KWARGS.update(DEFAULT_SN_MODEL) + def __init__(self, **kwargs): """Initialize a StepSN instance.""" # read kwargs to initialize the class if kwargs: for key in kwargs: - if key not in SN_MODEL: + if key not in self.DEFAULT_KWARGS: raise ValueError(key + " is not a valid parameter name!") - for varname in SN_MODEL: - default_value = SN_MODEL[varname] - setattr(self, varname, kwargs.get(varname, default_value)) + for varname in self.DEFAULT_KWARGS: + setattr(self, varname, kwargs.get(varname, self.DEFAULT_KWARGS[varname])) + self.RNG = kwargs.get("RNG") + if self.RNG is None: + self.RNG = np.random.default_rng() + else: - for varname in SN_MODEL: - default_value = SN_MODEL[varname] - setattr(self, varname, default_value) + for varname in self.DEFAULT_KWARGS: + setattr(self, varname, self.DEFAULT_KWARGS[varname]) # backward compatibility for kick if (self.kick_normalisation == 'asym_ej' @@ -294,6 +303,9 @@ def __init__(self, **kwargs): self.Sukhbold16_engines = "Sukhbold+16-engine" self.Patton20_engines = "Patton&Sukhbold20-engine" self.Couch20_engines = "Couch+20-engine" + self.Maltsev25_engines = "Maltsev+25-engine" + + self.mechanisms = [ self.Fryer12_rapid, @@ -302,7 +314,8 @@ def __init__(self, **kwargs): self.direct_collapse_hecore, self.Sukhbold16_engines, self.Patton20_engines, - self.Couch20_engines + self.Couch20_engines, + self.Maltsev25_engines ] if self.mechanism in self.mechanisms: @@ -332,7 +345,7 @@ def __init__(self, **kwargs): path_engine_dataset=self.path_to_Couch_datasets, verbose=self.verbose) - elif self.mechanism == self.Patton20_engines: + elif self.mechanism in (self.Patton20_engines, self.Maltsev25_engines): self.path_to_Patton_datasets = path_to_Patton_datasets def format_data_Patton20(file_name): @@ -382,11 +395,15 @@ def format_data_Patton20(file_name): return CO_core_params, target if self.verbose: - print('Loading the train dataset for engine mu4 and M4...') + print('Loading the train dataset for engine mu4, M4, Xi, and sc ...') CO_core_params_mu4, mu4_target = format_data_Patton20( 'Kepler_mu4_table.dat') CO_core_params_M4, M4_target = format_data_Patton20( 'Kepler_M4_table.dat') + CO_core_params_Xi, Xi_target = format_data_Patton20( + 'Kepler_Xi_table.dat') + CO_core_params_sc, sc_target = format_data_Patton20( + 'Kepler_sc_table.dat') n_neighbors = 5 @@ -399,6 +416,14 @@ def format_data_Patton20(file_name): self.mu4_interpolator = neighbors.KNeighborsRegressor( n_neighbors, weights='distance') self.mu4_interpolator.fit(CO_core_params_mu4, mu4_target) + + self.Xi_interpolator = neighbors.KNeighborsRegressor( + n_neighbors, weights='distance') + self.Xi_interpolator.fit(CO_core_params_Xi, Xi_target) + + self.sc_interpolator = neighbors.KNeighborsRegressor( + n_neighbors, weights='distance') + self.sc_interpolator.fit(CO_core_params_sc, sc_target) if self.verbose: print('Done') else: @@ -407,7 +432,7 @@ def format_data_Patton20(file_name): def __repr__(self): """Get the string representation of the class and any parameters.""" return "StepSN:\n" + \ - "\n".join([f"{key} = {getattr(self, key)}" for key in SN_MODEL]) + "\n".join([f"{key} = {getattr(self, key)}" for key in self.__dict__]) def _reset_other_star_properties(self, star): @@ -790,7 +815,8 @@ def collapse_star(self, star): elif self.mechanism in [self.Sukhbold16_engines, self.Patton20_engines, - self.Couch20_engines]: + self.Couch20_engines, + self.Maltsev25_engines]: # The final remnant mass and and state # is computed by the selected mechanism @@ -1393,6 +1419,21 @@ def compute_m_rembar(self, star, m_PISN): m_rembar, f_fb, state = self.Patton20_corecollapse(star, self.engine, self.conserve_hydrogen_envelope) + + elif self.mechanism == self.Maltsev25_engines: + if star.SN_type == "ECSN": + if self.ECSN == 'Podsiadlowski+04': + m_proto = 1.38 + else: + m_proto = m_core + f_fb = 0.0 + m_fb = 0.0 + m_rembar = m_proto + m_fb + state = 'NS' + else: + m_rembar, f_fb, state = self.Maltsev25_corecollapse(star, + self.engine, + self.conserve_hydrogen_envelope) else: raise ValueError("Mechanism %s not supported." % self.mechanism) @@ -1541,13 +1582,13 @@ def orbital_kick(self, binary): if not binary.star_1.natal_kick_azimuthal_angle is None: phi = binary.star_1.natal_kick_azimuthal_angle else: - phi = np.random.uniform(0, 2 * np.pi) + phi = self.RNG.uniform(0, 2 * np.pi) binary.star_1.natal_kick_azimuthal_angle = phi if not binary.star_1.natal_kick_polar_angle is None: cos_theta = np.cos(binary.star_1.natal_kick_polar_angle) else: - cos_theta = np.random.uniform(-1, 1) + cos_theta = self.RNG.uniform(-1, 1) binary.star_1.natal_kick_polar_angle = np.arccos(cos_theta) # generate random point in the orbit where the kick happens @@ -1558,7 +1599,7 @@ def orbital_kick(self, binary): raise ValueError("mean_anomaly must be a single float value." f"\n mean_anomaly = {mean_anomaly}") else: - mean_anomaly = np.random.uniform(0, 2 * np.pi) + mean_anomaly = self.RNG.uniform(0, 2 * np.pi) binary.star_1.natal_kick_mean_anomaly = mean_anomaly elif binary.event == "CC2": @@ -1642,13 +1683,13 @@ def orbital_kick(self, binary): if not binary.star_2.natal_kick_azimuthal_angle is None: phi = binary.star_2.natal_kick_azimuthal_angle else: - phi = np.random.uniform(0, 2 * np.pi) + phi = self.RNG.uniform(0, 2 * np.pi) binary.star_2.natal_kick_azimuthal_angle = phi if not binary.star_2.natal_kick_polar_angle is None: cos_theta = np.cos(binary.star_2.natal_kick_polar_angle) else: - cos_theta = np.random.uniform(-1, 1) + cos_theta = self.RNG.uniform(-1, 1) binary.star_2.natal_kick_polar_angle = np.arccos(cos_theta) # generate random point in the orbit where the kick happens @@ -1658,7 +1699,7 @@ def orbital_kick(self, binary): if not isinstance(mean_anomaly, float): raise ValueError("mean_anomaly must be a single float value.") else: - mean_anomaly = np.random.uniform(0, 2 * np.pi) + mean_anomaly = self.RNG.uniform(0, 2 * np.pi) binary.star_2.natal_kick_mean_anomaly = mean_anomaly # update the orbit @@ -1760,32 +1801,34 @@ def orbital_kick(self, binary): # extended to Eq 13, in Wong, T.-W., Valsecchi, F., Fragos, T., & Kalogera, V. 2012, ApJ, 747, 111 # get the orbital separation post SN # Eq from conservation of energy - Apost = ((2.0 / rpre) - - (((Vkick ** 2) + (Vr ** 2) + (2 * (Vkick * cos_theta) * Vr)) / (G * Mtot_post)) - ) ** -1 + # Note: Suppress overflow warnings for extreme kick scenarios that lead to + # disrupted binaries. + with np.errstate(over='ignore', divide='ignore', invalid='ignore'): + Apost = ((2.0 / rpre) + - (((Vkick ** 2) + (Vr ** 2) + (2 * (Vkick * cos_theta) * Vr)) / (G * Mtot_post)) + ) ** -1 - # get kicks componets in the coordinate system - Vkx = Vkick * (sin_theta * np.sin(phi) * sin_psi + cos_theta * cos_psi) - Vky = Vkick * (-sin_theta * np.sin(phi) * cos_psi + cos_theta * sin_psi) - Vkz = Vkick * sin_theta * np.cos(phi) + # get kicks componets in the coordinate system + Vkx = Vkick * (sin_theta * np.sin(phi) * sin_psi + cos_theta * cos_psi) + Vky = Vkick * (-sin_theta * np.sin(phi) * cos_psi + cos_theta * sin_psi) + Vkz = Vkick * sin_theta * np.cos(phi) - # Eq 4, in Kalogera, V. 1996, ApJ, 471, 352 - # extended to Eq 14 in Wong, T.-W., Valsecchi, F., Fragos, T., & Kalogera, V. 2012, ApJ, 747, 111 - # get the eccentricity post SN - # Eq from setting specific angular momentum r X Vr = sqrt(G*M*A*(1-e**2)) + # Eq 4, in Kalogera, V. 1996, ApJ, 471, 352 + # extended to Eq 14 in Wong, T.-W., Valsecchi, F., Fragos, T., & Kalogera, V. 2012, ApJ, 747, 111 + # get the eccentricity post SN + # Eq from setting specific angular momentum r X Vr = sqrt(G*M*A*(1-e**2)) + x = ((Vkz ** 2 + (Vky + Vr * sin_psi)** 2) + * rpre ** 2 + / (G * Mtot_post * Apost)) - x = ((Vkz ** 2 + (Vky + Vr * sin_psi)** 2) - * rpre ** 2 - / (G * Mtot_post * Apost)) - - # catch negative values, i.e. disrupted binaries - if 1.-x < 0.: - epost = np.nan - else: - epost = np.sqrt(1 - x) + # catch negative values, i.e. disrupted binaries + if 1.-x < 0.: + epost = np.nan + else: + epost = np.sqrt(1 - x) # Compute COM velocity, VS, post SN # VS_pre in COM frame is 0. So VS_post in COM frame is @@ -1813,7 +1856,9 @@ def orbital_kick(self, binary): # cos(tilt) = Lpre dot Lpost / ||Lpre||||Lpost|| # For epre=0 (sin_psi=1), reduces to Eq 4, in Kalogera, V. 1996, ApJ, 471, 352 - tilt = np.arccos((Vky + Vr * sin_psi) / np.sqrt( Vkz ** 2 + (Vky + Vr * sin_psi) ** 2 )) + # Suppress overflow warnings for extreme values + with np.errstate(over='ignore', invalid='ignore'): + tilt = np.arccos((Vky + Vr * sin_psi) / np.sqrt( Vkz ** 2 + (Vky + Vr * sin_psi) ** 2 )) # Track direction of tilt if Vkz < 0: tilt *= -1 @@ -1887,11 +1932,14 @@ def SNCheck( # (see, e.g., Kalogera, V. & Lorimer, D.R. 2000, ApJ, 530, 890) # The derivation in the papers above assume a circular pre SN # orbit. Hence, need a correction for eccentric pre SN orbits: - eccentric_orbit_correction = Vr**2 * rpre / (G * Mtot_pre) - tmp1 = 2 - Mtot_pre / Mtot_post * (Vkick / Vr - 1) ** 2\ - * eccentric_orbit_correction - tmp2 = 2 - Mtot_pre / Mtot_post * (Vkick / Vr + 1) ** 2\ - * eccentric_orbit_correction + # Suppress divide by zero warnings for edge cases + with np.errstate(divide='ignore', over='ignore', invalid='ignore'): + eccentric_orbit_correction = Vr**2 * rpre / (G * Mtot_pre) + tmp1 = 2 - Mtot_pre / Mtot_post * (Vkick / Vr - 1) ** 2\ + * eccentric_orbit_correction + tmp2 = 2 - Mtot_pre / Mtot_post * (Vkick / Vr + 1) ** 2\ + * eccentric_orbit_correction + SNflag2 = ((rpre / Apost - tmp1 < err) and (err > tmp2 - rpre / Apost)) @@ -2084,7 +2132,7 @@ def _get_kick_velocity(self, star, sigma=None, mean=None): # this is a fallback if sigma is None: sigma = 265.0 - Vkick_ej = sp.stats.maxwell.rvs(loc=0., scale=sigma, size=1)[0] + Vkick_ej = sp.stats.maxwell.rvs(loc=0., scale=sigma, size=1, random_state=self.RNG)[0] elif self.kick_prescription == "log_normal": # sigma==None should never be reached, since in that case Vkick=0 @@ -2094,7 +2142,7 @@ def _get_kick_velocity(self, star, sigma=None, mean=None): sigma = 0.68 if mean is None: mean = np.exp(5.60) - Vkick_ej = sp.stats.lognorm.rvs(s=sigma, scale=mean, size=1)[0] + Vkick_ej = sp.stats.lognorm.rvs(s=sigma, scale=mean, size=1, random_state=self.RNG)[0] elif self.kick_prescription == "asym_ej": f_kin = 0.1 # Fraction of SN explosion energy that is kinetic energy of the gas @@ -2252,10 +2300,13 @@ def get_CO_core_params(self, star, approximation=False): def get_M4_mu4_Patton20(self, CO_core_mass, C_core_abundance): """Get the M4 and mu4 using Patton+20.""" + M4 = self.M4_interpolator.predict([[C_core_abundance, CO_core_mass]]) mu4 = self.mu4_interpolator.predict([[C_core_abundance, CO_core_mass]]) + Xi = self.Xi_interpolator.predict([[C_core_abundance, CO_core_mass]]) + sc = self.sc_interpolator.predict([[C_core_abundance, CO_core_mass]]) - return M4, mu4 + return M4, mu4, Xi, sc def Patton20_corecollapse(self, star, engine, conserve_hydrogen_envelope=False): """Compute supernova final remnant mass and fallback fraction. @@ -2303,7 +2354,7 @@ def Patton20_corecollapse(self, star, engine, conserve_hydrogen_envelope=False): CO_core_mass, C_core_abundance = self.get_CO_core_params( star, self.approx_at_he_depletion) - M4, mu4 = self.get_M4_mu4_Patton20(CO_core_mass, C_core_abundance) + M4, mu4, Xi, sc = self.get_M4_mu4_Patton20(CO_core_mass, C_core_abundance) M4 = M4[0] mu4 = mu4[0] star.M4 = M4 @@ -2343,6 +2394,167 @@ def Patton20_corecollapse(self, star, engine, conserve_hydrogen_envelope=False): return m_rem, f_fb, state + def Maltsev25_corecollapse(self, star, engine, conserve_hydrogen_envelope=False): + """Compute supernova final remnant mass and fallback fraction. + + It uses the results from [8]_. The prediction for the core-collapse + outcome is performed using the C core mass and its C abundance. + The criterion by [8]_ is used to determine the final outcome. + + Parameters + ---------- + star : obj + Star object of a collapsing star containing the MESA profile. + engine : str + Engine to use for the core-collapse prescription + Possible options are: 'M16' + conserve_hydrogen_envelope : bool + Whether to assume that the hydrogen envelope is conserved in direct collapse to a BH. + + Returns + ------- + m_rem : double + Remnant mass of the compact object in M_sun. + f_fb : double + Fallback mass of the compact object in M_sun. + state : str + 'NS' if the remnant is a neutron star, 'BH' if the remnant is a black hole + + References + ---------- + .. [8] K. Maltsev, F.R.N. Schneider, I. Mandel, B. Mueller, A. Heger, + F.K. Roepke, E. Laplace, 2025, A&A, 700, A20. Explodability + criteria for the neutrino-driven supernova mechanism + + """ + Muller_k_parameters = { + 'M16': [0.005, 0.420] # Section 3.1.1. of [8]_ + } + + if engine not in Muller_k_parameters.keys(): + raise ValueError("Engine " + engine + " is not avaiable for the " + "Maltsev+25 core-collapse prescription, " + "please choose one of the following engines to " + "compute the collapse: \n" + "\n".join( + list(Muller_k_parameters.keys()))) + else: + + CO_core_mass, C_core_abundance = self.get_CO_core_params( + star, self.approx_at_he_depletion) + M4, mu4, Xi, sc = self.get_M4_mu4_Patton20(CO_core_mass, C_core_abundance) + M4 = M4[0] + mu4 = mu4[0] + Xi = Xi[0] + sc = sc[0] + mu4M4 = mu4*M4 + star.M4 = M4 + star.mu4 = mu4 + star.Xi = Xi + star.sc = sc + + + k1 = Muller_k_parameters[engine][0] + k2 = Muller_k_parameters[engine][1] + + if CO_core_mass <= 2.5: + m_rem = 1.25 + f_fb = 0.0 + state = 'NS' + + # In the Maltsev prescription, stars with CO core masses above 10 are allowed to explode. + # However, since this outcome depends on the mass-transfer (MT) history, we handle it + # in post-processing (for now). For all CO core masses above 10, we assume a failed supernova + # with fallback = 1 at this stage. + elif CO_core_mass >= 10.0: + # Assuming BH formation by direct collapse + if conserve_hydrogen_envelope: + m_rem = star.mass + else: + m_rem = star.he_core_mass + f_fb = 1.0 + state = 'BH' + + elif (CO_core_mass > 2.5) and (CO_core_mass < 10.0): + successful_SN = self.explod_crit(Xi, sc, mu4M4, mu4, k1, k2) + + if successful_SN: + rem = self.NS_vs_fallbackBH(Xi, CO_core_mass, M4, mu4M4) + if rem == 'NS': # successful SN with NS + m_rem = M4 + f_fb = 0.0 + state = 'NS' + + else: # successful SN but with fallback BH + if conserve_hydrogen_envelope: + m_rem = star.mass + else: + m_rem = star.he_core_mass + + f_fb = 0.99 + state = 'BH' + + else: + if conserve_hydrogen_envelope: + m_rem = star.mass + else: + m_rem = star.he_core_mass + + f_fb = 1.0 + state = 'BH' + + return m_rem, f_fb, state + + def NS_vs_fallbackBH(self, comp_val, mco_val, M4_val, mu4M4_val): + a, b = 1.75, -0.044 # eq. (8) of [8]_ + # conditions for guaranteed NS formation (eq. 7) + if comp_val <= 0.04 or (comp_val < a*mu4M4_val + b and comp_val <= 0.4) or M4_val/mco_val > 0.6: + rem = 'NS' + else: + # stochastic determination of the remnant type (NS versus fallback-BH) + rand_number = self.RNG.uniform(0,1) + if rand_number <= 0.15: # probability for fallback = 0.15 in Section 3.1.2. + rem = 'fallback_BH' + else: + rem = 'NS' + return rem + + # implemented from Maltsev+25 + def explod_crit(self, comp_val, sc_val, mu4M4_val, mu4_val, k1, k2): + ff1, ff2 = [], [] + unclassified = True + comp_crit1, comp_crit2 = 0.314, 0.544 # compactness + sc_crit1, sc_crit2 = 0.988, 1.169 # central specific entropy + mu4M4_crit1, mu4M4_crit2 = 0.247, 0.421 # product of M4 and mu4 + + # check whether criterion for failed SN is fulfilled + if comp_val > comp_crit2 or sc_val > sc_crit2: + ff2.append(0) + ff = False + unclassified = False + + # check whether criterion for successful SN is fulfilled + if comp_val < comp_crit1 or sc_val < sc_crit1: + ff1.append(1) + ff = True + unclassified = False + + # if there is contradiction or if the progenitor is unclassified based on comp & s_c + if (len(ff1) > 0 and len(ff2) > 0) or unclassified: + + # final fate classification based on mu4M4 + if mu4M4_val > mu4M4_crit2: + ff = False + elif mu4M4_val < mu4M4_crit1: + ff = True + # final fate classification based on reversed Ertl criterion + elif k1 + k2*mu4M4_val - mu4_val > 0: + ff = False + else: + ff = True + return ff + + + class Sukhbold16_corecollapse(object): """Compute supernova final remnant mass, fallback fraction and CO type. diff --git a/posydon/binary_evol/binarystar.py b/posydon/binary_evol/binarystar.py index efad6b2067..d89f12d3ab 100644 --- a/posydon/binary_evol/binarystar.py +++ b/posydon/binary_evol/binarystar.py @@ -222,8 +222,8 @@ def __init__(self, star_1=None, star_2=None, index=None, properties=None, setattr(self, f'interp_class_{grid_type}', None) if not hasattr(self, f'mt_history_{grid_type}'): setattr(self, f'mt_history_{grid_type}', None) - if not hasattr(self, f'culmulative_mt_case_{grid_type}'): - setattr(self, f'culmulative_mt_case_{grid_type}', None) + if not hasattr(self, f'cumulative_mt_case_{grid_type}'): + setattr(self, f'cumulative_mt_case_{grid_type}', None) # SimulationProperties object - parameters & parameterizations if isinstance(properties, SimulationProperties): diff --git a/posydon/binary_evol/simulationproperties.py b/posydon/binary_evol/simulationproperties.py index 43495e81f9..d24b168afc 100644 --- a/posydon/binary_evol/simulationproperties.py +++ b/posydon/binary_evol/simulationproperties.py @@ -17,8 +17,15 @@ import os import time +import numpy as np + +from posydon.binary_evol.track_match import TrackMatcher +from posydon.config import PATH_TO_POSYDON_DATA +from posydon.interpolation.interpolation import GRIDInterpolator from posydon.popsyn.io import simprop_kwargs_from_ini +from posydon.utils.common_functions import convert_metallicity_to_string from posydon.utils.constants import age_of_universe +from posydon.utils.posydonerror import GridError from posydon.utils.posydonwarning import Pwarn @@ -29,6 +36,17 @@ class NullStep: class SimulationProperties: """Class describing the properties of a population synthesis simulation.""" + # each value in this dict represents the expected path for the respective grid. + # A user may specify their own full path to a custom grid in the [grid_paths] + # section of their .ini file. I.e., HMS-HMS_path = 'path/to/my_own_grid/' to + # override these defaults. + default_grid_paths = {"single_HMS_path": os.path.join(PATH_TO_POSYDON_DATA, "single_HMS"), + "single_HeMS_path": os.path.join(PATH_TO_POSYDON_DATA, "single_HeMS"), + "HMS_HMS_path": os.path.join(PATH_TO_POSYDON_DATA, "HMS-HMS"), + "CO_HMS_RLO_path": os.path.join(PATH_TO_POSYDON_DATA, "CO-HMS_RLO"), + "CO_HeMS_path": os.path.join(PATH_TO_POSYDON_DATA, "CO-HeMS"), + "CO_HeMS_RLO_path": os.path.join(PATH_TO_POSYDON_DATA, "CO-HeMS_RLO")} + def __init__(self, flow=({}, {}), step_HMS_HMS = (NullStep(), {}), step_CO_HeMS = (NullStep(), {}), @@ -163,12 +181,6 @@ def __init__(self, flow=({}, {}), "(i) a class deriving from EvolveHooks and a kwargs dict, " "or (ii) the name of the extra function and the callable.") - # Binary parameters and parameterizations - self.initial_rotation = 0.0 - self.mass_transfer_efficiency = 1.0 - - self.common_envelope_efficiency = 1.0 - # Limits on simulation if not hasattr(self, 'max_simulation_time'): self.max_simulation_time = age_of_universe @@ -192,6 +204,83 @@ def __init__(self, flow=({}, {}), self.preload_imports() + # To hold TrackMatcher objects per step, if needed. + # maybe get rid of this + self.track_matchers = {} + + for grid_name in self.default_grid_paths: + try: + self.set_path(grid_name, self.kwargs[grid_name]) + except KeyError as e: + Pwarn(f"{grid_name} is not set in the kwargs passed to SimulationProperties. " + f"Falling back to the default: {self.default_grid_paths[grid_name]}", + "ReplaceValueWarning") + + self.set_path(grid_name, self.default_grid_paths[grid_name]) + + # These hold GRIDInterpolator objects + # and associated grid names for ea. metallicity + # (intended keys are metallicities): + self.grids_Hrich = {} + self.grids_strippedHe = {} + + def set_path(self, path_name, path_str): + """ + Set and normalize a grid path attribute that points to one of the + MESA grids needed for binary evolution. By default, these are the + grids inside of the directory name held in $PATH_TO_POSYDON_DATA. + + For example, for the step_HMS_HMS, the grid would be + + $PATH_TO_POSYDON_DATA/HMS-HMS/_Zsun.h5 + + by default. The grid HDF5 file names themselves are expected to + follow formats like so: 1e+00_Zsun.h5, 1e-04_Zsun.h5, etc. + + If ``path_str`` is ``None``, a default path is assigned based on + ``path_name`` using ``self.default_grid_paths``. If ``path_name`` is not + recognized, a ``GridError`` is raised listing the valid options. + + The resulting path is converted to an absolute path before being stored + as an attribute of the instance. + + Parameters + ---------- + path_name : str + Name of the grid path attribute to set. Must be a key in + ``self.default_grid_paths`` if ``path_str`` is ``None``. + + path_str : str or None + Path to assign. If ``None``, a default path corresponding to + ``path_name`` is used. + + Raises + ------ + GridError + If ``path_name`` is not recognized and no default path can be assigned. + + Notes + ----- + The path is not validated for existence here; only normalization to an + absolute path is performed. + """ + + # construct path to *_Zsun.h5 files if not specified + if path_str is None: + if path_name in self.default_grid_paths: + path_str = self.default_grid_paths[path_name] + else: + valid_names = "\n".join(f"{k} = " for k in self.default_grid_paths) + raise GridError(f'Trying to assign a grid path for "{path_name}".\n' + "This is an unrecognized path name. Please check " + "the [grid_paths] section of your .ini file.\n\n" + "Valid path variable names are:\n" + f"{valid_names}\n") + + path_str = os.path.abspath(path_str) + + setattr(self, path_name, path_str) + def preload_imports(self): """ Preload the imports of detached_step and MesaGridStep to avoid @@ -202,14 +291,17 @@ def preload_imports(self): failure occurs, hence the need for something like this. """ + from posydon.binary_evol.CE.step_CEE import StepCEE from posydon.binary_evol.DT.step_detached import detached_step from posydon.binary_evol.MESA.step_mesa import MesaGridStep self._detached_step = detached_step + self._step_CE = StepCEE self._MesaGridStep = MesaGridStep @classmethod - def from_ini(cls, path, metallicity = None, load_steps=False, verbose=False, **override_sim_kwargs): + def from_ini(cls, path, metallicity = None, load_steps=False, RNG=np.random.default_rng(), + verbose=False, **override_sim_kwargs): """Create a SimulationProperties instance from an inifile. Parameters @@ -226,9 +318,19 @@ def from_ini(cls, path, metallicity = None, load_steps=False, verbose=False, **o load_steps : bool Whether or not evolution steps should be automatically loaded. + RNG : numpy.random.Generator, optional + Random number generator used for any stochastic components of + the simulation. Defaults to a new NumPy Generator instance + created via ``np.random.default_rng()``. + verbose : bool Print useful info. + **override_sim_kwargs + Additional keyword arguments that override values specified + in the .ini file when constructing the SimulationProperties + instance. + Returns ------- SimulationProperties @@ -244,11 +346,12 @@ def from_ini(cls, path, metallicity = None, load_steps=False, verbose=False, **o if load_steps: # Load the steps and required data new_instance.load_steps(metallicity=metallicity, + RNG=RNG, verbose=verbose) return new_instance - def load_steps(self, metallicity=None, verbose=False): + def load_steps(self, metallicity=None, RNG=np.random.default_rng(), verbose=False): """Instantiate all step classes and set as instance attributes. Parameters @@ -266,99 +369,235 @@ def load_steps(self, metallicity=None, verbose=False): ------- None """ - if verbose: - print('STEP NAME'.ljust(20) + 'STEP FUNCTION'.ljust(25) + 'KWARGS') # for every other step, give it a metallicity and load each step for name, tup in self.kwargs.items(): if isinstance(tup, tuple): step_kwargs = tup[1] metallicity = step_kwargs.get('metallicity', metallicity) - self.load_a_step(name, tup, metallicity=metallicity, verbose=verbose) + self.load_a_step(name, tup, metallicity=metallicity, RNG=RNG, verbose=verbose) + + if verbose: + if self.steps_loaded: + print("All steps loaded successfully.") + else: + print("Not all steps were loaded successfully. Check warnings for details.") - # track that all steps have been loaded - self.steps_loaded = True + def load_a_step(self, step_name, step_tup=(NullStep, {}), metallicity=None, + RNG=np.random.default_rng(), from_ini='', verbose=False): + """ + Instantiate and attach a simulation step to this object. - def load_a_step(self, step_name, step_tup=(NullStep, {}), metallicity=None, from_ini='', verbose=False): - """Instantiate one step class and set as instance attribute. + This method creates an instance of a step class and assigns it as an + attribute of SimulationProperties using ``step_name`` as the attribute + name. Step keyword arguments may be provided directly via ``step_tup`` + or loaded from an `.ini` configuration file. Before instantiation, + step arguments are validated and augmented (e.g., assigning metallicity + and creating a TrackMatcher if required). Parameters ---------- step_name : str + Name of the evolution step. The created step instance will be + attached to the object as ``self.``. See + ``SimulationProperties.__init__`` for the standard set of steps. - This string is the name of the evolution step. See - SimulationProperties.__init__ for the full standard set. + step_tup : tuple, optional + Tuple of the form ``(step_class, kwargs_dict)`` where: - step_tup : tuple - A tuple whose first element is the step class and whose - second is a dictionary representing the step's kwargs. + - ``step_class`` is the class representing the step. + - ``kwargs_dict`` is a dictionary of keyword arguments used to + initialize the step. - metallicity : float - A metallicity (Z) may be provided to automatically assign - to the step as it is loaded. Should be one of e.g., 2.0, 1.0, - 4.5e-1, 2e-1, 1e-1, 1e-2, 1e-3, 1e-4, corresponding to - metallicities available in your POSYDON_DATA grids. + Default is ``(NullStep, {})``. - from_ini : str - Path to a .ini file to read step options from. + metallicity : float, optional + Metallicity (Z) to assign to the step if required and not already + specified in the step keyword arguments. Default supported values + are: 2.0, 1.0, 4.5e-1, 2e-1, 1e-1, 1e-2, 1e-3, 1e-4. - verbose : bool - Print extra information. + from_ini : str, optional + Path to an `.ini` file containing step configuration. If provided + and the file exists, the step class and keyword arguments for + ``step_name`` are loaded from this file and override ``step_tup``. + + verbose : bool, optional + If True, print detailed information about step loading and the + keyword arguments used to instantiate the step. Returns ------- None + + Notes + ----- + - Step keyword arguments are processed by ``self.check_step`` before + instantiation. This may assign a metallicity and/or attach a + ``TrackMatcher`` if required for the step. + - The instantiated step is stored as an attribute of SimulationProperties. + - After loading, ``self.steps_loaded`` is updated to indicate whether + all configured steps have been successfully attached. """ - # these steps and the flow do not require a metallicity - ignore_for_met = ["flow", "step_SN", "step_end"] + if verbose: + print(f"Loading {step_name}...") # grab kwargs from ini file for given step if os.path.isfile(from_ini): step_tup = simprop_kwargs_from_ini(from_ini, only=step_name)[step_name] - if (metallicity is None) and (step_name not in ignore_for_met): - step_kwargs = step_tup[1] - metallicity = step_kwargs.get('metallicity', metallicity) - if metallicity is not None: - pass - # if still None: - else: - Pwarn(f"{step_name} not assigned a metallicity. Defaulting to Z = Zsun (solar).", - "MissingValueWarning") - metallicity = 1.0 - - # This if should never trigger after __init__, unless the step is - # entirely new and non-standard - if step_name not in self.kwargs.keys(): - self.kwargs[step_name] = step_tup - - # give step a metallicity and load it as a class attribute - if step_name not in ignore_for_met: - step_tup[1].update({'metallicity':float(metallicity)}) - if verbose: - print(step_name, step_tup, end='\n') + if step_name != "flow": + # check to make sure the step has a... + # 1) metallicity assigned (if needed) + # 2) TrackMatcher assigned (if needed) + step_tup = self.check_step(metallicity, RNG, step_name, + step_tup, verbose) - step_func, kwargs = step_tup + step_func, step_kwargs = step_tup - # steps like step_end do not take kwargs, so try loading with - # kwargs first, then without if that fails. This mostly matters - # if a user has re-mapped a step to one that does not take kwargs. + # Try to load the step try: - setattr(self, step_name, step_func(**kwargs)) + setattr(self, step_name, step_func(**step_kwargs)) + if verbose: + print(f"Class: {step_func}") + if step_kwargs: + print("step_kwargs: ") + kw_list = [f"\t{key}: {val}" for key, val in step_kwargs.items()] + print("\n".join(kw_list)) + print(f"{step_name} loaded successfully.\n") except TypeError as e: - Pwarn(f"Error loading step {step_name}: {e}", "StepWarning") + Pwarn(f"Error loading {step_name}: {e}", "StepWarning") print(f"Loading {step_name} without arguments.") setattr(self, step_name, step_func()) # check if all steps have been loaded - for name, tup in self.kwargs.items(): - if isinstance(tup, tuple): - if hasattr(self, name): - self.steps_loaded = True - else: - self.steps_loaded = False + self.steps_loaded = all(hasattr(self, name) + for name, tup in self.kwargs.items() + if isinstance(tup, tuple)) + + def check_step(self, metallicity, RNG, step_name, step_tup, verbose=False): + """ + Validate and update configuration for an evolution step. + + This method ensures that a valid metallicity is assigned to the step + (unless the step is excluded from metallicity handling) and that a + corresponding TrackMatcher exists if the step requires track matching. + If a TrackMatcher for the `(metallicity, step_name)` combination does + not yet exist, it is created and stored. + + Parameters + ---------- + metallicity : float or None + Default metallicity value to use for the step if not explicitly + provided in ``step_kwargs``. + step_name : str + Name of the pipeline step being checked. + step_kwargs : dict + Keyword arguments for the step. This dictionary may be modified + in-place to include validated metallicity and/or a TrackMatcher + instance. + verbose : bool, optional + If True, print the keyword arguments used to construct the + TrackMatcher. + + Returns + ------- + dict + The updated ``step_kwargs`` dictionary, containing a validated + ``metallicity`` entry and potentially a ``track_matcher`` object. + + Notes + ----- + - If metallicity is not provided for a step that requires it, a warning + is issued and a default value of ``Z = 1.0`` (solar metallicity) is used. + - TrackMatcher objects are stored in ``self.track_matchers`` and reused + for repeated `(metallicity, step_name)` combinations. + """ + step_func, step_kwargs = step_tup + + # check/assign metallicity for the step + if "metallicity" in step_func.DEFAULT_KWARGS: + metallicity = step_kwargs.get('metallicity', metallicity) + if metallicity is None: + Pwarn(f"{step_name} not assigned a metallicity. " + "Defaulting to Z = Zsun (solar).", + "ReplaceValueWarning") + metallicity = 1.0 + step_kwargs['metallicity'] = float(metallicity) + + # These steps need these grids: + step_grid_map = {"step_HMS_HMS": self.HMS_HMS_path, + "step_CO_HMS_RLO": self.CO_HMS_RLO_path, + "step_CO_HeMS": self.CO_HeMS_path, + "step_CO_HeMS_RLO": self.CO_HeMS_RLO_path} + if step_name in step_grid_map: + step_kwargs['grid_path'] = step_grid_map[step_name] + + # each metallicity/step combo could require + # a unique TrackMatcher, so check for that + matcher_key = (metallicity, step_name) + if "track_matcher" in step_func.DEFAULT_KWARGS: + matcher_needed = matcher_key not in self.track_matchers + if matcher_needed: + # create TrackMatcher if needed + step_kwargs, matcher_kwargs = TrackMatcher.separate_kwargs(step_kwargs) + self.create_track_matcher(metallicity, step_name, matcher_kwargs) + + if verbose: + kw_list = [f"\t{key}: {val}" for key, val in matcher_kwargs.items()] + print(f"matcher_kwargs: \n" + "\n".join(kw_list)) + step_kwargs['track_matcher'] = self.track_matchers[matcher_key] + + if "RNG" in step_func.DEFAULT_KWARGS: + step_kwargs['RNG'] = RNG + + return step_tup + + def create_track_matcher(self, metallicity, step_name, matcher_kwargs): + """ + Create and store a TrackMatcher for a given metallicity and step. + + This method ensures that the required stellar evolution grids + (H-rich and stripped-He) are loaded for the specified metallicity. + If the corresponding GRIDInterpolator objects do not yet exist, + they are created and cached. The interpolators are then passed to + a TrackMatcher instance, which is stored internally. + + Parameters + ---------- + metallicity : float + Stellar metallicity used to select the appropriate grid files. + step_name : str + Identifier for the evolutionary step associated with this + TrackMatcher. + matcher_kwargs : dict + Keyword arguments used to initialize the TrackMatcher. This + dictionary will be updated in-place with the following keys: + 'grid_Hrich' and 'grid_strippedHe'. + + Notes + ----- + - GRIDInterpolator objects are created only once per metallicity + and reused for subsequent TrackMatcher creations. + - The created TrackMatcher is stored in ``self.track_matchers`` + using the key ``(metallicity, step_name)``. + """ + + z_str = convert_metallicity_to_string(metallicity) + # set up GRIDInterpolator objects (for HMS and HeMS) + # (only if one hasn't been created already for a given metallicity) + if metallicity not in self.grids_Hrich: + grid_path_Hrich = os.path.join(self.single_HMS_path, f"{z_str}_Zsun.h5") + self.grids_Hrich[metallicity] = GRIDInterpolator(grid_path_Hrich) + if metallicity not in self.grids_strippedHe: + grid_path_strippedHe = os.path.join(self.single_HeMS_path, f"{z_str}_Zsun.h5") + self.grids_strippedHe[metallicity] = GRIDInterpolator(grid_path_strippedHe) + + # Create TrackMatcher object as needed, passing GRIDInterpolator references + matcher_kwargs['grid_Hrich'] = self.grids_Hrich[metallicity] + matcher_kwargs['grid_strippedHe'] = self.grids_strippedHe[metallicity] + self.track_matchers[(metallicity, step_name)] = TrackMatcher(**matcher_kwargs) def close(self): """Close hdf5 files before exiting.""" @@ -368,9 +607,11 @@ def close(self): for step_func in all_step_funcs: if isinstance(step_func, self._MesaGridStep): step_func.close() - elif isinstance(step_func, self._detached_step): - for grid_interpolator in [step_func.track_matcher.grid_Hrich, step_func.track_matcher.grid_strippedHe]: - grid_interpolator.close() + + for metallicity in self.grids_Hrich: + self.grids_Hrich[metallicity].close() + for metallicity in self.grids_strippedHe: + self.grids_strippedHe[metallicity].close() def pre_evolve(self, binary): """Functions called before a binary evolves. diff --git a/posydon/binary_evol/singlestar.py b/posydon/binary_evol/singlestar.py index 8b194c8574..9c84997a03 100644 --- a/posydon/binary_evol/singlestar.py +++ b/posydon/binary_evol/singlestar.py @@ -40,7 +40,7 @@ STARPROPERTIES = [ 'state', # the evolutionary state of the star. For more info see # `posydon.utils.common_functions.check_state_of_star` - 'metallicity', # initial mass fraction of metals + 'metallicity', # Z/Z_sun, ratio to solar metallicity (1.0 for solar) 'mass', # mass (solar units) 'log_R', # log10 of radius (solar units) 'log_L', # log10 luminosity (solar units) @@ -354,6 +354,10 @@ def __init__(self, **kwargs): self.M4 = None if not hasattr(self, 'mu4'): self.mu4 = None + if not hasattr(self, 'Xi'): + self.Xi = None + if not hasattr(self, 'sc'): + self.sc = None if not hasattr(self, 'interp1d'): self.interp1d = None diff --git a/posydon/binary_evol/step_end.py b/posydon/binary_evol/step_end.py index 0b5aef7cc2..fff9f0dd3e 100644 --- a/posydon/binary_evol/step_end.py +++ b/posydon/binary_evol/step_end.py @@ -9,6 +9,8 @@ class step_end: """Default end step.""" + DEFAULT_KWARGS = {} + def __call__(self, binary): """Change the event of the binary to 'end'.""" if binary.state == "disrupted": diff --git a/posydon/binary_evol/DT/track_match.py b/posydon/binary_evol/track_match.py similarity index 93% rename from posydon/binary_evol/DT/track_match.py rename to posydon/binary_evol/track_match.py index 4450163647..031cc7a090 100644 --- a/posydon/binary_evol/DT/track_match.py +++ b/posydon/binary_evol/track_match.py @@ -22,11 +22,6 @@ from scipy.optimize import minimize, root import posydon.utils.constants as const -from posydon.binary_evol.DT.key_library import ( - DEFAULT_PROFILE_KEYS, - DEFAULT_TRANSLATED_KEYS, - KEYS_POSITIVE, -) from posydon.binary_evol.flow_chart import ( STAR_STATES_CO, STAR_STATES_FOR_HMS_MATCHING, @@ -38,18 +33,22 @@ ) from posydon.config import PATH_TO_POSYDON_DATA from posydon.interpolation.data_scaling import DataScaler -from posydon.interpolation.interpolation import GRIDInterpolator from posydon.utils.common_functions import ( convert_metallicity_to_string, set_binary_to_failed, ) from posydon.utils.interpolators import SingleStarInterpolator +from posydon.utils.key_library import ( + DEFAULT_FINAL_KEYS, + DEFAULT_PROFILE_KEYS, + DEFAULT_TRANSLATED_KEYS, + KEYS_POSITIVE, +) from posydon.utils.posydonerror import MatchingError, NumericalError, POSYDONError from posydon.utils.posydonwarning import Pwarn MATCHING_WITH_RELATIVE_DIFFERENCE = ["center_he4"] - val_names = [" ", "mass", "log_R", "center_h1", "surface_h1", "he_core_mass", "center_he4", "surface_he4", "center_c12", "co_core_mass"] @@ -253,58 +252,52 @@ class TrackMatcher: """ - def __init__( - self, - grid_name_Hrich, - grid_name_strippedHe, - path=PATH_TO_POSYDON_DATA, - metallicity=None, - matching_method="minimize", - matching_tolerance=1e-2, - matching_tolerance_hard=1e-1, - list_for_matching_HMS=None, - list_for_matching_postMS=None, - list_for_matching_HeStar=None, - list_for_matching_postHeMS=None, - record_matching=False, - verbose=False - ): + DEFAULT_KWARGS = {"grid_Hrich":None, + "grid_strippedHe":None, + "path":PATH_TO_POSYDON_DATA, + "metallicity":None, + "matching_method":"minimize", + "matching_tolerance":1e-2, + "matching_tolerance_hard":1e-1, + "list_for_matching_HMS":None, + "list_for_matching_HeStar":None, + "list_for_matching_postMS":None, + "list_for_matching_postHeMS":None, + "record_matching":False, + "verbose":False} + + def __init__(self, **kwargs): # MESA history column names used as matching metrics # TODO: should this be singlestar.STARPROPERTIES? An # error is thrown when (possibly user defined) # matching metrics don't exist in this array. # That's not very flexible... - self.root_keys = np.array( - [ - "age", - "mass", - "he_core_mass", - "center_h1", - "center_he4", - "surface_he4", - "surface_h1", - "log_R", - "center_c12", - "co_core_mass" - ] - ) + self.root_keys = np.array(["age", "mass", "he_core_mass", + "co_core_mass", + "center_h1", "center_he4", + "surface_he4", "surface_h1", + "center_c12", "log_R"]) # ===================================================================== + if kwargs: + for key in kwargs: + if key not in self.DEFAULT_KWARGS: + raise POSYDONError(f"Unexpected keyword argument {key} " + "passed to TrackMatcher. Expected " + f"kwargs: {self.DEFAULT_KWARGS.keys()}") + for varname in self.DEFAULT_KWARGS: + default_value = self.DEFAULT_KWARGS[varname] + setattr(self, varname, kwargs.get(varname, default_value)) + else: + for varname in self.DEFAULT_KWARGS: + default_value = self.DEFAULT_KWARGS[varname] + setattr(self, varname, default_value) - self.metallicity = convert_metallicity_to_string(metallicity) - self.matching_method = matching_method - self.matching_tolerance = matching_tolerance # DEFAULT: 1e-2 - self.matching_tolerance_hard = matching_tolerance_hard # DEFAULT: 1e-1 + self.metallicity = convert_metallicity_to_string(self.metallicity) self.initial_mass = None self.rootm = None - self.verbose = verbose - - self.list_for_matching_HMS = list_for_matching_HMS - self.list_for_matching_postMS = list_for_matching_postMS - self.list_for_matching_HeStar = list_for_matching_HeStar - self.list_for_matching_postHeMS = list_for_matching_postHeMS # mapping a combination of (key, htrack, method) to a pre-trained # DataScaler instance, created the first time it is requested @@ -312,43 +305,15 @@ def __init__( # these are the KEYS read from POSYDON h5 grid files (after translating # them to the appropriate columns) - self.KEYS = DEFAULT_TRANSLATED_KEYS #KEYS #DEFAULT_TRANSLATED_KEYS + self.KEYS = DEFAULT_TRANSLATED_KEYS self.KEYS_POSITIVE = KEYS_POSITIVE - # keys for the final value interpolation - self.final_keys = ( - 'avg_c_in_c_core_at_He_depletion', - 'co_core_mass_at_He_depletion', - 'm_core_CE_1cent', - 'm_core_CE_10cent', - 'm_core_CE_30cent', - 'm_core_CE_pure_He_star_10cent', - 'r_core_CE_1cent', - 'r_core_CE_10cent', - 'r_core_CE_30cent', - 'r_core_CE_pure_He_star_10cent' - ) - + self.final_keys = DEFAULT_FINAL_KEYS # keys for the star profile interpolation self.profile_keys = DEFAULT_PROFILE_KEYS - - # should grids just get passed to this? - if grid_name_Hrich is None: - grid_name_Hrich = os.path.join('single_HMS', - self.metallicity+'_Zsun.h5') - grid_path_Hrich = os.path.join(path, grid_name_Hrich) - self.grid_Hrich = GRIDInterpolator(grid_path_Hrich) - - if grid_name_strippedHe is None: - grid_name_strippedHe = os.path.join('single_HeMS', - self.metallicity+'_Zsun.h5') - grid_path_strippedHe = os.path.join(path, grid_name_strippedHe) - self.grid_strippedHe = GRIDInterpolator(grid_path_strippedHe) - # ===================================================================== # Initialize the matching lists: - # min/max ranges of initial masses for each grid m_min_H = np.min(self.grid_Hrich.grid_mass) m_max_H = np.max(self.grid_Hrich.grid_mass) @@ -423,7 +388,125 @@ def __init__( [m_min_He, m_max_He], [t_min_He, t_max_He] ] - self.record_matching = record_matching + # create and train scalers + self.create_root0_h() + self.create_root0_he() + self.train_scalers() + + @classmethod + def separate_kwargs(cls, step_kwargs): + + matcher_kwargs = cls.DEFAULT_KWARGS.copy() + for key, val in step_kwargs.items(): + if key in matcher_kwargs: + matcher_kwargs.update({key: val}) + # peel off TrackMatcher kwargs from step_kwargs + except_keys = ["metallicity", "verbose"] + for key in matcher_kwargs: + if key in except_keys: + continue + _ = step_kwargs.pop(key, None) + + return step_kwargs, matcher_kwargs + + def train_scalers(self): + + # ...if not, fit a new scaler, and store it for later use + + lists_for_matching = [self.list_for_matching_HMS, + self.list_for_matching_HeStar, + self.list_for_matching_HMS_alternative, + self.list_for_matching_HeStar_alternative, + self.list_for_matching_postHeMS, + self.list_for_matching_postHeMS_alternative, + self.list_for_matching_postMS, + self.list_for_matching_postMS_alternative] + + for list_for_matching in lists_for_matching: + + match_attr_names = list_for_matching[0] + rescale_facs = list_for_matching[1] + scaler_methods = list_for_matching[2] + bnds = list_for_matching[3:] + + if self.verbose: + print("Matching parameters and their normalizations:\n", + match_attr_names, rescale_facs) + for htrack in [True, False]: + grid = self.grid_Hrich if htrack else self.grid_strippedHe + self.initial_mass = grid.grid_mass + + # get (or train and get) scalers for attributes + # attributes are scaled to range (0, 1) + for attr_name, method in zip(match_attr_names, scaler_methods): + all_attributes = [] + # check that attributes are allowed as matching attributes + if attr_name not in self.root_keys: + raise AttributeError("Expected matching attribute " + f"{attr_name} not " + "added in root_keys list: " + f"{self.root_keys}") + + scaler_options = (attr_name, htrack, method) + + for mass in self.initial_mass: + for i in grid.get(attr_name, mass): + all_attributes.append(i) + + all_attributes = np.array(all_attributes) + scaler = DataScaler() + scaler.fit(all_attributes, method=method, lower=0.0, upper=1.0) + self.stored_scalers[scaler_options] = scaler + + def create_root0_h(self): + + # set which grid to search based on htrack condition + grid = self.grid_Hrich + + # initial masses within grid (defined but never used? used in scale()) + self.initial_mass = grid.grid_mass + + # search across all initial masses and get max track length + max_track_length = 0 + for mass in grid.grid_mass: + track_length = len(grid.get("age", mass)) + max_track_length = max(max_track_length, track_length) + + # intialize root matrix + # (DIM = [N(Mi), N(max_track_length), N(root_keys)]) + self.rootm_h = np.inf * np.ones((len(grid.grid_mass), + max_track_length, len(self.root_keys))) + + # for each mass, get matching metrics and store in matrix + for i, mass in enumerate(grid.grid_mass): + for j, key in enumerate(self.root_keys): + track = grid.get(key, mass) + self.rootm_h[i, : len(track), j] = track + + def create_root0_he(self): + + # set which grid to search based on htrack condition + grid = self.grid_strippedHe + + # initial masses within grid (defined but never used? used in scale()) + self.initial_mass = grid.grid_mass + + # search across all initial masses and get max track length + max_track_length = 0 + for mass in grid.grid_mass: + track_length = len(grid.get("age", mass)) + max_track_length = max(max_track_length, track_length) + + # intialize root matrix + # (DIM = [N(Mi), N(max_track_length), N(root_keys)]) + self.rootm_he = np.inf * np.ones((len(grid.grid_mass), + max_track_length, len(self.root_keys))) + + # for each mass, get matching metrics and store in matrix + for i, mass in enumerate(grid.grid_mass): + for j, key in enumerate(self.root_keys): + track = grid.get(key, mass) + self.rootm_he[i, : len(track), j] = track def get_root0(self, attr_names, attr_vals, htrack, rescale_facs=None): """ @@ -463,29 +546,10 @@ def get_root0(self, attr_names, attr_vals, htrack, rescale_facs=None): """ + rootm = self.rootm_h if htrack else self.rootm_he # set which grid to search based on htrack condition grid = self.grid_Hrich if htrack else self.grid_strippedHe - # initial masses within grid (defined but never used? used in scale()) - self.initial_mass = grid.grid_mass - - # search across all initial masses and get max track length - max_track_length = 0 - for mass in grid.grid_mass: - track_length = len(grid.get("age", mass)) - max_track_length = max(max_track_length, track_length) - - # intialize root matrix - # (DIM = [N(Mi), N(max_track_length), N(root_keys)]) - self.rootm = np.inf * np.ones((len(grid.grid_mass), - max_track_length, len(self.root_keys))) - - # for each mass, get matching metrics and store in matrix - for i, mass in enumerate(grid.grid_mass): - for j, key in enumerate(self.root_keys): - track = grid.get(key, mass) - self.rootm[i, : len(track), j] = track - # rescaling factors if rescale_facs is None: rescale_facs = np.ones_like(attr_names) @@ -501,7 +565,7 @@ def get_root0(self, attr_names, attr_vals, htrack, rescale_facs=None): # Slice out just the matching metric data for all stellar tracks # grid_attr_vals now has shape # (N(Mi), N(max_track_len), N(matching_metrics)) - grid_attr_vals = self.rootm[:, :, idx] + grid_attr_vals = rootm[:, :, idx] # For all stellar tracks in grid: # Take difference btwn. grid track and given star values... @@ -522,7 +586,7 @@ def get_root0(self, attr_names, attr_vals, htrack, rescale_facs=None): # time and initial mass corresp. to track w/ minimum difference m0 = grid.grid_mass[mass_i] - t0 = self.rootm[mass_i][age_i][np.argmax("age" == self.root_keys)] + t0 = rootm[mass_i][age_i][np.argmax("age" == self.root_keys)] return m0, t0 @@ -608,22 +672,6 @@ def scale(self, attr_name, htrack, scaler_method): # find if the scaler has already been fitted and return it if so... scaler = self.stored_scalers.get(scaler_options, None) - if scaler is not None: - return scaler - - # ...if not, fit a new scaler, and store it for later use - grid = self.grid_Hrich if htrack else self.grid_strippedHe - self.initial_mass = grid.grid_mass - all_attributes = [] - - for mass in self.initial_mass: - for i in grid.get(attr_name, mass): - all_attributes.append(i) - - all_attributes = np.array(all_attributes) - scaler = DataScaler() - scaler.fit(all_attributes, method=scaler_method, lower=0.0, upper=1.0) - self.stored_scalers[scaler_options] = scaler return scaler diff --git a/posydon/grids/SN_MODELS.py b/posydon/grids/SN_MODELS.py index 647c494a09..65f32ec79f 100644 --- a/posydon/grids/SN_MODELS.py +++ b/posydon/grids/SN_MODELS.py @@ -447,6 +447,76 @@ # "use_profiles": True, "use_core_masses": False, # "allow_spin_None" : False, +# "approx_at_he_depletion": False, + }, + "SN_MODEL_v2_25": { + "mechanism": "Maltsev+25-engine", + "engine": "M16", +# "PISN": "Hendriks+23", +# "PISN_CO_shift": 0.0, +# "PPI_extra_mass_loss": -20.0, +# "ECSN": "Tauris+15", +# "conserve_hydrogen_envelope" : False, +# "conserve_hydrogen_PPI" : False, +# "max_neutrino_mass_loss": NEUTRINO_MASS_LOSS_UPPER_LIMIT, +# "max_NS_mass": STATE_NS_STARMASS_UPPER_LIMIT, + "use_interp_values": False, +# "use_profiles": True, + "use_core_masses": False, +# "allow_spin_None" : False, +# "approx_at_he_depletion": False, + }, + "SN_MODEL_v2_26": { + "mechanism": "Maltsev+25-engine", + "engine": "M16", +# "mechanism": "Fryer+12-delayed", +# "engine": "", +# "PISN": "Hendriks+23", +# "PISN_CO_shift": 0.0, +# "PPI_extra_mass_loss": -20.0, +# "ECSN": "Tauris+15", + "conserve_hydrogen_envelope" : True, +# "conserve_hydrogen_PPI" : False, +# "max_neutrino_mass_loss": NEUTRINO_MASS_LOSS_UPPER_LIMIT, +# "max_NS_mass": STATE_NS_STARMASS_UPPER_LIMIT, + "use_interp_values": False, +# "use_profiles": True, + "use_core_masses": False, +# "allow_spin_None" : False, +# "approx_at_he_depletion": False, + }, + "SN_MODEL_v2_27": { + "mechanism": "Maltsev+25-engine", + "engine": "M16", +# "PISN": "Hendriks+23", +# "PISN_CO_shift": 0.0, + "PPI_extra_mass_loss": 0.0, +# "ECSN": "Tauris+15", +# "conserve_hydrogen_envelope" : False, +# "conserve_hydrogen_PPI" : False, +# "max_neutrino_mass_loss": NEUTRINO_MASS_LOSS_UPPER_LIMIT, +# "max_NS_mass": STATE_NS_STARMASS_UPPER_LIMIT, + "use_interp_values": False, +# "use_profiles": True, + "use_core_masses": False, +# "allow_spin_None" : False, +# "approx_at_he_depletion": False, + }, + "SN_MODEL_v2_28": { + "mechanism": "Maltsev+25-engine", + "engine": "M16", +# "PISN": "Hendriks+23", +# "PISN_CO_shift": 0.0, + "PPI_extra_mass_loss": 0.0, +# "ECSN": "Tauris+15", + "conserve_hydrogen_envelope" : True, +# "conserve_hydrogen_PPI" : False, +# "max_neutrino_mass_loss": NEUTRINO_MASS_LOSS_UPPER_LIMIT, +# "max_NS_mass": STATE_NS_STARMASS_UPPER_LIMIT, + "use_interp_values": False, +# "use_profiles": True, + "use_core_masses": False, +# "allow_spin_None" : False, # "approx_at_he_depletion": False, }, } diff --git a/posydon/grids/lazy_hdf.py b/posydon/grids/lazy_hdf.py new file mode 100644 index 0000000000..28600b29b1 --- /dev/null +++ b/posydon/grids/lazy_hdf.py @@ -0,0 +1,92 @@ +__authors__ = [ + "Seth Gossage " +] + +import numpy as np +import pandas as pd + + +class LazyHDF5: + """ + Lazy wrapper around an HDF5 dataset with optional dtype conversion. + + This class provides a lightweight interface for accessing data from an + HDF5 dataset without immediately loading the entire dataset into memory. + Data are retrieved lazily when indexed. Optionally, a set of dtype + conversions can be applied when data are accessed. + + If dtype mappings are provided, retrieved data are cast to the specified + dtypes either per-field (for structured arrays) or for the selected field + when accessed by name. + + Assignments (via __setitem__) trigger full materialization of the dataset + in memory, after which the internal storage is replaced by the in-memory + array. + + Parameters + ---------- + dataset : h5py.Dataset or array-like + The underlying dataset providing the data. Typically an HDF5 dataset + object supporting NumPy-style indexing. + dtype_set : dict, optional + Mapping of field names to NumPy dtypes used to cast the returned data. + This is typically used for structured arrays where individual fields + require specific dtype conversions. + + Notes + ----- + - Data are only read from the dataset when accessed via ``__getitem__`` or + when converted to a NumPy array. + - Writing via ``__setitem__`` loads the entire dataset into memory before + modifying it. + - The ``dtype`` property reflects the converted dtype if ``dtype_set`` is + provided. + """ + def __init__(self, dataset, dtype_set=None): + self._dataset = dataset + self._dtype_set = dtype_set + if self._dtype_set is not None: + self._dtype_list = list(self._dtype_set.items()) + + def __getitem__(self, idx): + data = self._dataset[idx] + if self._dtype_set is not None: + if isinstance(idx, str): + data = data.astype(self._dtype_set[idx]) + else: + data = data.astype(self._dtype_list) + return data + + def __setitem__(self, idx, value): + # materialize full array in memory + arr = self.__array__() + # write new value + arr[idx] = value + + self._dataset = arr + + + def __array__(self): + data = self._dataset[()] + if self._dtype_set is not None: + data = data.astype(self._dtype_list) + return data + + def astype(self, dtype): # pragma: no cover + return LazyHDF5(np.asarray(self).astype(dtype)) + + @property + def dtype(self): + if self._dtype_set is not None: + return np.dtype(self._dtype_list) + return self._dataset.dtype + + @property + def shape(self): # pragma: no cover + return self._dataset.shape + + def __len__(self): # pragma: no cover + return len(self._dataset) + + def to_df(self): # pragma: no cover + return pd.DataFrame(self.__array__()) diff --git a/posydon/grids/psygrid.py b/posydon/grids/psygrid.py index 8e57a26d1a..5c725004cf 100644 --- a/posydon/grids/psygrid.py +++ b/posydon/grids/psygrid.py @@ -176,6 +176,7 @@ "Devina Misra ", "Kyle Akira Rocha ", "Matthias Kruckow ", + "Seth Gossage '] import numpy as np diff --git a/posydon/popsyn/IMFs.py b/posydon/popsyn/IMFs.py index d144470a75..bfaace0213 100644 --- a/posydon/popsyn/IMFs.py +++ b/posydon/popsyn/IMFs.py @@ -9,6 +9,7 @@ import numpy as np from scipy.integrate import quad +from posydon.utils.common_functions import inverse_sampler from posydon.utils.posydonwarning import Pwarn @@ -101,6 +102,11 @@ def imf(self, m): # pragma: no cover ''' pass + @abstractmethod + def rvs(self, size=1, rng=None): # pragma: no cover + pass + + class Salpeter(IMFBase): """ Initial Mass Function based on Salpeter (1955), which is defined as: @@ -166,6 +172,137 @@ def imf(self, m): valid = self._check_valid(m) return m ** (-self.alpha) + def rvs(self, size=1, rng=None): + """Draw random samples from the Salpeter IMF. + + Uses analytical inverse transform sampling for efficiency. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random mass samples in solar masses. + """ + if rng is None: + rng = np.random.default_rng() + + # Analytical inverse transform sampling + # For power law IMF: m^(-alpha), the inverse CDF is: + # m = (u * (m_max^(1-alpha) - m_min^(1-alpha)) + m_min^(1-alpha))^(1/(1-alpha)) + normalization_constant = (1.0 - self.alpha) / ( + self.m_max**(1.0 - self.alpha) - self.m_min**(1.0 - self.alpha) + ) + u = rng.uniform(size=size) + masses = (u * (1.0 - self.alpha) / normalization_constant + + self.m_min**(1.0 - self.alpha))**(1.0 / (1.0 - self.alpha)) + + return masses + + +class Kroupa1993(IMFBase): + """ + Initial Mass Function based on Kroupa et al. (1993), which is defined as: + + dN/dM = m^-2.7 + + References + ---------- + Kroupa P., Tout C. A., Gilmore G., 1993, MNRAS, 262, 545 + https://ui.adsabs.harvard.edu/abs/1993MNRAS.262..545K/abstract + + Parameters + ---------- + alpha : float, optional + The power-law index of the IMF (default is 2.7). + m_min : float, optional + The minimum allowable mass (default is 0.01) [Msun]. + m_max : float, optional + The maximum allowable mass (default is 200.0) [Msun]. + + Attributes + ---------- + alpha : float + Power-law index used in the IMF calculation. + m_min : float + Minimum stellar mass for the IMF [Msun]. + m_max : float + Maximum stellar mass for the IMF [Msun]. + """ + + def __init__(self, alpha=2.7, m_min=0.01, m_max=200.0): + self.alpha = alpha + super().__init__(m_min, m_max) + + def __repr__(self): + return (f"Kroupa1993(alpha={self.alpha}, " + f"m_min={self.m_min}, " + f"m_max={self.m_max})") + + def _repr_html_(self): + return (f"

Kroupa (1993) IMF

" + f"

alpha = {self.alpha}

" + f"

m_min = {self.m_min}

" + f"

m_max = {self.m_max}

") + + def imf(self, m): + '''Computes the IMF value for a given mass or array of masses 'm'. + Raises a ValueError if any value in 'm' is less than or equal to zero. + + Parameters + ---------- + m : float or array_like + Stellar mass or array of stellar masses [Msun]. + + Returns + ------- + float or ndarray + The IMF value for the given mass or masses. + ''' + m = np.asarray(m) + if np.any(m <= 0): + raise ValueError("Mass must be positive.") + valid = self._check_valid(m) + return m ** (-self.alpha) + + def rvs(self, size=1, rng=None): + """Draw random samples from the Kroupa1993 IMF. + + Uses analytical inverse transform sampling for efficiency. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random mass samples in solar masses. + """ + if rng is None: + rng = np.random.default_rng() + + # Analytical inverse transform sampling + # For power law IMF: m^(-alpha), the inverse CDF is: + # m = (u * (m_max^(1-alpha) - m_min^(1-alpha)) + m_min^(1-alpha))^(1/(1-alpha)) + normalization_constant = (1.0 - self.alpha) / ( + self.m_max**(1.0 - self.alpha) - self.m_min**(1.0 - self.alpha) + ) + u = rng.uniform(size=size) + masses = (u * (1.0 - self.alpha) / normalization_constant + + self.m_min**(1.0 - self.alpha))**(1.0 / (1.0 - self.alpha)) + + return masses + + class Kroupa2001(IMFBase): """Initial Mass Function based on Kroupa (2001), which is defined as a broken power-law: @@ -275,6 +412,39 @@ def imf(self, m): out[mask3] = const2 * (m[mask3] / self.m2break) ** (-self.alpha3) return out + def rvs(self, size=1, rng=None): + """Draw random samples from the Kroupa2001 IMF. + + Uses inverse transform sampling with discretized PDF for the + broken power-law distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random mass samples in solar masses. + """ + if rng is None: + rng = np.random.default_rng() + + # Create discretized PDF for inverse sampling + # Use more points near the breaks for better accuracy + n_points = 2000 + m_grid = np.linspace(self.m_min, self.m_max, n_points) + pdf_values = self.imf(m_grid) + + # Sample using inverse transform method + masses = inverse_sampler(m_grid, pdf_values, size=size, rng=rng) + + return masses + + class Chabrier2003(IMFBase): """Chabrier2003 Initial Mass Function (IMF), which is defined as a lognormal distribution for low-mass stars and a power-law distribution @@ -362,3 +532,35 @@ def imf(self, m): C = (1.0 / (self.m_break * sqrt_2pi_sigma)) * np.exp(-log_term_break) powerlaw = C * (m / self.m_break) ** (-self.alpha) return np.where(m < self.m_break, lognormal, powerlaw) + + def rvs(self, size=1, rng=None): + """Draw random samples from the Chabrier2003 IMF. + + Uses inverse transform sampling with discretized PDF for the + combined lognormal and power-law distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random mass samples in solar masses. + """ + if rng is None: + rng = np.random.default_rng() + + # Create discretized PDF for inverse sampling + # Use more points near the transition for better accuracy + n_points = 2000 + m_grid = np.linspace(self.m_min, self.m_max, n_points) + pdf_values = self.imf(m_grid) + + # Sample using inverse transform method + masses = inverse_sampler(m_grid, pdf_values, size=size, rng=rng) + + return masses diff --git a/posydon/popsyn/Moes_distributions.py b/posydon/popsyn/Moes_distributions.py index 43be3a9dd7..df82196a78 100644 --- a/posydon/popsyn/Moes_distributions.py +++ b/posydon/popsyn/Moes_distributions.py @@ -367,15 +367,15 @@ def __init__(self, n_M1=101, n_logP=158, n_q=91, n_e=200, # save to grid self.cumPbindist[:,i] = mycumPbindist - def __repr__(self): + def __repr__(self): # pragma: no cover return ("Moe and Di Stefano 2017 distributions on a grid of " - f"n_M1={self.n_M1}, n_logP={self.n_logP}, n_q={self.n_q}, and " - f"n_e={self.n_e}") + f"n_M1={self.numM1}, n_logP={self.numlogP}, n_q={self.numq}, and " + f"n_e={self.nume}") - def _repr_html_(self): + def _repr_html_(self): # pragma: no cover return ("

Moe and Di Stefano 2017 distributions on a grid of

" - f"

n_M1={self.n_M1}

n_logP={self.n_logP}

" - f"

n_q={self.n_q}

n_e={self.n_e}

") + f"

n_M1={self.numM1}

n_logP={self.numlogP}

" + f"

n_q={self.numq}

n_e={self.nume}

") def __call__(self, M1, M_min=0.08, M_max=150.0, all_binaries=True): """Initializing the class. diff --git a/posydon/popsyn/analysis.py b/posydon/popsyn/analysis.py index 2f7717b593..e2f373c6d4 100644 --- a/posydon/popsyn/analysis.py +++ b/posydon/popsyn/analysis.py @@ -1,5 +1,7 @@ """Module for analyzing binary population simulation results.""" +# This code is currently not being actively used, and is not covered by unit tests. +# To test this file, remove it from the 'omit' blocks in POSYDON/setup.cfg. __authors__ = [ "Konstantinos Kovlakas ", diff --git a/posydon/popsyn/binarypopulation.py b/posydon/popsyn/binarypopulation.py index eb0e0601f2..e3a460068b 100644 --- a/posydon/popsyn/binarypopulation.py +++ b/posydon/popsyn/binarypopulation.py @@ -38,6 +38,7 @@ import psutil from tqdm import tqdm +import posydon from posydon.binary_evol.binarystar import BinaryStar from posydon.binary_evol.simulationproperties import SimulationProperties from posydon.binary_evol.singlestar import SingleStar, properties_massless_remnant @@ -77,7 +78,8 @@ 'orbital_separation_scheme', 'orbital_separation_min', 'orbital_separation_max', - 'eccentricity_scheme'] + 'eccentricity_scheme', + 'posydon_version'] HISTORY_MIN_ITEMSIZE = {'state': 30, 'event': 25, 'step_names': 21, @@ -95,8 +97,8 @@ 'interp_class_CO_HMS_RLO' : 15, 'interp_class_CO_HeMS_RLO' : 15, 'mt_history_HMS_HMS' : 40, 'mt_history_CO_HeMS' : 40, 'mt_history_CO_HMS_RLO' : 40, 'mt_history_CO_HeMS_RLO' : 40, - 'culmulative_mt_case_HMS_HMS': 40, 'culmulative_mt_case_CO_HeMS': 40, - 'culmulative_mt_case_CO_HMS_RLO': 40, 'culmulative_mt_case_CO_HeMS_RLO': 40, + 'cumulative_mt_case_HMS_HMS': 40, 'cumulative_mt_case_CO_HeMS': 40, + 'cumulative_mt_case_CO_HMS_RLO': 40, 'cumulative_mt_case_CO_HeMS_RLO': 40, } # BinaryPopulation will enforce a constant metallicity accross all steps that @@ -234,7 +236,8 @@ def from_ini(cls, path, metallicity_index=0, verbose=False): return cls(**pop_kwargs) - def evolve(self, **kwargs): + def evolve(self, **kwargs): # pragma: no cover + # wrapper for _safe_evolve """Evolve a binary population. Parameters @@ -288,7 +291,8 @@ def evolve(self, **kwargs): self.kwargs.update(params) self._safe_evolve(**self.kwargs) - def _safe_evolve(self, **kwargs): + def _safe_evolve(self, **kwargs): # pragma: no cover + # needs more complex test than unit test """Evolve binaries in a population, catching warnings/exceptions.""" if not self.population_properties.steps_loaded: # Enforce the same metallicity for all grid steps @@ -302,7 +306,7 @@ def _safe_evolve(self, **kwargs): modified_tup = (step_function, step_kwargs) self.population_properties.kwargs[step_name] = modified_tup - self.population_properties.load_steps() + self.population_properties.load_steps(RNG=self.RNG) indices = kwargs.get('indices', list(range(self.number_of_binaries))) @@ -484,7 +488,8 @@ def _safe_evolve(self, **kwargs): f"evolution.combined.{self.rank}.h5"), mode='w', **kwargs) - def save(self, save_path, **kwargs): + def save(self, save_path, **kwargs): # pragma: no cover + # dependent on full evolution """Save BinaryPopulation to hdf file.""" optimize_ram = self.kwargs['optimize_ram'] temp_directory = self.kwargs['temp_directory'] @@ -512,13 +517,13 @@ def save(self, save_path, **kwargs): self.combine_saved_files(absolute_filepath, tmp_files, **kwargs) - def make_temp_fname(self): + def make_temp_fname(self): # pragma: no cover """Get a valid filename for the temporary file.""" temp_directory = self.kwargs['temp_directory'] return os.path.join(temp_directory, f"evolution.combined.{self.rank}.h5") # return os.path.join(dir_name, '.tmp{}_'.format(rank) + file_name) - def combine_saved_files(self, absolute_filepath, file_names, **kwargs): + def combine_saved_files(self, absolute_filepath, file_names, **kwargs): # pragma: no cover """Combine various temporary files in a given folder. Parameters @@ -587,7 +592,10 @@ def combine_saved_files(self, absolute_filepath, file_names, **kwargs): # store population metadata tmp_df = pd.DataFrame() for c in saved_ini_parameters: - tmp_df[c] = [self.kwargs[c]] + if c == 'posydon_version': + tmp_df[c] = [posydon.__version__] + else: + tmp_df[c] = [self.kwargs[c]] store.append('ini_parameters', tmp_df) tmp_df = pd.DataFrame( @@ -618,19 +626,19 @@ def __getstate__(self): prop.close() return d - def __iter__(self): + def __iter__(self): # pragma: no cover """Iterate the binaries.""" return iter(self.manager) - def __getitem__(self, key): + def __getitem__(self, key): # pragma: no cover """Get the k-th binary.""" return self.manager[key] - def __len__(self): + def __len__(self): # pragma: no cover """Get the number of binaries in the population.""" return len(self.manager) - def __repr__(self): + def __repr__(self): # pragma: no cover """Report key properties of the object.""" s = "<{}.{} at {}>\n".format( self.__class__.__module__, self.__class__.__name__, hex(id(self)) @@ -733,7 +741,7 @@ def to_df(self, selection_function=None, **kwargs): and selection_function(binary)): holder.append(binary.to_df(**kwargs)) - elif len(self.history_dfs) > 0: + elif len(self.history_dfs) > 0: # pragma: no branch holder.extend(self.history_dfs) if len(holder) > 0: @@ -754,7 +762,7 @@ def to_oneline_df(self, selection_function=None, **kwargs): and selection_function(binary)): holder.append(binary.to_oneline_df(**kwargs)) - elif len(self.oneline_dfs) > 0: + elif len(self.oneline_dfs) > 0: # pragma: no branch holder.extend(self.oneline_dfs) if len(holder) > 0: @@ -781,7 +789,7 @@ def generate(self, **kwargs): self.append(binary) return binary - def from_hdf(self, indices=None, where=None, restore=False): + def from_hdf(self, indices=None, where=None, restore=False): # pragma: no cover """Load a BinaryStar instance from an hdf file of a saved population. Parameters @@ -851,7 +859,7 @@ def from_hdf(self, indices=None, where=None, restore=False): return binary_holder - def save(self, fname, **kwargs): + def save(self, fname, **kwargs): # pragma: no cover """Save binaries to an hdf file using pandas HDFStore. Any object dtype columns not parsed by infer_objects() is converted to @@ -938,7 +946,10 @@ def save(self, fname, **kwargs): # store population metadata tmp_df = pd.DataFrame() for c in saved_ini_parameters: - tmp_df[c] = [self.kwargs[c]] + if c == 'posydon_version': + tmp_df[c] = [posydon.__version__] + else: + tmp_df[c] = [self.kwargs[c]] store.append('ini_parameters', tmp_df) tmp_df = pd.DataFrame( @@ -953,19 +964,19 @@ def save(self, fname, **kwargs): return - def __getitem__(self, key): + def __getitem__(self, key): # pragma: no cover """Return the key-th binary.""" return self.binaries[key] - def __iter__(self): + def __iter__(self): # pragma: no cover """Iterate the binaries in the population.""" return iter(self.binaries) - def __len__(self): + def __len__(self): # pragma: no cover """Return the number of binaries in the population.""" return len(self.binaries) - def __bool__(self): + def __bool__(self): # pragma: no cover """Evaluate as True if binaries have been appended.""" return len(self) > 0 @@ -1021,7 +1032,7 @@ def draw_initial_samples(self, orbital_scheme='separation', **kwargs): if orbital_scheme == 'separation': separation, eccentricity, m1, m2 = sampler_output orbital_period = orbital_period_from_separation(separation, m1, m2) - elif orbital_scheme == 'period': + elif orbital_scheme == 'period': # pragma: no branch orbital_period, eccentricity, m1, m2 = sampler_output separation = orbital_separation_from_period(orbital_period, m1, m2) else: @@ -1168,7 +1179,7 @@ def draw_initial_binary(self, **kwargs): star_2=SingleStar(**star2_params)) return binary - def __repr__(self,): + def __repr__(self,): # pragma: no cover """Report key properties of the BinaryGenerator instance.""" s = "<{}.{} at {}>\n".format( self.__class__.__module__, self.__class__.__name__, hex(id(self)) diff --git a/posydon/popsyn/distributions.py b/posydon/popsyn/distributions.py index 36247403ab..67a9f6481b 100644 --- a/posydon/popsyn/distributions.py +++ b/posydon/popsyn/distributions.py @@ -5,6 +5,7 @@ ] import numpy as np from scipy.integrate import quad +from scipy.stats import truncnorm class FlatMassRatio: @@ -12,12 +13,12 @@ class FlatMassRatio: A uniform distribution for mass ratios q = m2/m1 within specified bounds. This distribution assigns equal probability to all mass ratios within the - given range [q_min, q_max]. + given range (q_min, q_max], exclusive bottom, inclusive top. Parameters ---------- q_min : float, optional - Minimum mass ratio (default: 0.05). Must be in (0, 1]. + Minimum mass ratio (default: 0.05). Must be in [0, 1). q_max : float, optional Maximum mass ratio (default: 1.0). Must be in (0, 1]. @@ -38,19 +39,19 @@ def __init__(self, q_min=0.05, q_max=1): Parameters ---------- q_min : float, optional - Minimum mass ratio (default: 0.05). Must be in (0, 1]. + Minimum mass ratio (default: 0.05). Must be in [0, 1). q_max : float, optional Maximum mass ratio (default: 1.0). Must be in (0, 1]. Raises ------ ValueError - If q_min or q_max are not in (0, 1], or if q_min >= q_max. + If q_min or q_max are not in valid range, or if q_min >= q_max. """ - if not (0 < q_min <= 1): - raise ValueError("q_min must be in (0, 1)") + if not (0 <= q_min < 1): + raise ValueError("q_min must be in [0, 1)") if not (0 < q_max <= 1): - raise ValueError("q_max must be in (0, 1)") + raise ValueError("q_max must be in (0, 1]") if q_min >= q_max: raise ValueError("q_min must be less than q_max") @@ -124,11 +125,174 @@ def pdf(self, q): Probability density at mass ratio q. """ q = np.asarray(q) - valid = (q >= self.q_min) & (q <= self.q_max) + valid = (q > self.q_min) & (q <= self.q_max) pdf_values = np.zeros_like(q, dtype=float) pdf_values[valid] = self.flat_mass_ratio(q[valid]) * self.norm return pdf_values + def rvs(self, size=1, rng=None): + """Draw random samples from the flat mass ratio distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + float or ndarray + Random samples from the distribution. + """ + if rng is None: + rng = np.random.default_rng() + + return rng.uniform(self.q_min, self.q_max, size=size) + + +class PowerLawMassRatio: + """Power law mass ratio distribution for binary star systems. + + A distribution where the PDF follows q^alpha within specified bounds + (q_min, q_max], exclusive bottom, inclusive top. + + Parameters + ---------- + alpha : float, optional + Power law exponent (default: 0.0, i.e. flat). Can be any real number. + q_min : float, optional + Minimum mass ratio (default: 0.05). Must be in [0, 1). + q_max : float, optional + Maximum mass ratio (default: 1.0). Must be in (0, 1]. + + Raises + ------ + ValueError + If q_min or q_max are not in valid range, or if q_min >= q_max. + + Examples + -------- + >>> dist = PowerLawMassRatio(alpha=-1.0, q_min=0.1, q_max=1.0) + >>> pdf_value = dist.pdf(0.5) + """ + + def __init__(self, alpha=0.0, q_min=0.05, q_max=1.0): + """Initialize the power law mass ratio distribution. + + Parameters + ---------- + alpha : float, optional + Power law exponent (default: 0.0). + q_min : float, optional + Minimum mass ratio (default: 0.05). Must be in [0, 1). + q_max : float, optional + Maximum mass ratio (default: 1.0). Must be in (0, 1]. + + Raises + ------ + ValueError + If q_min or q_max are not in valid range, or if q_min >= q_max. + """ + if not (0 <= q_min < 1): + raise ValueError("q_min must be in [0, 1)") + if not (0 < q_max <= 1): + raise ValueError("q_max must be in (0, 1]") + if q_min >= q_max: + raise ValueError("q_min must be less than q_max") + if alpha <= -1 and q_min == 0: + raise ValueError("q_min must be > 0 for alpha <= -1 " + "to avoid divergent integral") + + self.alpha = alpha + self.q_min = q_min + self.q_max = q_max + self.norm = self._calculate_normalization() + + def __repr__(self): + """Return string representation of the distribution.""" + return (f"PowerLawMassRatio(alpha={self.alpha}, " + f"q_min={self.q_min}, q_max={self.q_max})") + + def _repr_html_(self): + """Return HTML representation for Jupyter notebooks.""" + return (f"

Power Law Mass Ratio Distribution

" + f"

alpha = {self.alpha}

" + f"

q_min = {self.q_min}

" + f"

q_max = {self.q_max}

") + + def _calculate_normalization(self): + """Calculate the normalization constant for the power law mass ratio + distribution. + + Returns + ------- + float + The normalization constant ensuring the PDF integrates to 1. + """ + integral, _ = quad(self.power_law_mass_ratio, self.q_min, self.q_max) + if integral == 0: # pragma: no cover + raise ValueError("Normalization integral is zero. " + "Check mass ratio parameters.") + return 1.0 / integral + + def power_law_mass_ratio(self, q): + """Compute the power law mass ratio distribution value. + + Parameters + ---------- + q : float or array_like + Mass ratio. + + Returns + ------- + float or ndarray + Distribution value q^alpha. + """ + return np.asarray(q, dtype=float)**self.alpha + + def pdf(self, q): + """Probability density function of the power law mass ratio distribution. + + Parameters + ---------- + q : float or array_like + Mass ratio(s). + + Returns + ------- + float or ndarray + Probability density at mass ratio q. + """ + q = np.asarray(q) + valid = (q > self.q_min) & (q <= self.q_max) + pdf_values = np.zeros_like(q, dtype=float) + pdf_values[valid] = self.power_law_mass_ratio(q[valid]) * self.norm + return pdf_values + + def rvs(self, size=1, n_points=1000, rng=None): + """Draw random samples from the power law mass ratio distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random samples from the distribution. + """ + if rng is None: + rng = np.random.default_rng() + + from posydon.utils.common_functions import inverse_sampler + q_grid = np.linspace(self.q_min, self.q_max, n_points) + pdf_values = self.power_law_mass_ratio(q_grid) + return inverse_sampler(q_grid, pdf_values, size=size, rng=rng) + class Sana12Period(): """Period distribution from Sana et al. (2012). @@ -310,6 +474,73 @@ def pdf(self, p, m1): * self.norm(m1[valid])) return pdf_values + def rvs(self, size=1, m1=None, rng=None): + """Draw random samples from the Sana12 period distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + m1 : float or array_like, optional + Primary mass(es). If array, must have length equal to size. + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random period samples in days. + + Raises + ------ + ValueError + If m1 is None or has incorrect size. + """ + if rng is None: + rng = np.random.default_rng() + + if m1 is None: + raise ValueError("m1 (primary mass) must be provided for Sana12Period sampling") + + m1 = np.atleast_1d(m1) + if m1.size == 1: + m1 = np.full(size, m1[0]) + elif m1.size != size: + raise ValueError(f"m1 must be a single value or have size={size}" + f"\n m1 = {m1}\n m1.size = {m1.size}") + + # Import here to avoid circular dependency + from posydon.utils.common_functions import rejection_sampler + + periods = np.zeros(size) + + # For low mass stars (m1 <= 15), use log-uniform distribution + low_mass_mask = m1 <= self.mbreak + n_low = np.sum(low_mass_mask) + if n_low > 0: + periods[low_mass_mask] = 10**rng.uniform( + np.log10(self.p_min), + np.log10(self.p_max), + size=n_low + ) + + # For high mass stars (m1 > 15), use rejection sampling + high_mass_mask = ~low_mass_mask + n_high = np.sum(high_mass_mask) + if n_high > 0: + # Create PDF function for rejection sampler + def pdf_high_mass(logp): + return self.sana12_period(logp, self.mbreak + 1) + + periods[high_mass_mask] = 10**rejection_sampler( + size=n_high, + x_lim=[np.log10(self.p_min), np.log10(self.p_max)], + pdf=pdf_high_mass, + rng=rng + ) + + return periods + class PowerLawPeriod(): '''Power law period distribution with slope pi and boundaries m_min, m_max. @@ -438,6 +669,37 @@ def pdf(self, p): return pdf_values + def rvs(self, size=1, n_points=1000, rng=None): + """Draw random samples from the power law period distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random period samples in days. + """ + if rng is None: + rng = np.random.default_rng() + + # Import here to avoid circular dependency + from posydon.utils.common_functions import inverse_sampler + + # Create discretized PDF for inverse sampling + logp_grid = np.linspace(np.log10(self.p_min), np.log10(self.p_max), n_points) + pdf_values = self.power_law_period(logp_grid) + + # Sample in log space + logp_samples = inverse_sampler(logp_grid, pdf_values, size=size, rng=rng) + + # Convert back to linear space + return 10**logp_samples + class LogUniform(): """Log-uniform distribution between specified minimum and maximum values. @@ -468,6 +730,28 @@ def __init__(self, min=5.0, max=1e5): self.norm = self._calculate_normalization() + def __repr__(self): + """Return string representation of the distribution. + + Returns + ------- + str + String representation showing the distribution parameters. + """ + return f"LogUniform(min={self.min}, max={self.max})" + + def _repr_html_(self): + """Return HTML representation for Jupyter notebooks. + + Returns + ------- + str + HTML string for rich display in notebooks. + """ + return (f"

Log-Uniform Distribution

" + f"

min = {self.min}

" + f"

max = {self.max}

") + def _calculate_normalization(self): """ Calculate the normalization constant for the log-uniform distribution. @@ -497,6 +781,465 @@ def pdf(self, x): valid = (x > 0) & (x >= self.min) & (x <= self.max) pdf_values = np.zeros_like(x, dtype=float) + pdf_values[valid] = self.norm / x[valid] # PDF is constant in log space, so divide by x for linear space + + return pdf_values + + def rvs(self, size=1, rng=None): + """Draw random samples from the log-uniform distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random samples from the distribution. + """ + if rng is None: + rng = np.random.default_rng() + + # Sample uniformly in log space + log_samples = rng.uniform(np.log10(self.min), np.log10(self.max), size=size) + + # Convert back to linear space + return 10**log_samples + + +class ThermalEccentricity: + """Thermal eccentricity distribution for binary star systems. + + The thermal distribution follows pdf(e) = 2*e, which is the + distribution expected for binaries that have undergone significant + dynamical interactions or thermal relaxation. + + Parameters + ---------- + e_min : float, optional + Minimum eccentricity (default: 0.0). Must be in [0, 1). + e_max : float, optional + Maximum eccentricity (default: 1.0). Must be in (0, 1]. + + Raises + ------ + ValueError + If e_min or e_max are not in [0, 1], or if e_min >= e_max. + + Examples + -------- + >>> dist = ThermalEccentricity() + >>> e_samples = dist.rvs(size=1000) + """ + + def __init__(self, e_min=0.0, e_max=1.0): + """Initialize the thermal eccentricity distribution. + + Parameters + ---------- + e_min : float, optional + Minimum eccentricity (default: 0.0). Must be in [0, 1). + e_max : float, optional + Maximum eccentricity (default: 1.0). Must be in (0, 1]. + + Raises + ------ + ValueError + If e_min or e_max are not in [0, 1], or if e_min >= e_max. + """ + if not (0 <= e_min < 1): + raise ValueError("e_min must be in [0, 1)") + if not (0 < e_max <= 1): + raise ValueError("e_max must be in (0, 1]") + if e_min >= e_max: + raise ValueError("e_min must be less than e_max") + + self.e_min = e_min + self.e_max = e_max + self.norm = self._calculate_normalization() + + def __repr__(self): + """Return string representation of the distribution.""" + return f"ThermalEccentricity(e_min={self.e_min}, e_max={self.e_max})" + + def _repr_html_(self): + """Return HTML representation for Jupyter notebooks.""" + return (f"

Thermal Eccentricity Distribution

" + f"

e_min = {self.e_min}

" + f"

e_max = {self.e_max}

") + + def _calculate_normalization(self): + """Calculate the normalization constant for the thermal eccentricity + distribution. + + Returns + ------- + float + The normalization constant ensuring the PDF integrates to 1. + """ + # Integral of 2*e from e_min to e_max is e_max^2 - e_min^2 + integral = self.e_max**2 - self.e_min**2 + if integral == 0: # pragma: no cover + raise ValueError("Cannot normalize distribution: e_min == e_max") + return 1.0 / integral + + def thermal_eccentricity(self, e): + """Compute the thermal eccentricity distribution value. + + Parameters + ---------- + e : float or array_like + Eccentricity value(s). + + Returns + ------- + float or ndarray + Distribution value (2*e). + """ + return 2.0 * np.asarray(e) + + def pdf(self, e): + """Probability density function of the thermal eccentricity distribution. + + Parameters + ---------- + e : float or array_like + Eccentricity value(s). + + Returns + ------- + float or ndarray + Probability density at eccentricity e. + """ + e = np.asarray(e) + valid = (e >= self.e_min) & (e <= self.e_max) + pdf_values = np.zeros_like(e, dtype=float) + pdf_values[valid] = self.thermal_eccentricity(e[valid]) * self.norm + return pdf_values + + def rvs(self, size=1, rng=None): + """Draw random samples from the thermal eccentricity distribution. + + Uses the analytical inverse CDF: e = sqrt(u * (e_max^2 - e_min^2) + e_min^2) + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random eccentricity samples in [e_min, e_max]. + """ + if rng is None: + rng = np.random.default_rng() + + # Inverse CDF: e = sqrt(u * (e_max^2 - e_min^2) + e_min^2) + u = rng.uniform(size=size) + return np.sqrt(u * (self.e_max**2 - self.e_min**2) + self.e_min**2) + + +class UniformEccentricity: + """Uniform eccentricity distribution for binary star systems. + + A flat distribution over eccentricities between e_min and e_max. + + Parameters + ---------- + e_min : float, optional + Minimum eccentricity (default: 0.0). Must be in [0, 1). + e_max : float, optional + Maximum eccentricity (default: 1.0). Must be in (0, 1]. + + Raises + ------ + ValueError + If e_min or e_max are not in [0, 1], or if e_min >= e_max. + + Examples + -------- + >>> dist = UniformEccentricity() + >>> e_samples = dist.rvs(size=1000) + """ + + def __init__(self, e_min=0.0, e_max=1.0): + """Initialize the uniform eccentricity distribution. + + Parameters + ---------- + e_min : float, optional + Minimum eccentricity (default: 0.0). Must be in [0, 1). + e_max : float, optional + Maximum eccentricity (default: 1.0). Must be in (0, 1]. + + Raises + ------ + ValueError + If e_min or e_max are not in [0, 1], or if e_min >= e_max. + """ + if not (0 <= e_min < 1): + raise ValueError("e_min must be in [0, 1)") + if not (0 < e_max <= 1): + raise ValueError("e_max must be in (0, 1]") + if e_min >= e_max: + raise ValueError("e_min must be less than e_max") + + self.e_min = e_min + self.e_max = e_max + self.norm = 1.0 / (e_max - e_min) + + def __repr__(self): + """Return string representation of the distribution.""" + return f"UniformEccentricity(e_min={self.e_min}, e_max={self.e_max})" + + def _repr_html_(self): + """Return HTML representation for Jupyter notebooks.""" + return (f"

Uniform Eccentricity Distribution

" + f"

e_min = {self.e_min}

" + f"

e_max = {self.e_max}

") + + def pdf(self, e): + """Probability density function of the uniform eccentricity distribution. + + Parameters + ---------- + e : float or array_like + Eccentricity value(s). + + Returns + ------- + float or ndarray + Probability density at eccentricity e. + """ + e = np.asarray(e) + valid = (e >= self.e_min) & (e <= self.e_max) + pdf_values = np.zeros_like(e, dtype=float) pdf_values[valid] = self.norm + return pdf_values + + def rvs(self, size=1, rng=None): + """Draw random samples from the uniform eccentricity distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random eccentricity samples in [e_min, e_max]. + """ + if rng is None: + rng = np.random.default_rng() + + return rng.uniform(self.e_min, self.e_max, size=size) + + +class ZeroEccentricity: + """Zero eccentricity distribution for circular binary orbits. + + All samples are exactly zero (circular orbits). The PDF is a Dirac delta + at e=0. + + Examples + -------- + >>> dist = ZeroEccentricity() + >>> e_samples = dist.rvs(size=1000) # All zeros + """ + + def __init__(self): + """Initialize the zero eccentricity distribution.""" + pass + + def __repr__(self): + """Return string representation of the distribution.""" + return "ZeroEccentricity()" + + def _repr_html_(self): + """Return HTML representation for Jupyter notebooks.""" + return "

Zero Eccentricity Distribution

e = 0 (circular orbits)

" + + def pdf(self, e): + """Probability density function of the zero eccentricity distribution. + + This is formally a Dirac delta at e=0. Returns 1.0 at e=0, 0.0 elsewhere. + + Parameters + ---------- + e : float or array_like + Eccentricity value(s). + + Returns + ------- + float or ndarray + 1.0 where e==0, 0.0 elsewhere. + """ + e = np.asarray(e) + return np.where(e == 0, 1.0, 0.0) + + def rvs(self, size=1, rng=None): + """Draw random samples from the zero eccentricity distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator (unused, included for API consistency). + + Returns + ------- + ndarray + Array of zeros with shape (size,). + """ + return np.zeros(size) + + +class LogNormalSeparation: + """Log-normal orbital separation distribution for binary star systems. + + Orbital separations are drawn from a log-normal distribution in log10 space, + truncated between specified minimum and maximum values. + + Parameters + ---------- + mean : float, optional + Mean of the log10 distribution (default: 0.85, corresponding to ~7.08 Rsun). + sigma : float, optional + Standard deviation of the log10 distribution (default: 0.37). + min : float, optional + Minimum orbital separation in solar radii (default: 5.0). + max : float, optional + Maximum orbital separation in solar radii (default: 1e5). + + Raises + ------ + ValueError + If min is not positive, max <= min, or sigma <= 0. + + Examples + -------- + >>> dist = LogNormalSeparation(mean=3.0, sigma=1.5, min=10, max=1e6) + >>> separations = dist.rvs(size=1000) + + Notes + ----- + Uses scipy.stats.truncnorm for efficient sampling from the truncated + normal distribution in log10 space. + """ + + def __init__(self, mean=0.85, sigma=0.37, min=5.0, max=1e5): + """Initialize the log-normal separation distribution. + + Parameters + ---------- + mean : float, optional + Mean of the log10 distribution (default: 0.85). + sigma : float, optional + Standard deviation of the log10 distribution (default: 0.37). + min : float, optional + Minimum orbital separation in solar radii (default: 5.0). + max : float, optional + Maximum orbital separation in solar radii (default: 1e5). + + Raises + ------ + ValueError + If min is not positive, max <= min, or sigma <= 0. + """ + if min <= 0: + raise ValueError("min must be positive") + if max <= min: + raise ValueError("max must be greater than min") + if sigma <= 0: + raise ValueError("sigma must be positive") + + self.mean = mean + self.sigma = sigma + self.min = min + self.max = max + + # Compute truncation bounds for scipy.stats.truncnorm + self.a_low = (np.log10(min) - mean) / sigma + self.a_high = (np.log10(max) - mean) / sigma + + def __repr__(self): + """Return string representation of the distribution.""" + return (f"LogNormalSeparation(mean={self.mean}, sigma={self.sigma}, " + f"min={self.min}, max={self.max})") + + def _repr_html_(self): + """Return HTML representation for Jupyter notebooks.""" + return (f"

Log-Normal Separation Distribution

" + f"

mean (log10) = {self.mean}

" + f"

sigma (log10) = {self.sigma}

" + f"

min = {self.min} R☉

" + f"

max = {self.max} R☉

") + + def pdf(self, a): + """Probability density function of the log-normal separation distribution. + + Parameters + ---------- + a : float or array_like + Orbital separation(s) in solar radii. + + Returns + ------- + float or ndarray + Probability density at separation a. + """ + a = np.asarray(a) + valid = (a > 0) & (a >= self.min) & (a <= self.max) + pdf_values = np.zeros_like(a, dtype=float) + + log_a = np.log10(a[valid]) + # PDF of truncnorm in log space, transformed to linear space + # pdf(a) = pdf_log(log a) * |d(log a)/da| = pdf_log(log a) / (a * ln(10)) + pdf_values[valid] = truncnorm.pdf( + log_a, + self.a_low, self.a_high, + loc=self.mean, + scale=self.sigma + ) / (a[valid] * np.log(10)) return pdf_values + + def rvs(self, size=1, rng=None): + """Draw random samples from the log-normal separation distribution. + + Parameters + ---------- + size : int, optional + Number of samples to draw (default: 1). + rng : numpy.random.Generator, optional + Random number generator. If None, uses np.random.default_rng(). + + Returns + ------- + ndarray + Random orbital separation samples in solar radii. + """ + if rng is None: + rng = np.random.default_rng() + + # Sample from truncated normal in log10 space + log_separations = truncnorm.rvs( + self.a_low, self.a_high, + loc=self.mean, + scale=self.sigma, + size=size, + random_state=rng + ) + + # Convert back to linear space + return 10**log_separations diff --git a/posydon/popsyn/independent_sample.py b/posydon/popsyn/independent_sample.py index 7bb0ee0f11..16dd4f1cfa 100644 --- a/posydon/popsyn/independent_sample.py +++ b/posydon/popsyn/independent_sample.py @@ -16,6 +16,7 @@ import numpy as np from scipy.stats import truncnorm +from posydon.popsyn import IMFs, distributions from posydon.popsyn.Moes_distributions import Moe_17_PsandQs from posydon.utils.common_functions import rejection_sampler @@ -49,7 +50,8 @@ def generate_independent_samples(orbital_scheme='period', **kwargs): # Generate primary masses m1_set = generate_primary_masses(**kwargs) - if use_Moe_17_PsandQs(orbital_scheme=orbital_scheme, **kwargs): + if use_Moe_17_PsandQs(orbital_scheme=orbital_scheme, **kwargs): # pragma: no cover + # this requires an integration test with actual external datafiles. No unit test # initialize generator for Moe+17-PsandQs if _gen_Moe_17_PsandQs is None: _gen_Moe_17_PsandQs = Moe_17_PsandQs(**kwargs) @@ -123,53 +125,18 @@ def generate_orbital_periods(primary_masses, orbital_period_max=10**3.5, orbital_period_scheme='Sana+12_period_extended', **kwargs): - """Randomaly generate orbital periods for a sample of binaries.""" + """Randomly generate orbital periods for a sample of binaries.""" RNG = kwargs.get('RNG', np.random.default_rng()) # Check inputs # Sana H., et al., 2012, Science, 337, 444 if orbital_period_scheme == 'Sana+12_period_extended': - # compute periods as if all M1 <= 15Msun (where pi = 0.0) - orbital_periods_M_lt_15 = 10**RNG.uniform( - low=np.log10(orbital_period_min), - high=np.log10(orbital_period_max), - size=number_of_binaries) - - # compute periods as if all M1 > 15Msun - def pdf(logp): - pi = 0.55 - beta = 1 - pi - A = np.log10(10**0.15)**(-pi)*(np.log10(10**0.15) - - np.log10(orbital_period_min)) - B = 1./beta*(np.log10(orbital_period_max)**beta - - np.log10(10**0.15)**beta) - C = 1./(A + B) - pdf = np.zeros(len(logp)) - - for j, logp_j in enumerate(logp): - # for logP<=0.15 days, the pdf is uniform - if np.log10(orbital_period_min) <= logp_j and logp_j < 0.15: - pdf[j] = C*0.15**(-pi) - - # original Sana H., et al., 2012, Science, 337, 444 - elif 0.15 <= logp_j and logp_j < np.log10(orbital_period_max): - pdf[j] = C*logp_j**(-pi) - - else: - pdf[j] = 0. - - return pdf - - orbital_periods_M_gt_15 = 10**(rejection_sampler( - size=number_of_binaries, - x_lim=[np.log10(orbital_period_min), np.log10(orbital_period_max)], - pdf=pdf, - rng=RNG)) - - orbital_periods = np.where(primary_masses <= 15.0, - orbital_periods_M_lt_15, - orbital_periods_M_gt_15) + period_dist = distributions.Sana12Period( + p_min=orbital_period_min, + p_max=orbital_period_max + ) + orbital_periods = period_dist.rvs(size=number_of_binaries, m1=primary_masses, rng=RNG) else: raise ValueError("You must provide an allowed orbital period scheme.") @@ -217,10 +184,11 @@ def generate_orbital_separations(number_of_binaries=1, "orbital separation scheme.") if orbital_separation_scheme == 'log_uniform': - orbital_separations = 10**RNG.uniform( - low=np.log10(orbital_separation_min), - high=np.log10(orbital_separation_max), - size=number_of_binaries) + sep_dist = distributions.LogUniform( + min=orbital_separation_min, + max=orbital_separation_max + ) + orbital_separations = sep_dist.rvs(size=number_of_binaries, rng=RNG) if orbital_separation_max < orbital_separation_min: raise ValueError("`orbital_separation_max` must be " @@ -234,22 +202,15 @@ def generate_orbital_separations(number_of_binaries=1, "`log_orbital_separation_mean`, " "`log_orbital_separation_sigma`.") - # Set limits for truncated normal distribution - a_low = (np.log10(orbital_separation_min) - - log_orbital_separation_mean) / log_orbital_separation_sigma - a_high = (np.log10(orbital_separation_max) - - log_orbital_separation_mean) / log_orbital_separation_sigma - - # generate orbital separations from a truncted normal distribution - log_orbital_separations = truncnorm.rvs( - a_low, a_high, - loc=log_orbital_separation_mean, - scale=log_orbital_separation_sigma, - size=number_of_binaries, - random_state=RNG) - orbital_separations = 10**log_orbital_separations + sep_dist = distributions.LogNormalSeparation( + mean=log_orbital_separation_mean, + sigma=log_orbital_separation_sigma, + min=orbital_separation_min, + max=orbital_separation_max + ) + orbital_separations = sep_dist.rvs(size=number_of_binaries, rng=RNG) - else: + else: # pragma: no cover pass return orbital_separations @@ -285,12 +246,15 @@ def generate_eccentricities(number_of_binaries=1, raise ValueError("You must provide an allowed eccentricity scheme.") if eccentricity_scheme == 'thermal': - eccentricities = np.sqrt(RNG.uniform(size=number_of_binaries)) + ecc_dist = distributions.ThermalEccentricity() + eccentricities = ecc_dist.rvs(size=number_of_binaries, rng=RNG) elif eccentricity_scheme == 'uniform': - eccentricities = RNG.uniform(size=number_of_binaries) + ecc_dist = distributions.UniformEccentricity() + eccentricities = ecc_dist.rvs(size=number_of_binaries, rng=RNG) elif eccentricity_scheme == 'zero': - eccentricities = np.zeros(number_of_binaries) - else: + ecc_dist = distributions.ZeroEccentricity() + eccentricities = ecc_dist.rvs(size=number_of_binaries, rng=RNG) + else: # pragma: no cover # This should never be reached pass @@ -332,31 +296,19 @@ def generate_primary_masses(number_of_binaries=1, # Salpeter E. E., 1955, ApJ, 121, 161 if primary_mass_scheme == 'Salpeter': - alpha = 2.35 - normalization_constant = (1.0-alpha) / (primary_mass_max**(1-alpha) - - primary_mass_min**(1-alpha)) - random_variable = RNG.uniform(size=number_of_binaries) - primary_masses = (random_variable*(1.0-alpha)/normalization_constant - + primary_mass_min**(1.0-alpha))**(1.0/(1.0-alpha)) + imf = IMFs.Salpeter(alpha=2.35, m_min=primary_mass_min, m_max=primary_mass_max) + primary_masses = imf.rvs(size=number_of_binaries, rng=RNG) # Kroupa P., Tout C. A., Gilmore G., 1993, MNRAS, 262, 545 elif primary_mass_scheme == 'Kroupa1993': - alpha = 2.7 - normalization_constant = (1.0-alpha) / (primary_mass_max**(1-alpha) - - primary_mass_min**(1-alpha)) - random_variable = RNG.uniform(size=number_of_binaries) - primary_masses = (random_variable*(1.0-alpha)/normalization_constant - + primary_mass_min**(1.0-alpha))**(1.0/(1.0-alpha)) + imf = IMFs.Kroupa1993(alpha=2.7, m_min=primary_mass_min, m_max=primary_mass_max) + primary_masses = imf.rvs(size=number_of_binaries, rng=RNG) # Kroupa P., 2001, MNRAS, 322, 231 elif primary_mass_scheme == 'Kroupa2001': - alpha = 2.3 - normalization_constant = (1.0-alpha) / (primary_mass_max**(1-alpha) - - primary_mass_min**(1-alpha)) - random_variable = RNG.uniform(size=number_of_binaries) - primary_masses = (random_variable*(1.0-alpha)/normalization_constant - + primary_mass_min**(1.0-alpha))**(1.0/(1.0-alpha)) - else: + imf = IMFs.Kroupa2001(m_min=primary_mass_min, m_max=primary_mass_max) + primary_masses = imf.rvs(size=number_of_binaries, rng=RNG) + else: # pragma: no cover pass return primary_masses @@ -405,13 +357,18 @@ def generate_secondary_masses(primary_masses, # Generate secondary masses if secondary_mass_scheme == 'flat_mass_ratio': - mass_ratio_min = np.max([secondary_mass_min / primary_masses, - np.ones(len(primary_masses))*0.05], axis=0) - mass_ratio_max = np.min([secondary_mass_max / primary_masses, - np.ones(len(primary_masses))], axis=0) - secondary_masses = ( - (mass_ratio_max - mass_ratio_min) * RNG.uniform( - size=number_of_binaries) + mass_ratio_min) * primary_masses + # Calculate mass ratio bounds for each primary mass + q_min = np.maximum(secondary_mass_min / primary_masses, 0.0) + q_max = np.minimum(secondary_mass_max / primary_masses, 1.0) + + # Sample mass ratios using the distribution class + # For mass-dependent bounds, we need to sample individually + mass_ratios = np.zeros(number_of_binaries) + for i in range(number_of_binaries): + q_dist = distributions.FlatMassRatio(q_min=q_min[i], q_max=q_max[i]) + mass_ratios[i] = q_dist.rvs(size=1, rng=RNG)[0] + + secondary_masses = primary_masses * mass_ratios if secondary_mass_scheme == 'q=1': secondary_masses = primary_masses @@ -445,6 +402,7 @@ def generate_binary_fraction(m1=None, binary_fraction_const=1, elif not isinstance(m1,np.ndarray): m1 = np.asarray(m1) binary_fraction = np.zeros_like(m1, dtype=float) + # Input parameter checks if binary_fraction_scheme not in binary_fraction_scheme_options: raise ValueError("You must provide an allowed binary fraction scheme.") @@ -459,7 +417,7 @@ def generate_binary_fraction(m1=None, binary_fraction_const=1, binary_fraction[(m1 <= 5) & (m1 > 2)] = 0.59 binary_fraction[(m1 <= 2)] = 0.4 - else: + else: # pragma: no cover pass return binary_fraction diff --git a/posydon/popsyn/io.py b/posydon/popsyn/io.py index 9468bccb2b..810b6ccb3e 100644 --- a/posydon/popsyn/io.py +++ b/posydon/popsyn/io.py @@ -60,7 +60,7 @@ STARPROPERTIES_DTYPES = { 'state': 'string', # the evolutionary state of the star. For more info see # `posydon.utils.common_functions.check_state_of_star` - 'metallicity': 'float64', # initial mass fraction of metals + 'metallicity': 'float64', # Z/Z_sun, ratio to solar metallicity (1.0 for solar) 'mass': 'float64', # mass (solar units) 'log_R': 'float64', # log10 of radius (solar units) 'log_L': 'float64', # log10 luminosity (solar units) @@ -192,11 +192,11 @@ def clean_binary_history_df(binary_df, extra_binary_dtypes_user=None, assert isinstance( binary_df, pd.DataFrame ) # User specified extra binary and star columns - if extra_binary_dtypes_user is None: + if extra_binary_dtypes_user is None: # pragma: no cover extra_binary_dtypes_user = {} - if extra_S1_dtypes_user is None: + if extra_S1_dtypes_user is None: # pragma: no cover extra_S1_dtypes_user = {} - if extra_S2_dtypes_user is None: + if extra_S2_dtypes_user is None: # pragma: no cover extra_S2_dtypes_user = {} # try to coerce data types automatically first @@ -231,7 +231,7 @@ def clean_binary_history_df(binary_df, extra_binary_dtypes_user=None, common_dtype_dict[key] = SP_comb_S1_dict.get( key.replace('S1_', '') ) elif key in S2_keys: common_dtype_dict[key] = SP_comb_S2_dict.get( key.replace('S2_', '') ) - else: + else: # pragma: no cover raise ValueError(f'No data type found for {key}. Dtypes must be explicity declared.') # set dtypes binary_df = binary_df.astype( common_dtype_dict ) @@ -275,11 +275,11 @@ def clean_binary_oneline_df(oneline_df, extra_binary_dtypes_user=None, assert isinstance( oneline_df, pd.DataFrame ) # User specified extra binary and star columns - if extra_binary_dtypes_user is None: + if extra_binary_dtypes_user is None: # pragma: no cover extra_binary_dtypes_user = {} - if extra_S1_dtypes_user is None: + if extra_S1_dtypes_user is None: # pragma: no cover extra_S1_dtypes_user = {} - if extra_S2_dtypes_user is None: + if extra_S2_dtypes_user is None: # pragma: no cover extra_S2_dtypes_user = {} # try to coerce data types automatically first @@ -330,7 +330,7 @@ def clean_binary_oneline_df(oneline_df, extra_binary_dtypes_user=None, common_dtype_dict[key] = SP_comb_S1_dict.get( strip_prefix_and_suffix(key) ) elif key in S2_keys: common_dtype_dict[key] = SP_comb_S2_dict.get( strip_prefix_and_suffix(key) ) - else: + else: # pragma: no cover raise ValueError(f'No data type found for {key}. Dtypes must be explicity declared.') # set dtypes oneline_df = oneline_df.astype( common_dtype_dict ) @@ -369,7 +369,7 @@ def parse_inifile(path, verbose=False): if isinstance(path, str): path = os.path.abspath(path) - if verbose: + if verbose: # pragma: no cover print('Reading inifile: \n\t{}'.format(path)) if not os.path.exists(path): raise FileNotFoundError( @@ -377,7 +377,7 @@ def parse_inifile(path, verbose=False): elif isinstance(path, (list, np.ndarray)): path = [os.path.abspath(f) for f in path] - if verbose: + if verbose: # pragma: no cover print('Reading inifiles: \n{}'.format(pprint.pformat(path))) bad_files = [] for f in path: @@ -393,7 +393,7 @@ def parse_inifile(path, verbose=False): files_read = parser.read(path) # Catch silent errors from configparser.read - if len(files_read) == 0: + if len(files_read) == 0: # pragma: no cover raise ValueError("No files were read successfully. Given {}.". format(path)) return parser @@ -425,7 +425,7 @@ def simprop_kwargs_from_ini(path, only=None, verbose=False): parser_dict = {} for section in parser: # skip default section - if section == 'DEFAULT': + if section == 'DEFAULT': # pragma: no cover continue if only is not None: if section != only: @@ -494,6 +494,10 @@ def simprop_kwargs_from_ini(path, only=None, verbose=False): parser_dict[section] = hooks_list + if section == "grid_paths": + + parser_dict.update(sect_dict) + return parser_dict @@ -534,7 +538,7 @@ def binarypop_kwargs_from_ini(path, verbose=False): if pop_kwargs['use_MPI'] == True and JOB_ID is not None: raise ValueError('MPI must be turned off for job arrays.') exit() - elif pop_kwargs['use_MPI'] == True: + elif pop_kwargs['use_MPI'] == True: # pragma: no cover from mpi4py import MPI pop_kwargs['comm'] = MPI.COMM_WORLD # MPI needs to be turned off for job arrays @@ -542,7 +546,7 @@ def binarypop_kwargs_from_ini(path, verbose=False): pop_kwargs['comm'] = None # Check if we are running as a job array - if JOB_ID is not None and pop_kwargs['use_MPI'] is True: + if JOB_ID is not None and pop_kwargs['use_MPI'] is True: # pragma: no cover raise ValueError('MPI must be turned off for job arrays.') elif JOB_ID is not None: pop_kwargs['JOB_ID'] = np.int64(os.environ['SLURM_ARRAY_JOB_ID']) @@ -570,7 +574,7 @@ def binarypop_kwargs_from_ini(path, verbose=False): if pop_kwargs['include_S1']: pop_kwargs['S1_kwargs'] = S1_kwargs - elif section == 'SingleStar_2_output': + elif section == 'SingleStar_2_output': # pragma: no branch S2_kwargs = dict() for key, val in parser[section].items(): S2_kwargs[key] = ast.literal_eval(val) diff --git a/posydon/popsyn/norm_pop.py b/posydon/popsyn/norm_pop.py index 4f2946504e..915a95ae67 100644 --- a/posydon/popsyn/norm_pop.py +++ b/posydon/popsyn/norm_pop.py @@ -12,6 +12,7 @@ from posydon.popsyn.distributions import ( FlatMassRatio, LogUniform, + PowerLawMassRatio, PowerLawPeriod, Sana12Period, ) @@ -75,6 +76,11 @@ def get_mass_ratio_pdf(kwargs): Requires the following parameters: - `secondary_mass_min` - `secondary_mass_max` + - `power_law_mass_ratio` for `secondary_mass_scheme` + Requires the following parameters: + - `mass_ratio_slope`: exponent alpha in q^alpha + - `q_min` (optional, default 0.05) + - `q_max` (optional, default 1.0) Parameters ---------- @@ -93,30 +99,41 @@ def get_mass_ratio_pdf(kwargs): def get_pdf_for_m1(m1): m1 = np.atleast_1d(m1) minimum = np.max( - [kwargs['secondary_mass_min'] / m1, np.ones(len(m1))*0.05], + [kwargs['secondary_mass_min'] / m1, np.zeros(len(m1))], axis=0) maximum = np.min( [kwargs['secondary_mass_max'] / m1, np.ones(len(m1))], axis=0) - q_dist = lambda q: np.where((q >= minimum) & (q <= maximum), + # Use FlatMassRatio distribution class + q_dist = lambda q: np.where((q > minimum) & (q <= maximum), 1/(maximum - minimum), 0) return q_dist q_pdf = lambda q, m1: get_pdf_for_m1(m1)(q) elif kwargs['secondary_mass_scheme'] == 'flat_mass_ratio': # flat mass ratio, where bounds are given - q_pdf = lambda q, m1=None: np.where( - (q > kwargs['q_min']) & (q <= kwargs['q_max']), - 1/(kwargs['q_max'] - kwargs['q_min']), - 0) + from posydon.popsyn.distributions import FlatMassRatio + q_dist = FlatMassRatio(q_min=kwargs['q_min'], q_max=kwargs['q_max']) + q_pdf = lambda q, m1=None: q_dist.pdf(q) + + elif kwargs['secondary_mass_scheme'] == 'power_law_mass_ratio': + from posydon.popsyn.distributions import PowerLawMassRatio + q_dist = PowerLawMassRatio( + alpha=kwargs['mass_ratio_slope'], + q_min=kwargs.get('q_min', 0.05), + q_max=kwargs.get('q_max', 1.0), + ) + q_pdf = lambda q, m1=None: q_dist.pdf(q) else: # default to a flat distribution Pwarn("The secondary_mass_scheme is not defined use a flat mass ratio " "distribution in (0,1].", "UnsupportedModelWarning") - q_pdf = lambda q, m1=None: np.where((q > 0.0) & (q<=1.0), 1, 0) + from posydon.popsyn.distributions import FlatMassRatio + q_dist = FlatMassRatio(q_min=0.0, q_max=1.0) + q_pdf = lambda q, m1=None: q_dist.pdf(q) return q_pdf def get_binary_fraction_pdf(kwargs): @@ -305,7 +322,68 @@ def calculate_model_weights(pop_data, M_sim, simulation_parameters, population_parameters): - '''reweight each model in the simulation to the requested population''' + """Reweight each model in the simulation to the requested population + + Uses the PDF of the simulation and the PDF of the requested population to calculate + the weights for each model in the simulation to match the requested population. + + Parameters + ---------- + pop_data : dict + Dictionary containing the population data. + This needs to contain the following keys: + - `S1_mass_i`: initial mass of the primary + - `S2_mass_i`: initial mass of the secondary + - `orbital_period_i`: initial orbital period + - `state_i`: initial state of the system (e.g. 'initially_single_star' for single stars) + These are used to calculate the PDF for each model in the simulation. + + M_sim : float + Mass of the simulation + simulation_parameters : dict + Dictionary containing the simulation parameters. + This is used to calculate the PDF of the simulation. + The parameters in this dictionary are the initial conditions of the population. + The following parameters are required to be present in the dictionary: + - `primary_mass_scheme` + - `primary_mass_min` + - `primary_mass_max` + - `secondary_mass_scheme` + - `secondary_mass_min` + - `secondary_mass_max` + - `binary_fraction_scheme` + - `binary_fraction_const` + - `orbital_scheme` + - `orbital_period_scheme` or `orbital_separation_scheme` depending on the `orbital_scheme` + - `orbital_period_min` and `orbital_period_max` or `orbital_separation_min` and `orbital_separation_max` depending on the `orbital_scheme` + - `power_law_slope` if `orbital_period_scheme` is `power_law` + - `q_min` and `q_max` if `secondary_mass_scheme` is `flat_mass_ratio` + + population_parameters : dict + Dictionary containing the population parameters, which is the requested population to which we want to reweight the simulation. This is used to calculate the PDF of the requested population. + The parameters in this dictionary are the initial conditions of the population you want to reweight to. + The following parameters are required to be present in the dictionary: + - `primary_mass_scheme` + - `primary_mass_min` + - `primary_mass_max` + - `secondary_mass_scheme` + - `secondary_mass_min` + - `secondary_mass_max` + - `binary_fraction_scheme` + - `binary_fraction_const` + - `orbital_scheme` + - `orbital_period_scheme` or `orbital_separation_scheme` depending on the `orbital_scheme` + - `orbital_period_min` and `orbital_period_max` or `orbital_separation_min` and `orbital_separation_max` depending on the `orbital_scheme` + - `power_law_slope` if `orbital_period_scheme` is `power_law` + - `q_min` and `q_max` if `secondary_mass_scheme` is `flat_mass_ratio` + + Returns + ------- + output : ndarray of floats + Weights for each model in the simulation to match the requested population + This has the units of likelihood of the systems per unit mass (Msun^-1). + + """ f_b_sim = simulation_parameters['binary_fraction_const'] f_b_pop = population_parameters['binary_fraction_const'] diff --git a/posydon/popsyn/population_params_default.ini b/posydon/popsyn/population_params_default.ini index 2e06a9b3c4..53cec0cf75 100644 --- a/posydon/popsyn/population_params_default.ini +++ b/posydon/popsyn/population_params_default.ini @@ -170,6 +170,18 @@ absolute_import = None # if given, use an absolute filepath to user defined step: # ['', ''] + do_wind_loss = False + # True, False + do_tides = False + # True, False + do_gravitational_radiation = False + # True, False + do_magnetic_braking = False + # True, False + magnetic_braking_mode = 'RVJ83' + # 'RVJ83', 'M15', 'G18', 'CARB' + do_stellar_evolution_and_spin_from_winds = True + # True, False matching_method = 'minimize' # 'minimize', 'root' matching_tolerance = 1e-2 @@ -187,6 +199,38 @@ absolute_import = None # if given, use an absolute filepath to user defined step: # ['', ''] + do_wind_loss = False + # True, False + do_tides = False + # True, False + do_gravitational_radiation = False + # True, False + do_magnetic_braking = False + # True, False + magnetic_braking_mode = 'RVJ83' + # 'RVJ83', 'M15', 'G18', 'CARB' + do_stellar_evolution_and_spin_from_winds = True + # True, False + list_for_matching_HMS = [["mass", "center_h1", "he_core_mass"], + [20.0, 1.0, 10.0], + ["log_min_max", "min_max", "min_max"], + [None, None], [0, None]] + # A list of mixed type that specifies properties of the matching + # process for HMS stars. This list has the following structure: + # list_for_matching = [[matching attr. names], [rescale_factors], + # [scaling method], [mass_bnds], [age_bnds]] + list_for_matching_postMS = [["mass", "center_he4", "he_core_mass"], + [20.0, 1.0, 10.0], + ["log_min_max", "min_max", "min_max"], + [None, None], [0, None]] + # As above, a list that specifies properties of the matching + # process for post-MS stars. + list_for_matching_HeStar = [["he_core_mass", "center_he4"], + [10.0, 1.0], + ["min_max" , "min_max"], + [None, None], [0, None]] + # As above, a list that specifies properties of the matching + # process for HeMS stars. record_matching = False # True, False verbose = False @@ -199,6 +243,18 @@ absolute_import = None # if given, use an absolute filepath to user defined step: # ['', ''] + do_wind_loss = False + # True, False + do_tides = False + # True, False + do_gravitational_radiation = False + # True, False + do_magnetic_braking = False + # True, False + magnetic_braking_mode = 'RVJ83' + # 'RVJ83', 'M15', 'G18', 'CARB' + do_stellar_evolution_and_spin_from_winds = True + # True, False matching_method = 'minimize' # 'minimize', 'root' matching_tolerance = 1e-2 @@ -216,6 +272,22 @@ absolute_import = None # if given, use an absolute filepath to user defined step: # ['', ''] + matching_method = 'minimize' + # 'minimize', 'root' + matching_tolerance = 1e-2 + # float, DEF: 1e-2 + matching_tolerance_hard = 1e-1 + # float, DEF: 1e-1 + list_for_matching_HMS = [["mass", "center_h1", "he_core_mass"], + [20.0, 1.0, 10.0], + ["log_min_max", "min_max", "min_max"], + [0.1, 300], [0.0, None]] + # A list of mixed type that specifies properties of the matching + # process for HMS stars. This list has the following structure: + # list_for_matching = [[matching attr. names], [rescale_factors], + # [scaling method], [mass_bnds], [age_bnds]] + record_matching = False + # True, False prescription = 'alpha-lambda' # 'alpha-lambda' common_envelope_efficiency = 1.0 @@ -246,8 +318,6 @@ common_envelope_option_after_succ_CEE = 'two_phases_stableMT' # 'two_phases_stableMT' 'one_phase_variable_core_definition' # 'two_phases_windloss' - record_matching = False - # True, False verbose = False # True, False @@ -325,6 +395,18 @@ absolute_import = None # if given, use an absolute filepath to user defined step: # ['', ''] + do_wind_loss = False + # True, False + do_tides = False + # True, False + do_gravitational_radiation = True + # True, False + do_magnetic_braking = False + # True, False + magnetic_braking_mode = 'RVJ83' + # 'RVJ83', 'M15', 'G18', 'CARB' + do_stellar_evolution_and_spin_from_winds = False + # True, False n_o_steps_history = None # None or int (0, inf) @@ -352,6 +434,31 @@ kwargs_2 = {} # dict +[grid_paths] + # You shouldn't need to edit these paths unless you want to specify paths to + # your own custom MESA grids. Leaving these as None will tell POSYDON to use + # the default paths inside of your $PATH_TO_POSYDON_DATA environment variable. + + HMS_HMS_path = None + # A string representing the path to your grid HDF5 files for HMS-HMS evolution. + # If None, the default $PATH_TO_POSYDON_DATA/HMS-HMS/*_Zsun.h5 files will be used + CO_HMS_RLO_path = None + # A string representing the path to your grid HDF5 files for CO-HMS evolution. + # If None, the default $PATH_TO_POSYDON_DATA/CO-HMS_RLO/*_Zsun.h5 files will be used + CO_HeMS_path = None + # A string representing the path to your grid HDF5 files for CO-HeMS evolution. + # If None, the default $PATH_TO_POSYDON_DATA/CO-HeMS/*_Zsun.h5 files will be used + CO_HeMS_RLO_path = None + # A string representing the path to your grid HDF5 files for CO-HeMS evolution. + # If None, the default $PATH_TO_POSYDON_DATA/CO-HeMS_RLO/*_Zsun.h5 files will be used + single_HMS_path = None + # A string representing the path to your grid HDF5 files for single star HMS evolution. + # If None, the default $PATH_TO_POSYDON_DATA/single_HMS/*_Zsun.h5 files will be used. + # These are used by detached binaries and single stars in your population. + single_HeMS_path = None + # A string representing the path to your grid HDF5 files for single star HeMS evolution. + # If None, the default $PATH_TO_POSYDON_DATA/single_HeMS/*_Zsun.h5 files will be used. + ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;;;;;;;; BinaryPopulation ;;;;;;;;;; @@ -567,6 +674,8 @@ 'he4_mass_ej', #'M4', #'mu4', + #'Xi', + #'sc', 'avg_c_in_c_core_at_He_depletion', 'co_core_mass_at_He_depletion', #'m_core_CE_1cent', @@ -655,6 +764,8 @@ 'he4_mass_ej', #'M4', #'mu4', + #'Xi', + #'sc', 'avg_c_in_c_core_at_He_depletion', 'co_core_mass_at_He_depletion', #'m_core_CE_1cent', diff --git a/posydon/popsyn/rate_calculation.py b/posydon/popsyn/rate_calculation.py index 4b57935d0b..de98212cb7 100644 --- a/posydon/popsyn/rate_calculation.py +++ b/posydon/popsyn/rate_calculation.py @@ -1,4 +1,4 @@ -__author__ = [ +__authors__ = [ "Simone Bavera ", "Max Briel ", ] @@ -204,7 +204,7 @@ def get_redshift_bin_centers(delta_t): # compute the redshift z_birth = [] for i in range(n_redshift_bin_centers + 1): - # z_at_value is from astopy.cosmology + # z_at_value is from astropy.cosmology z_birth.append(z_at_value(cosmology.age, t_birth[i] * u.Gyr)) z_birth = np.array(z_birth) diff --git a/posydon/popsyn/synthetic_population.py b/posydon/popsyn/synthetic_population.py index 9c89c8971c..3fc21ecfa7 100644 --- a/posydon/popsyn/synthetic_population.py +++ b/posydon/popsyn/synthetic_population.py @@ -155,12 +155,12 @@ def evolve(self, overwrite=False): if os.path.exists(pop.kwargs["temp_directory"]) and not overwrite: raise FileExistsError(f"The {pop.kwargs['temp_directory']} directory already exists! Please remove it or rename it before running the population.") elif os.path.exists(pop.kwargs["temp_directory"]) and overwrite: - if self.verbose: + if self.verbose: # pragma: no cover print(f"Removing pre-existing {pop.kwargs['temp_directory']} directory...") shutil.rmtree(pop.kwargs["temp_directory"]) pop.evolve(optimize_ram=True) - if pop.comm is None: + if pop.comm is None: # pragma: no cover self.merge_parallel_runs(pop, overwrite) def merge_parallel_runs(self, pop, overwrite=False): @@ -179,7 +179,7 @@ def merge_parallel_runs(self, pop, overwrite=False): f"{Zstr}_Zsun_population.h5 already exists!\n" +"Files were not merged. You can use PopulationRunner.merge_parallel_runs() to merge the files manually." ) - elif os.path.exists(fname) and overwrite: + elif os.path.exists(fname) and overwrite: # pragma: no cover if self.verbose: print(f"Removing pre-exisiting {fname}...") os.remove(fname) @@ -191,12 +191,12 @@ def merge_parallel_runs(self, pop, overwrite=False): for f in os.listdir(path_to_batch) if os.path.isfile(os.path.join(path_to_batch, f)) ] - if self.verbose: + if self.verbose: # pragma: no cover print(f"Merging {len(tmp_files)} files...") pop.combine_saved_files(fname, tmp_files) - if self.verbose: + if self.verbose: # pragma: no cover print("Files merged!") print(f"Saved merged files to {fname}...") print(f"Removing files in {path_to_batch}...") @@ -377,7 +377,7 @@ def __init__(self, filename, verbose=False, chunksize=100000): if "/history_lengths" in store.keys(): self.lengths = store["history_lengths"] else: - if self.verbose: + if self.verbose: # pragma: no cover print( "history_lengths not found in population file. Calculating history lengths..." ) @@ -388,7 +388,7 @@ def __init__(self, filename, verbose=False, chunksize=100000): tmp_df.rename(columns={"index": "length"}, inplace=True) self.lengths = tmp_df del tmp_df - if self.verbose: + if self.verbose: # pragma: no cover print("Storing history lengths in population file!") store.put("history_lengths", pd.DataFrame(self.lengths), format="table") del history_events @@ -725,7 +725,7 @@ def __getitem__(self, key): else: raise ValueError("Invalid key type!") - def __len__(self): + def __len__(self): # pragma: no cover """ Get the number of systems in the oneline table. @@ -736,7 +736,7 @@ def __len__(self): """ return self.number_of_systems - def head(self, n=10): + def head(self, n=10): # pragma: no cover """Get the first n rows of the oneline table. Parameters @@ -751,7 +751,7 @@ def head(self, n=10): """ return super().head("oneline", n) - def tail(self, n=10): + def tail(self, n=10): # pragma: no cover """ Get the last n rows of the oneline table. @@ -767,7 +767,7 @@ def tail(self, n=10): """ return super().tail("oneline", n) - def __repr__(self): + def __repr__(self): # pragma: no cover """ Get a string representation of the oneline table. @@ -778,7 +778,7 @@ def __repr__(self): """ return super().get_repr("oneline") - def _repr_html_(self): + def _repr_html_(self): # pragma: no cover """ Get an HTML representation of the oneline table. @@ -789,7 +789,7 @@ def _repr_html_(self): """ return super().get_html_repr("oneline") - def select(self, where=None, start=None, stop=None, columns=None): + def select(self, where=None, start=None, stop=None, columns=None): # pragma: no cover """Select a subset of the oneline table based on the given conditions. This method allows you to filter and extract a subset of rows from the oneline table stored in an HDF file. @@ -882,7 +882,7 @@ def _save_mass_per_metallicity(self, filename): """ with pd.HDFStore(filename, mode="a") as store: store.put("mass_per_metallicity", self.mass_per_metallicity) - if self.verbose: + if self.verbose: # pragma: no cover print("mass_per_metallicity table written to population file!") def _load_mass_per_metallicity(self, filename): @@ -896,7 +896,7 @@ def _load_mass_per_metallicity(self, filename): """ with pd.HDFStore(filename, mode="r") as store: self.mass_per_metallicity = store["mass_per_metallicity"] - if self.verbose: + if self.verbose: # pragma: no cover print("mass_per_metallicity table read from population file!") def _save_ini_params(self, filename): @@ -1061,7 +1061,7 @@ def __init__( # check if formation channels are present if "/formation_channels" not in keys: - if self.verbose: + if self.verbose: # pragma: no cover print(f"{filename} does not contain formation channels!") self._formation_channels = None else: @@ -1070,7 +1070,7 @@ def __init__( ) # if an ini file is given, read the parameters from the ini file - if ini_file is not None: + if ini_file is not None: # pragma: no cover self.ini_params = binarypop_kwargs_from_ini(ini_file) self._save_ini_params(filename) self._load_ini_params(filename) @@ -1122,7 +1122,7 @@ def __init__( self.solar_metallicities = self.mass_per_metallicity.index.to_numpy() self.metallicities = self.solar_metallicities * Zsun - elif metallicity is not None and ini_file is None: + elif metallicity is not None and ini_file is None: # pragma: no cover raise ValueError( f"{filename} does not contain a mass_per_metallicity table and no ini file was given!" ) @@ -1132,7 +1132,7 @@ def __init__( self.number_of_systems = self.oneline.number_of_systems self.indices = self.history.indices - def __repr__(self): + def __repr__(self): # pragma: no cover """Return a string representation of the object. Returns @@ -1246,19 +1246,19 @@ def export_selection(self, selection, filename, overwrite=False, append=False, h if "/oneline" in store.keys(): last_index_in_file = np.sort(store["oneline"].index)[-1] - elif "/history" in store.keys(): + elif "/history" in store.keys(): # pragma: no cover last_index_in_file = np.sort(store["history"].index)[-1] - if "/history" in store.keys() and self.verbose: + if "/history" in store.keys() and self.verbose: # pragma: no cover print("history in file. Appending to file") - if "/oneline" in store.keys() and self.verbose: + if "/oneline" in store.keys() and self.verbose: # pragma: no cover print("oneline in file. Appending to file") - if "/formation_channels" in store.keys() and self.verbose: + if "/formation_channels" in store.keys() and self.verbose: # pragma: no cover print("formation_channels in file. Appending to file") - if "/history_lengths" in store.keys() and self.verbose: + if "/history_lengths" in store.keys() and self.verbose: # pragma: no cover print("history_lengths in file. Appending to file") # TODO: I need to shift the indices of the binaries or should I reindex them? @@ -1289,7 +1289,7 @@ def export_selection(self, selection, filename, overwrite=False, append=False, h "The population file contains multiple metallicities. Please add a metallicity column to the oneline dataframe!" ) - if self.verbose: + if self.verbose: # pragma: no cover print("Writing selected systems to population file...") # write oneline of selected systems @@ -1313,7 +1313,7 @@ def export_selection(self, selection, filename, overwrite=False, append=False, h index=False, ) - if self.verbose: + if self.verbose: # pragma: no cover print("Oneline: Done") # write history of selected systems @@ -1332,7 +1332,7 @@ def export_selection(self, selection, filename, overwrite=False, append=False, h index=False, ) - if self.verbose: + if self.verbose: # pragma: no cover print("History: Done") # write formation channels of selected systems @@ -1396,7 +1396,7 @@ def formation_channels(self): self.filename, key="formation_channels" ) else: - if self.verbose: + if self.verbose: # pragma: no cover print("No formation channels in the population file!") self._formation_channels = None @@ -1420,7 +1420,7 @@ def calculate_formation_channels(self, mt_history=True): If the mt_history_HMS_HMS column is not present in the oneline dataframe. """ - if self.verbose: + if self.verbose: # pragma: no cover print("Calculating formation channels...") # load the HMS-HMS interp class @@ -1539,7 +1539,7 @@ def get_mt_history(row): self._write_formation_channels(self.filename, df) del df - if self.verbose: + if self.verbose: # pragma: no cover print("formation_channels written to population file!") def _write_formation_channels(self, filename, df): @@ -1567,7 +1567,7 @@ def _write_formation_channels(self, filename, df): min_itemsize={"channel_debug": str_length, "channel": str_length}, ) - def __len__(self): + def __len__(self): # pragma: no cover """Get the number of systems in the population. Returns @@ -1579,7 +1579,7 @@ def __len__(self): return self.number_of_systems @property - def columns(self): + def columns(self): # pragma: no cover """ Returns a dictionary containing the column names of the history and oneline dataframes. @@ -1733,7 +1733,7 @@ def create_transient_population( ) return synth_pop - def plot_binary_evolution(self, index): + def plot_binary_evolution(self, index): # pragma: no cover """Plot the binary evolution of a system This method is not currently implemented. @@ -1805,7 +1805,7 @@ def __init__(self, filename, transient_name, verbose=False, chunksize=100000): self.transient_name = transient_name @property - def population(self): + def population(self): # pragma: no cover """Returns the entire transient population as a pandas DataFrame. This method retrieves the transient population data from a file and returns it as a pandas DataFrame. @@ -1819,7 +1819,7 @@ def population(self): return pd.read_hdf(self.filename, key="transients/" + self.transient_name) @property - def columns(self): + def columns(self): # pragma: no cover """Return the columns of the transient population. Returns: @@ -1882,6 +1882,9 @@ def calculate_model_weights(self, model_weights_identifier, population_parameter This method calculates the model weights of each event in the transient population based on the provided model parameters. It performs various calculations and stores the results in an HDF5 file at the location '/transients/{transient_name}/weights/{model_weights_identifier}'. + The calculated model weights represent the probability of an event per Msun + formed. Thus, it's units are Msun^{-1}. + Parameters ---------- model_weights_identifier : str @@ -1891,6 +1894,10 @@ def calculate_model_weights(self, model_weights_identifier, population_parameter population_parameters : dict, optional Dictionary containing the population parameters. If None, the default population parameters will be used. + Returns + ------- + pd.DataFrame + The model weights of the transient population and have units of Msun^-1. """ if population_parameters is None: population_parameters = {'number_of_binaries': 1000000, @@ -1915,12 +1922,12 @@ def calculate_model_weights(self, model_weights_identifier, population_parameter else: # check for different parameters for key in simulation_parameters.keys(): - if key not in self.ini_params: + if key not in self.ini_params: # pragma: no branch Pwarn((f"Parameter {key} not found in the population" " parameters! Make sure this is intended"), "POSYDONWarning") - if self.verbose: + if self.verbose: # pragma: no cover print("Simulation parameters:") print(simulation_parameters) print("Population parameters:") @@ -1938,7 +1945,7 @@ def calculate_model_weights(self, model_weights_identifier, population_parameter met_indices = tmp_data.index[met_mask] met_indices =np.unique(met_indices) M_sim = self.mass_per_metallicity['simulated_mass'].iloc[i] - if len(met_indices) == 0: + if len(met_indices) == 0: # pragma: no cover continue pop_data = self.oneline.select(where='index in '+str(met_indices.tolist()), columns=['S1_mass_i', 'S2_mass_i', 'orbital_period_i', 'eccentricity_i', 'state_i']) @@ -1972,7 +1979,7 @@ def model_weights(self, model_weights_identifier=None): """Retrieve the model weights of the transient population. This method retrieves the model weights of the transient population based on the provided model weights identifier. - The model weights are stored in an HDF5 file + The model weights are stored in an HDF5 file and have units of Msun^-1. Parameters ---------- @@ -1982,7 +1989,7 @@ def model_weights(self, model_weights_identifier=None): Returns ------- pd.DataFrame - The model weights of the transient population. + The model weights of the transient population in units of Msun^-1. """ if model_weights_identifier is None: @@ -2061,13 +2068,13 @@ def calculate_cosmic_weights(self, SFH_identifier, model_weights, MODEL_in=None) with pd.HDFStore(self.filename, mode="a") as store: if path_in_file + "MODEL" in store.keys(): store.remove(path_in_file + "MODEL") - if self.verbose: + if self.verbose: # pragma: no cover print("Cosmic weights already computed! Overwriting them!") - if path_in_file + "weights" in store.keys(): + if path_in_file + "weights" in store.keys(): # pragma: no branch store.remove(path_in_file + "weights") - if path_in_file + "z_events" in store.keys(): + if path_in_file + "z_events" in store.keys(): # pragma: no branch store.remove(path_in_file + "z_events") - if path_in_file + "birth" in store.keys(): + if path_in_file + "birth" in store.keys(): # pragma: no branch store.remove(path_in_file + "birth") self._write_MODEL_data(self.filename, path_in_file, MODEL) @@ -2117,7 +2124,7 @@ def calculate_cosmic_weights(self, SFH_identifier, model_weights, MODEL_in=None) .index.to_numpy() .flatten() ) - if len(selected_indices) == 0: + if len(selected_indices) == 0: # pragma: no cover continue delay_time = ( @@ -2174,7 +2181,7 @@ def calculate_cosmic_weights(self, SFH_identifier, model_weights, MODEL_in=None) ) return rates - def plot_efficiency_over_metallicity(self, model_weight_identifier, channels=False, **kwargs): + def plot_efficiency_over_metallicity(self, model_weight_identifier, channels=False, **kwargs): # pragma: no cover """ Plot the efficiency over metallicity. @@ -2195,7 +2202,7 @@ def plot_efficiency_over_metallicity(self, model_weight_identifier, channels=Fal efficiency.index.to_numpy() * Zsun, efficiency, channels=channels, **kwargs ) - def plot_delay_time_distribution( + def plot_delay_time_distribution( # pragma: no cover self, model_weights_identifier, metallicity=None, ax=None, bins=100, color="black" ): """ @@ -2273,7 +2280,7 @@ def plot_delay_time_distribution( ax.set_xlabel("Time [yr]") ax.set_ylabel("Number of events/Msun/yr") - def plot_popsyn_over_grid_slice(self, grid_type, met_Zsun, **kwargs): + def plot_popsyn_over_grid_slice(self, grid_type, met_Zsun, **kwargs): # pragma: no cover """ Plot the transients over the grid slice. @@ -2292,7 +2299,7 @@ def plot_popsyn_over_grid_slice(self, grid_type, met_Zsun, **kwargs): pop=self, grid_type=grid_type, met_Zsun=met_Zsun, **kwargs ) - def _write_MODEL_data(self, filename, path_in_file, MODEL): + def _write_MODEL_data(self, filename, path_in_file, MODEL): # pragma: no cover """ Write the MODEL data to the HDFStore file. @@ -2311,7 +2318,7 @@ def _write_MODEL_data(self, filename, path_in_file, MODEL): store.put(path_in_file + "MODEL", pd.DataFrame(MODEL)) else: store.put(path_in_file + "MODEL", pd.DataFrame(MODEL, index=[0])) - if self.verbose: + if self.verbose: # pragma: no cover print("MODEL written to population file!") def efficiency(self, model_weights_identifier, channels=False): @@ -2446,11 +2453,11 @@ def _read_MODEL_data(self, filename): else: self.MODEL = tmp_df.iloc[0].to_dict() - if self.verbose: + if self.verbose: # pragma: no cover print("MODEL read from population file!") @property - def weights(self): + def weights(self): # pragma: no cover """ Retrieves the weights from the HDFStore. @@ -2620,7 +2627,7 @@ def calculate_observable_population(self, observable_func, observable_name): + observable_name in store.keys() ): - if self.verbose: + if self.verbose: # pragma: no cover print("Overwriting observable population!") del store[ "transients/" @@ -2732,7 +2739,7 @@ def intrinsic_rate_density(self): def plot_hist_properties( self, prop, intrinsic=True, observable=None, bins=50, channel=None, **kwargs - ): + ): # pragma: no cover """Plot a histogram of a given property available in the transient population. This method plots a histogram of a given property available in the transient population. @@ -2812,7 +2819,7 @@ def plot_hist_properties( # plot the histogram using plot_pop.plot_hist_properties plot_pop.plot_hist_properties(df, bins=bins, **kwargs) - def plot_intrinsic_rate(self, channels=False, **kwargs): + def plot_intrinsic_rate(self, channels=False, **kwargs): # pragma: no cover """Plot the intrinsic rate density of the transient population.""" plot_pop.plot_rate_density(self.intrinsic_rate_density, channels=channels, **kwargs) @@ -2836,7 +2843,7 @@ def edges_metallicity_bins(self): bin_met[-1] = met_val[-1] + (met_val[-1] - met_val[-2]) / 2.0 bin_met[1:-1] = met_val[:-1] + (met_val[1:] - met_val[:-1]) / 2.0 # one metallicty bin - elif len(met_val) == 1: + elif len(met_val) == 1: # pragma: no branch if self.MODEL["dlogZ"] is None: bin_met[0] = -9 bin_met[-1] = 0 @@ -2845,7 +2852,7 @@ def edges_metallicity_bins(self): bin_met[-1] = met_val[0] + self.MODEL["dlogZ"] / 2.0 elif isinstance(self.MODEL["dlogZ"], list) or isinstance( self.MODEL["dlogZ"], np.array - ): + ): # pragma: no branch bin_met[0] = self.MODEL["dlogZ"][0] bin_met[-1] = self.MODEL["dlogZ"][1] diff --git a/posydon/popsyn/transient_select_funcs.py b/posydon/popsyn/transient_select_funcs.py index 38ca3f6283..c3169e5d9c 100644 --- a/posydon/popsyn/transient_select_funcs.py +++ b/posydon/popsyn/transient_select_funcs.py @@ -82,7 +82,7 @@ def GRB_selection(history_chunk, oneline_chunk, formation_channels_chunk=None, S selection = history_chunk.loc[indices_selection] if S1_S2 == 'S1': S_mask = (selection['S1_state'] == 'BH') & (selection['S1_state'] != 'BH').shift(1) & (selection['step_names'] == 'step_SN') - elif S1_S2 == 'S2': + elif S1_S2 == 'S2': # pragma: no branch S_mask = (selection['S2_state'] == 'BH') & (selection['S2_state'] != 'BH').shift(1) & (selection['step_names'] == 'step_SN') GRB_df_synthetic = pd.DataFrame(index=indices_selection) @@ -94,11 +94,10 @@ def GRB_selection(history_chunk, oneline_chunk, formation_channels_chunk=None, S if S1_S2 == 'S1': columns_pre_post.append('S1_mass') columns.append('S2_mass') - elif S1_S2 == 'S2': + elif S1_S2 == 'S2': # pragma: no branch columns_pre_post.append('S2_mass') columns.append('S1_mass') - for col in columns_pre_post: GRB_df_synthetic[col+'_preSN'] = pre_SN_hist[col].values GRB_df_synthetic[col+'_postSN'] = post_SN_hist[col].values @@ -111,11 +110,11 @@ def GRB_selection(history_chunk, oneline_chunk, formation_channels_chunk=None, S for col in oneline_chunk.columns: GRB_df_synthetic[col] = oneline_chunk.loc[indices_selection][col].values - if any(formation_channels_chunk != None): + if formation_channels_chunk is not None: formation_channels_chunk = formation_channels_chunk.loc[indices_selection] if S1_S2 == 'S1': GRB_df_synthetic['channel'] = formation_channels_chunk['channel'].str.split('_CC1').str[0].apply(lambda x: x+'_CC1') - elif S1_S2 == 'S2': + elif S1_S2 == 'S2': # pragma: no branch GRB_df_synthetic['channel'] = formation_channels_chunk['channel'].str.split('_CC2').str[0].apply(lambda x: x+'_CC2') # calculate the time! @@ -294,9 +293,9 @@ def DCO_detectability(sensitivity, transient_pop_chunk, z_events_chunk, z_weight These have to be present and a valid value. If not, the function will raise an error! ''' - available_sensitiveies = ['O3actual_H1L1V1', 'O4low_H1L1V1', 'O4high_H1L1V1', 'design_H1L1V1'] - if sensitivity not in available_sensitiveies: - raise ValueError(f'Unknown sensitivity {sensitivity}. Available sensitivities are {available_sensitiveies}') + available_sensitivities = ['O3actual_H1L1V1', 'O4low_H1L1V1', 'O4high_H1L1V1', 'design_H1L1V1'] + if sensitivity not in available_sensitivities: + raise ValueError(f'Unknown sensitivity {sensitivity}. Available sensitivities are {available_sensitivities}') else: sel_eff = selection_effects.KNNmodel(grid_path=PATH_TO_PDET_GRID, sensitivity_key=sensitivity) diff --git a/posydon/tests/active_learning/psy_cris/test_Classifier.py b/posydon/tests/active_learning/psy_cris/test_Classifier.py deleted file mode 100644 index 7c7d7f2953..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_Classifier.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.classify import Classifier -from posydon.active_learning.psy_cris.data import TableData -from posydon.active_learning.psy_cris.synthetic_data.synth_data_3D import get_output_3D -from posydon.active_learning.psy_cris.utils import ( - get_random_grid_df, - get_regular_grid_df, -) - -# True for faster runtime ~ 3s vs 15s -SKIP_GP_TESTS = True - -SKIP_TEST_PLOTS = False -SHOW_PLOTS = False - - -class TestClassifier(unittest.TestCase): - """Test Classifier class on the 3d synthetic data set.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.TEST_DATA_GRID = get_regular_grid_df(N=10 ** 3, dim=3) - cls.TEST_DATA_RAND = get_random_grid_df(N=10 ** 3, dim=3) - cls.UNIQUE_CLASSES = [1, 2, 3, 4, 6, 8] - - cls.TEST_INPUT_POINTS = np.array([[0, 0, 0], [-0.5, 0.5, 0.5]]) - cls.TEST_OUTPUT_TRUTH = get_output_3D(*cls.TEST_INPUT_POINTS.T) - - def setUp(self): - my_kwargs = {"n_neighbors": [2, 3]} - self.table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - self.table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - self.cls_obj_grid = Classifier(self.table_grid) - self.cls_obj_rand = Classifier(self.table_rand) - - def create_TableData(self, data_frame, **kwargs): - files = None - input_cols = ["input_1", "input_2", "input_3"] - output_cols = ["class", "output_1"] - class_col_name = "class" - table_obj = TableData( - files, - input_cols, - output_cols, - class_col_name, - my_DataFrame=data_frame, - verbose=False, - **kwargs - ) - return table_obj - - def test_train_classifiers_1(self): - # di can be [16,19,24,32,41,48,64] to avoid multi class error for gp - data_range = np.arange(0, 1000)[::64] - for cls in ["rbf", "linear", "gp"]: - with self.subTest("Train grid:", classifier=cls): - self.cls_obj_grid.train(cls, di=data_range, verbose=False) - - def test_train_classifiers_2(self): - # skipping gp here - data_range = np.arange(0, 1000)[::10] - for cls in ["rbf", "linear"]: - with self.subTest("Train random:", classifier=cls): - self.cls_obj_rand.train(cls, di=data_range, verbose=False) - - def train_grid_classifiers(self, cls_names, **kwargs): - for name in cls_names: - self.cls_obj_grid.train(name, **kwargs) - - def test_predictions(self): - self.train_grid_classifiers(["linear", "rbf"]) - correct_probabilities = [0.98519722, 1.00000, 0.5883452] - for i, cls_name in enumerate(["rbf", "linear"]): - with self.subTest("Get class predictions:", classifier=cls_name): - tup_out = self.cls_obj_grid.get_class_predictions( - cls_name, self.TEST_INPUT_POINTS, return_ids=False - ) - class_pred, probs, where_not_nan = tup_out - - self.assertTrue( - (3 in class_pred) and (8 in class_pred), - msg="All predictions should contain class 3 and 8", - ) - self.assertAlmostEqual(probs[0], correct_probabilities[i], places=3) - self.assertTrue(len(where_not_nan) == 2, msg="Should not get any nans.") - - @unittest.skipIf(SKIP_GP_TESTS, "GP train / predict - long runtime.") - def test_predictions_gp(self): - self.train_grid_classifiers(["gp"], di=np.arange(0, 1000)[::6]) - tup_out = self.cls_obj_grid.get_class_predictions( - "gp", self.TEST_INPUT_POINTS, return_ids=False - ) - class_pred, probs, where_not_nan = tup_out - - self.assertTrue( - (3 in class_pred) and (8 in class_pred), - msg="All predictions should contain class 3 and 8", - ) - self.assertAlmostEqual(probs[0], 0.5883452, places=3) - self.assertTrue(len(where_not_nan) == 2, msg="Should not get any nans.") - - def test_pred_train_err(self): - # Trying to predict without training - names = ["grid", "random"] - for i, classifier in enumerate([self.cls_obj_grid, self.cls_obj_rand]): - with self.subTest(classifier_name=names[i]): - with self.assertRaisesRegex( - Exception, "No trained interpolators exist" - ): - classifier.get_class_predictions("linear", [[0, 0, 0]]) - - def test_pred_linear_err(self): - self.cls_obj_rand.train("linear", di=np.arange(0, 1000, 50)) - tup_out = self.cls_obj_rand.get_class_predictions( - "lin", [[-1, -1, -1], [1, 1, 1]], return_ids=False - ) - self.assertTrue(len(tup_out[2]) == 0, msg="Should return no valid values.") - - # def test_cross_val(self): - # correct_ans = [67.36842105263158, 66.66666666666666] - # acc, times = self.cls_obj_grid.cross_validate( - # ["rbf", "linear"], 0.05, verbose=False - # ) - # for i, percent_acc in enumerate(acc): - # with self.subTest("Cross Val", i=i, percent_acc=percent_acc): - # self.assertAlmostEqual(acc[i], correct_ans[i], places=3) - - @unittest.skipIf(SKIP_GP_TESTS, "GP cross_val - long runtime.") - def test_cross_val_gp(self): - correct_ans = [73.76470588235294] - acc, times = self.cls_obj_grid.cross_validate(["gp"], 0.15, verbose=False) - for i, percent_acc in enumerate(acc): - with self.subTest("Cross Val", i=i, percent_acc=percent_acc): - self.assertAlmostEqual(acc[i], correct_ans[i], places=3) - - @unittest.skipIf(SKIP_TEST_PLOTS, "Skipping maximum class P plot.") - def test_max_cls_plot(self): - N = int(2e4) if SHOW_PLOTS else 100 - self.train_grid_classifiers(["rbf"]) - fig, axes = self.cls_obj_grid.make_max_cls_plot( - "rbf", ("input_1", "input_2"), N=N, s=3, alpha=0.6, cmap="bone" - ) - if SHOW_PLOTS: - fig.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close(fig) - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/active_learning/psy_cris/test_Regressor.py b/posydon/tests/active_learning/psy_cris/test_Regressor.py deleted file mode 100644 index 94695db620..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_Regressor.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import math -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.data import TableData -from posydon.active_learning.psy_cris.regress import Regressor -from posydon.active_learning.psy_cris.synthetic_data.synth_data_3D import get_output_3D -from posydon.active_learning.psy_cris.utils import ( - get_random_grid_df, - get_regular_grid_df, -) - -# True for faster runtime ~ 1s vs 3s -SKIP_GP_TESTS = False - -SKIP_TEST_PLOTS = False -SHOW_PLOTS = False - - -class TestRegressor(unittest.TestCase): - """Test Regressor class on the 3d synthetic data set.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.TEST_DATA_GRID = get_regular_grid_df(N=10 ** 3, dim=3) - cls.TEST_DATA_RAND = get_random_grid_df(N=10 ** 3, dim=3) - cls.UNIQUE_CLASSES = [1, 2, 3, 4, 6, 8] - - cls.TEST_INPUT_POINTS = np.array([[0, 0, 0], [-0.5, 0.5, 0.5]]) - cls.TEST_OUTPUT_TRUTH = get_output_3D(*cls.TEST_INPUT_POINTS.T) - - def setUp(self): - my_kwargs = {"n_neighbors": [2, 3, 5]} - self.table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - self.table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - self.regr_grid = Regressor(self.table_grid) - self.regr_rand = Regressor(self.table_rand) - - def create_TableData(self, data_frame, **kwargs): - files = None - input_cols = ["input_1", "input_2", "input_3"] - output_cols = ["class", "output_1"] - class_col_name = "class" - table_obj = TableData( - files, - input_cols, - output_cols, - class_col_name, - my_DataFrame=data_frame, - verbose=False, - **kwargs - ) - return table_obj - - def train_grid_regressors(self, names, *args, **kwargs): - for i, regr_name in enumerate(names): - self.regr_grid.train(regr_name, *args, **kwargs) - - def test_train_regressors_1(self): - # di can be [16,19,24,32,41,48,64] to avoid multi class error for gp - classes_to_train = [ [1,2,3,4,6,8], [1,6,8], [1,6,8] ] - col_keys = ["output_1"] - for i, regr in enumerate(["rbf", "linear", "gp"]): - with self.subTest("Train grid:", regressor=regr): - self.regr_grid.train(regr, classes_to_train[i], col_keys, - verbose=False) - - def test_train_regressors_2(self): - # skipping gp here - classes_to_train = [ [1,2,3,4,6,8], [1,6,8], [1,6,8] ] - col_keys = ["output_1"] - for i, regr in enumerate(["rbf", "linear"]): - with self.subTest("Train random:", regressor=regr): - self.regr_rand.train(regr, classes_to_train[i], - col_keys, verbose=False) - - def test_predictions(self): - """Checking for consistency only, not true values.""" - self.train_grid_regressors(["rbf", "linear"], [6], ["output_1"]) - - regr_out = self.regr_grid.get_predictions( - ["rbf", "linear"], [6], ["output_1"], self.TEST_INPUT_POINTS - ) - # RBF check - for i, corr_ans in enumerate([-0.78295781, -0.7401497]): - with self.subTest("RBF regr", correct_ans=corr_ans): - pred = regr_out["RBF"][6]["output_1"][i] - self.assertAlmostEqual(pred, corr_ans, places=5) - - # LinearNDInterpolator check - for i, corr_ans in enumerate([0.13898644, float("Nan")]): - with self.subTest("LinearNDInterpolator regr", correct_ans=corr_ans): - pred = regr_out["LinearNDInterpolator"][6]["output_1"][i] - if i == 1: - self.assertTrue(math.isnan(pred), msg="Prediction should be Nan. {}".format(pred)) - else: - self.assertAlmostEqual(pred, corr_ans, places=5) - - @unittest.skipIf(SKIP_GP_TESTS, "GP train / predict - longer runtime.") - def test_predictions_gp(self): - self.train_grid_regressors(["gp"], [6], ["output_1"], di=None) - - regr_out = self.regr_grid.get_predictions( - ["gp"], [6], ["output_1"], self.TEST_INPUT_POINTS - ) - for i, corr_ans in enumerate([0,0]): - with self.subTest("GaussianProcessRegressor", correct_ans=corr_ans): - pred = regr_out["GaussianProcessRegressor"][6]["output_1"][i] - self.assertAlmostEqual(pred, corr_ans, places=5) - - - def test_pred_train_err(self): - # Trying to predict without training - names = ["grid", "random"] - for i, regressor in enumerate([self.regr_grid, self.regr_rand]): - with self.subTest(classifier_name=names[i]): - with self.assertRaisesRegex( - Exception, "No trained interpolators exist" - ): - regressor.get_predictions( ["linear"], [6], ["output_1"], [[0, 0, 0]]) - - def test_cross_val(self): - corr_ans = [-16.528141760898365, -61.41327730988214, -10.621903342287425] - for index, cls in enumerate([1,6,8]): - with self.subTest("Cross Validation Regression", class_key=cls): - perc_diffs, actual_diffs = self.regr_grid.cross_validate("rbf", cls, "output_1", 0.5 ) - self.assertAlmostEqual( np.mean(perc_diffs), corr_ans[index], places=5) - - plt.hist(perc_diffs, bins=40, density=True, range=(-300,300), - histtype="step", label="class "+str(cls)) - plt.xlabel("Regression Percent Difference") - plt.title("Test Regression CV") - plt.legend() - if SHOW_PLOTS: - plt.show() - plt.close() - - - @unittest.skipIf(SKIP_GP_TESTS, "GP cross_val - longer runtime.") - def test_cross_val_gp(self): - corr_ans = [-36.08266156603506, -100.0, -70.79557229587994] - for index, cls in enumerate([1,6,8]): - with self.subTest("Cross Validation Regression GP", class_key=cls, ans=corr_ans[index]): - perc_diffs, actual_diffs = self.regr_grid.cross_validate("gp", cls, "output_1", 0.5 ) - self.assertAlmostEqual( np.mean(perc_diffs), corr_ans[index], places=5) - - plt.hist(perc_diffs, bins=40, density=True, range=(-300,300), - histtype="step", label="class "+str(cls)) - plt.xlabel("Regression Percent Difference") - plt.title("Test GP Regression CV") - plt.legend() - if SHOW_PLOTS: - plt.show() - plt.close() - - @unittest.skipIf(SKIP_TEST_PLOTS, "All regression data plot.") - def test_max_cls_plot(self): - class_key = 1 - fig = self.regr_grid.plot_regr_data(class_key) - if SHOW_PLOTS: - plt.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close() - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/active_learning/psy_cris/test_Sampler.py b/posydon/tests/active_learning/psy_cris/test_Sampler.py deleted file mode 100644 index 8ac35611dc..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_Sampler.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import math -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.classify import Classifier -from posydon.active_learning.psy_cris.data import TableData -from posydon.active_learning.psy_cris.regress import Regressor -from posydon.active_learning.psy_cris.sample import Sampler -from posydon.active_learning.psy_cris.synthetic_data.synth_data_3D import get_output_3D -from posydon.active_learning.psy_cris.utils import ( - get_random_grid_df, - get_regular_grid_df, -) - -SKIP_TEST_PLOTS = False -SHOW_PLOTS = False - - -class TestSampler(unittest.TestCase): - """Test Sampler class on the 3d synthetic data set.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.TEST_DATA_GRID = get_regular_grid_df(N=10 ** 3, dim=3) - cls.TEST_DATA_RAND = get_random_grid_df(N=10 ** 3, dim=3) - cls.UNIQUE_CLASSES = [1, 2, 3, 4, 6, 8] - - cls.TEST_INPUT_POINTS = np.array([[0, 0, 0], [-0.5, 0.5, 0.5]]) - cls.TEST_OUTPUT_TRUTH = get_output_3D(*cls.TEST_INPUT_POINTS.T) - - def setUp(self): - my_kwargs = {"n_neighbors": [2, 3, 5]} - self.table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - self.regr_grid = Regressor(self.table_grid) - self.cls_grid = Classifier(self.table_grid) - - def create_TableData(self, data_frame, **kwargs): - files = None - input_cols = ["input_1", "input_2", "input_3"] - output_cols = ["class", "output_1"] - class_col_name = "class" - table_obj = TableData( - files, - input_cols, - output_cols, - class_col_name, - my_DataFrame=data_frame, - verbose=False, - **kwargs - ) - return table_obj - - def train_everything_grid(self, cls_names, regr_names): - if cls_names is not None: - self.cls_grid.train_everything(cls_names) - if regr_names is not None: - self.regr_grid.train_everything(regr_names) - - def test_init_0(self): - test_cases = [ - (None, None), - (self.cls_grid, self.regr_grid), - (None, self.regr_grid), - ] - for i, tup_input in enumerate([(None, None), ()]): - with self.subTest("Sampler init", iter=i): - samp = Sampler(*tup_input) - - def test_mcmc(self): - self.train_everything_grid(["rbf"], None) - samp = Sampler(classifier=self.cls_grid, regressor=None) - - steps, acc, rej = samp.run_MCMC( - 15, 0.25, [0, 0, 0], samp.TD_classification, "rbf", T=1, **{"TD_BETA": 2} - ) - self.assertTrue(len(steps) == (acc + 1), msg="steps taken should match acc.") - - def test_ptmcmc(self): - self.train_everything_grid(["rbf"], ["rbf"]) - samp = Sampler(classifier=self.cls_grid, regressor=self.regr_grid) - - chain_step_hist, T_list = samp.run_PTMCMC( - 5, - 15, - samp.TD_classification_regression, - ("rbf", "rbf"), - init_pos=[0, 0, 0], - alpha=0.25, - verbose=False, - trace_plots=False, - TD_BETA=1, - ) - # try with default values - chain_step_hist, T_list = samp.run_PTMCMC( - 10, 15, samp.TD_classification_regression, ("rbf", "rbf"), - verbose=False, trace_plots=False) - - - def test_simple_density_logic(self): - self.cls_grid.train("rbf") - samp = Sampler(classifier=self.cls_grid, regressor=None) - steps, acc, rej = samp.run_MCMC( - 200, 0.25, [0, 0, 0], samp.TD_classification, "rbf", T=1 - ) - acc_pts, rej_pts = samp.do_simple_density_logic(steps, 10, 0.05) - return samp, steps - - def test_get_proposed_points(self): - N = 10 - samp, step_hist = self.test_simple_density_logic() - prop_points, kappa = samp.get_proposed_points(step_hist, N, 0.046) - self.assertTrue(len(prop_points) == N) - - @unittest.skipIf(SKIP_TEST_PLOTS, "Plotting C, C+R target distributions") - def test_TD_plots(self): - self.train_everything_grid(["rbf"], ["rbf"]) - samp = Sampler(classifier=self.cls_grid, regressor=self.regr_grid) - - N = 70 if SHOW_PLOTS else 5 - zed = 0 - x, y = np.meshgrid(np.linspace(-1, 1, N), np.linspace(-1, 1, N)) - z = np.ones(x.shape) * zed - data_points = np.concatenate( - (x.flatten()[:, None], y.flatten()[:, None], z.flatten()[:, None]), axis=1 - ) - - max_probs, pos, cls_keys = samp.get_TD_classification_data("rbf", data_points) - - kwargs = {"TD_BETA": 2, "TD_TAU": 0.5} - cls_regr_td_vals = [ - float(samp.TD_classification_regression(["rbf", "rbf"], dat, **kwargs)) - for dat in data_points - ] - - fig, subs = plt.subplots(1, 2, figsize=(13, 5)) - subs[0].set_title("TD_classification at z = {}".format(zed)) - cls_plot = subs[0].pcolormesh( - x, y, (1 - max_probs).reshape(N, N), shading="auto" - ) - - subs[1].set_title("TD_classification_regression at z = {}".format(zed)) - cls_regr_plot = subs[1].pcolormesh( - x, y, np.array(cls_regr_td_vals).reshape(N, N), shading="auto" - ) - - fig.colorbar(cls_plot, ax=subs[0]) - fig.colorbar(cls_regr_plot, ax=subs[1]) - - for i in range(2): - subs[i].set_xlabel("input_1") - subs[1].set_ylabel("input_2") - subs[i].axis("equal") - - if SHOW_PLOTS: - plt.show() - plt.close() - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/active_learning/psy_cris/test_TableData.py b/posydon/tests/active_learning/psy_cris/test_TableData.py deleted file mode 100644 index ae1795e4e8..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_TableData.py +++ /dev/null @@ -1,247 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.data import TableData -from posydon.active_learning.psy_cris.synthetic_data.synth_data_3D import get_output_3D -from posydon.active_learning.psy_cris.utils import ( - get_random_grid_df, - get_regular_grid_df, -) - -SKIP_TEST_PLOTS = False -SHOW_PLOTS = False - - -class TestTableData(unittest.TestCase): - """Test TableData class on the 3d synthetic data set.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.TEST_DATA_GRID = get_regular_grid_df(N=10 ** 3, dim=3) - cls.TEST_DATA_RAND = get_random_grid_df(N=10 ** 3, dim=3) - cls.UNIQUE_CLASSES = [1, 2, 3, 4, 6, 8] - - def create_TableData(self, data_frame, **kwargs): - files = None - input_cols = ["input_1", "input_2", "input_3"] - output_cols = ["class", "output_1"] - class_col_name = "class" - table_obj = TableData( - files, - input_cols, - output_cols, - class_col_name, - my_DataFrame=data_frame, - verbose=False, - **kwargs - ) - return table_obj - - def test_init_0(self): - td = self.create_TableData(self.TEST_DATA_GRID) - self.assertTrue(isinstance(td, TableData)) - - def test_init_1_grid(self): - my_kwargs = {"n_neighbors": [2, 3]} - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - # N classes - self.assertTrue( - table_grid.num_classes == len(self.UNIQUE_CLASSES), - msg="Should find 6 classes. Found {}".format(table_grid.num_classes), - ) - # Unique classes + APC cols - regr_data = table_grid.get_regr_data(what_data="output") - for cls_key in regr_data.keys(): - with self.subTest("Checking data by class.", cls_key=cls_key): - self.assertIn(cls_key, self.UNIQUE_CLASSES) - self.assertIn("APC2_output_1", regr_data[cls_key].columns) - if cls_key != 2: - self.assertIn("APC3_output_1", regr_data[cls_key].columns) - - def test_init_2_rand(self): - my_kwargs = {"n_neighbors": [2, 3]} - table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - - # N classes - self.assertTrue( - table_rand.num_classes == len(self.UNIQUE_CLASSES), - msg="Should find 6 classes. Found {}".format(table_rand.num_classes), - ) - # Unique classes + APC cols - regr_data = table_rand.get_regr_data(what_data="output") - for cls_key in regr_data.keys(): - with self.subTest("Checking data by class.", cls_key=cls_key): - self.assertIn(cls_key, self.UNIQUE_CLASSES) - self.assertIn("APC2_output_1", regr_data[cls_key].columns) - self.assertIn("APC3_output_1", regr_data[cls_key].columns) - - def test_init_3_clean_data(self): - my_kwargs = {"n_neighbors": [2, 3], "omit_vals": [-1]} - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - - self.assertTrue( - len(table_grid.get_data()) == 729, - msg="Should remove 271 rows from grid data set with value -1.", - ) - - def test_init_4_clean_data(self): - my_kwargs = {"n_neighbors": [2, 3], "omit_vals": [-1]} - table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - - self.assertTrue( - len(table_rand.get_data()) == 1000, - msg="Should remove 0 rows from random data set with value -1.", - ) - - def test_classification_data(self): - my_kwargs = {"n_neighbors": [2, 3], "omit_vals": [-1]} - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - - binary_classification_data = table_grid.get_binary_mapping_per_class() - self.assertTrue( - binary_classification_data.shape == (len(self.UNIQUE_CLASSES), 729) - ) - self.assertTrue( - all( - [ - sum((row == 1) + (row == 0)) == 729 - for row in binary_classification_data - ] - ) - ) - - def test_regression_data_1(self): - my_kwargs = { - "n_neighbors": [2, 3], - } - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - output_dat = table_grid.get_regr_data(what_data="output") - for cls in [5, 7]: - with self.subTest(cls=cls): - # Raise KeyError for classes that shouldn't exist - with self.assertRaisesRegex(KeyError, str(cls)): - output_dat[cls] - - def test_regression_data_2(self): - my_kwargs = {"n_neighbors": [2, 3, 5, 10, 50]} - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - - for key, val in table_grid._regr_dfs_per_class_.items(): - col_names = list(val.columns) - with self.subTest(key=key, data="grid"): - self.assertTrue(any(["APC2" in item for item in col_names])) - if key in [1, 3, 4, 8]: - self.assertTrue(any(["APC50" in item for item in col_names])) - - for key, val in table_rand._regr_dfs_per_class_.items(): - col_names = list(val.columns) - with self.subTest(key=key, data="random"): - self.assertTrue(any(["APC2" in item for item in col_names])) - self.assertTrue(any(["APC3" in item for item in col_names])) - if key in [1, 3, 6, 8]: - self.assertTrue(any(["APC50" in item for item in col_names])) - - @unittest.skipIf( SKIP_TEST_PLOTS, "Test by plotting nearest neighbors skipped") - def test_nearest_neighbhors(self): - my_kwargs = { - "n_neighbors": [2, 3], - } - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - n_neigh = 3 - - dat = np.random.uniform(low=(-1, -1), high=(1, 1), size=(15, 2)) - output = table_grid.find_n_neighbors(dat, [n_neigh]) - - plt.figure(figsize=(4, 4), dpi=100) - plt.title("NearestNeighbors test") - plt.plot( - dat.T[0][0], dat.T[1][0], "+", color="r", markersize=10, label="reference" - ) - plt.scatter(dat.T[0], dat.T[1], label="data") - for i in range(n_neigh): - plt.scatter( - dat.T[0][output[n_neigh][0][i]], - dat.T[1][output[n_neigh][0][i]], - marker="+", - color="lime", - s=29, - label="nearest", - ) - plt.axis("equal") - if SHOW_PLOTS: - plt.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close() - - @unittest.skipIf( SKIP_TEST_PLOTS, "Test for general plotting skipped") - def test_plotting_1(self): - my_kwargs = { - "n_neighbors": [2, 3], - } - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - - zed_val = 1 - N = 40 - X, Y = np.meshgrid(np.linspace(-1, 1, N), np.linspace(-1, 1, N)) - Z = np.ones(X.shape) * zed_val - f_out = get_output_3D(X, Y, Z) - print("ZED VAL: {}".format(zed_val)) - - fig, subs = plt.subplots(1, 3, figsize=(14, 4), dpi=100) - subs[0].set_title("TableData - even grid") - fig, subs[0], handles = table_grid.make_class_data_plot( - fig, - subs[0], - ["input_1", "input_2"], - my_slice_vals={0: (0.9, 1.1)}, - return_legend_handles=True, - ) - subs[0].legend( - handles, table_grid._unique_class_keys_, bbox_to_anchor=(-0.25, 0.5) - ) - - subs[1].set_title("TableData - random points") - fig, subs[1], handles = table_rand.make_class_data_plot( - fig, - subs[1], - ["input_1", "input_2"], - my_slice_vals={0: (0.9, 1.1)}, - return_legend_handles=True, - ) - - subs[2].set_title("Analytic Classification") - subs[2].pcolormesh(X, Y, f_out["class"].values.reshape(N, N), shading="auto") - for i in range(3): - subs[i].axis("equal") - if SHOW_PLOTS: - plt.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close() - - @unittest.skipIf( SKIP_TEST_PLOTS, "Test for plotting 3d skipped") - def test_plotting_2(self): - my_kwargs = { - "n_neighbors": [2, 3], - } - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - table_grid.plot_3D_class_data() - plt.title("plot_3D_class_data") - if SHOW_PLOTS: - plt.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close() - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/active_learning/psy_cris/test_utils.py b/posydon/tests/active_learning/psy_cris/test_utils.py deleted file mode 100644 index bd666c1ae2..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_utils.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import os -import unittest - -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.utils import ( - check_dist, - get_new_query_points, - get_random_grid_df, - get_regular_grid_df, - parse_inifile, -) -from posydon.config import PATH_TO_POSYDON - - -class TestUtils(unittest.TestCase): - """Test methods in utils.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - psy_cris_dir = os.path.join( - PATH_TO_POSYDON, "posydon/active_learning/psy_cris") - cls.INI_FILE_PATH = os.path.join(psy_cris_dir, - "run_params/psycris_default.ini") - - def test_parse_inifile(self): - self.assertTrue(os.path.isfile(self.INI_FILE_PATH), msg="Can't find file.") - my_kwargs = parse_inifile(self.INI_FILE_PATH) - self.assertTrue(isinstance(my_kwargs, dict)) - return my_kwargs - - def test_get_new_query_points(self): - my_kwargs = self.test_parse_inifile() - holder = my_kwargs["TableData_kwargs"] - holder["my_DataFrame"] = get_regular_grid_df(N=10 ** 3, dim=3) - my_kwargs["TableData_kwargs"] = holder - - holder_1 = my_kwargs["Sampler_kwargs"] - holder_1["N_tot"] = 50 - holder_1["T_max"] = 5 - holder_1["verbose"] = False - my_kwargs["Sampler_kwargs"] = holder_1 - - query_pts, preds = get_new_query_points(3, **my_kwargs) - self.assertTrue(len(query_pts) == 3) - - def test_check_dist(self): - original_pts = np.random.uniform( - low=(-1, -1, -1), high=(1, 1, 1), size=(500, 3) - ) - proposed_pts = get_regular_grid_df(N=10 ** 3, dim=3).values[:, 0:3] - result = check_dist(original_pts, proposed_pts, threshold=1e-2) - self.assertTrue( - sum(result) == len(proposed_pts), - msg="All points should not be within 1e-2 of eachother.", - ) - - def test_get_regular_grid_df(self): - for config in [dict(N=1000, dim=3), dict(N=50, dim=2), dict(jitter=True)]: - with self.subTest("regular_grid_df", config=config): - df = get_regular_grid_df(**config) - self.assertTrue(isinstance(df, pd.DataFrame)) - - def test_get_random_grid_df(self): - for config in [dict(N=1000, dim=3), dict(N=50, dim=2)]: - with self.subTest("random_grid_df", config=config): - df = get_random_grid_df(**config) - self.assertTrue(isinstance(df, pd.DataFrame)) - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/binary_evol/CE/test_CEE.py b/posydon/tests/binary_evol/CE/test_CEE.py deleted file mode 100644 index ddeea07647..0000000000 --- a/posydon/tests/binary_evol/CE/test_CEE.py +++ /dev/null @@ -1,655 +0,0 @@ -import os -import unittest - -import numpy as np - -from posydon.binary_evol.binarystar import BinaryStar -from posydon.binary_evol.CE.step_CEE import StepCEE -from posydon.binary_evol.singlestar import SingleStar -from posydon.config import PATH_TO_POSYDON -from posydon.utils import common_functions as cf - -# spaces are read '\\ ' instead of ' ' -PATH_TO_DATA = os.path.join( - PATH_TO_POSYDON, "posydon/tests/data/POSYDON-UNIT-TESTS/binary_evol/CE/") - - -class TestCommonEnvelope(unittest.TestCase): - def test_common_envelope_1(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'default_lambda'} - - CEE = StepCEE(verbose=False, **kwargs) - - # simple binary system which will experience CEE with default_lambda - # option on. - # no profiles needed for this - PROPERTIES_STAR1 = { - 'mass': 10.0, - 'log_R': np.log10(1000.0), - 'he_core_mass': 3.0, - 'he_core_radius': 0.5, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - PROPERTIES_STAR2 = { - 'mass': 2.0, - 'log_R': np.log10(2.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } - giantstar = SingleStar(**PROPERTIES_STAR1) - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar.log_R / cf.roche_lobe_radius( - giantstar.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar.mass, compstar.mass) - PROPERTIES_BINARY = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary = BinaryStar(star_1=giantstar, - star_2=compstar, - **PROPERTIES_BINARY) - - CEE(binary) - #self.assertTrue(binary.event == 'redirect', "CEE test 1 failed") - self.assertTrue( - abs(binary.orbital_period - 5.056621408721529) < - 1.0, "CEE test 1 failed") - self.assertTrue("stripped_He" in binary.star_1.state, "CEE test 1 failed") - - def test_common_envelope_2(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational'} - - CEE = StepCEE(verbose=False, **kwargs) - - # testing with loading a profile of the donor at the moment of CEE to - # calculate the lamda CEE - profile_donor_name = os.path.join(PATH_TO_DATA, - 'simple_giant_profile_for_CEE.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 2.0, - 'log_R': np.log10(2.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.5, - "center_h1" : 0.5, - "center_c12" : 0.01, - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational" - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 2 failed") - - def test_common_envelope_3(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - # testing with loading a profile of the donor at the moment of CEE to - # calculate the lamda CEE, taking into account also the internal energy - # - recombination energy - profile_donor_name = os.path.join( - PATH_TO_DATA, - 'giant_profile_for_CEE_with_recombinationenergy_calculation.npy') - #profile_donor = np.load(profile_donor_name, mmap_mode = "r") - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 2.0, - 'log_R': np.log10(2.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 3 failed") - - def test_common_envelope_4(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join(PATH_TO_DATA, - 'simple_giant_profile_for_CEE.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 10.0, - 'log_R': np.log10(7.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 4 failed") - - def test_common_envelope_5(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - # testing with loading a profile of the donor at the moment of CEE to - # calculate the lamda CEE, taking into account also the internal - # energy - recombination energy - profile_donor_name = os.path.join( - PATH_TO_DATA, - 'giant_profile_for_CEE_with_recombinationenergy_calculation.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 10.0, - 'log_R': np.log10(7.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 5 failed") - - def test_common_envelope_6(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join(PATH_TO_DATA, - 'simple_giant_profile_for_CEE.npy') - #profile_donor = np.genfromtxt(profile_donor_name, skip_header=5, names=True, dtype=None) - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 10., - 'log_R': np.log10(0.0001), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'BH' - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational" # options: 'default_lambda', 'lambda_from_profile_gravitational', 'lambda_from_profile_gravitational_plus_internal', 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print("new state of the star that triggered CEE = ",giantstar_withprofile.state) - #print("new mass of the star that triggered CEE = ",giantstar_withprofile.mass) - print(binary_withprofile.event) - #self.assertTrue(binary_withprofile.event == 'redirect', - # "CEE test 6 failed") - self.assertTrue((10**giantstar_withprofile.log_R - cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, - a_orb=cf.orbital_separation_from_period( - binary_withprofile.orbital_period, giantstar_withprofile.mass, - compstar.mass))), "CEE test 6 failed") - - self.assertTrue( - (abs(binary_withprofile.orbital_period - 0.12123905531545925) < - 1.0), - "CEE test 6 failed") - - def test_common_envelope_7(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - # testing with loading a profile of the donor at the moment of CEE to - # calculate the lamda CEE, taking into account also the internal - # energy - recombination energy - profile_donor_name = os.path.join( - PATH_TO_DATA, - 'giant_profile_for_CEE_with_recombinationenergy_calculation.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 10., - 'log_R': np.log10(0.0001), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'BH' - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - #self.assertTrue(binary_withprofile.event == 'redirect', - # "CEE test 7 failed") - self.assertTrue((10**giantstar_withprofile.log_R - cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, - a_orb=cf.orbital_separation_from_period( - binary_withprofile.orbital_period, giantstar_withprofile.mass, - compstar.mass))), "CEE test 7 failed") - self.assertTrue( - (abs(binary_withprofile.orbital_period - 0.3287114957064215) < - 1.0), - "CEE test 7 failed") - - def test_common_envelope_8(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join( - PATH_TO_DATA, - 'giant_profile_for_CEE_with_recombinationenergy_calculation.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 20., - 'log_R': np.log10(0.0001), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'BH' - } #radius = 10km - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - #self.assertTrue(binary_withprofile.event == 'redirect', - # "CEE test 8 failed") - #self.assertTrue(binary_withprofile.star_1.state == "stripped_He_Core_He_burning", - # "CEE test 8 failed") - self.assertTrue("stripped_He" in binary_withprofile.star_1.state, - "CEE test 8 failed") - self.assertTrue( - (abs(binary_withprofile.orbital_period - 0.7636524660283687) < - 1.0), - "CEE test 8 failed") - - def test_common_envelope_9(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join(PATH_TO_DATA, - 'caseB_CEE_profile.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 28.04, - 'log_R': np.log10(927.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 20., - 'log_R': np.log10(0.0001), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'BH' - } #radius = 10km - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - #self.assertTrue(binary_withprofile.event == 'redirect', - # "CEE test 9 failed event") - self.assertTrue("stripped_He" in binary_withprofile.star_1.state, - "CEE test 9 failed state") - self.assertTrue( - (abs(binary_withprofile.orbital_period - 0.166535882054919) < - 1.0), - "CEE test 9 failed tolerance") - - def test_common_envelope_10(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join(PATH_TO_DATA, - 'caseB_CEE_profile.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 28.04, - 'log_R': np.log10(927.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 2., - 'log_R': np.log10(1.5), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } #radius = 10km - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 10 failed") diff --git a/posydon/tests/binary_evol/DT/test_step_detached.py b/posydon/tests/binary_evol/DT/test_step_detached.py deleted file mode 100644 index 00fbb9db43..0000000000 --- a/posydon/tests/binary_evol/DT/test_step_detached.py +++ /dev/null @@ -1,404 +0,0 @@ -import os -import unittest - -from posydon.binary_evol.binarystar import BinaryStar -from posydon.binary_evol.DT.step_detached import detached_step, diffeq -from posydon.binary_evol.simulationproperties import SimulationProperties -from posydon.binary_evol.singlestar import SingleStar -from posydon.config import PATH_TO_POSYDON -from posydon.utils import common_functions as cf -from posydon.utils import constants as const - -PATH_TO_DATA = os.path.join( - PATH_TO_POSYDON, "posydon/tests/data/POSYDON-UNIT-TESTS/binary_evol/detached/") -#eep_version = "POSYDON" - - -class TestDetached_step(unittest.TestCase): - def test_matching1_root(self): - method = "root" - matching = detached_step(#grid='POSYDON', - path=PATH_TO_DATA, - matching_method=method, - #eep_version=eep_version, - verbose=False) - get_mist0 = detached_step.get_mist0 - get_track_val = detached_step.get_track_val - htrack = True - PROPERTIES_STAR = { - "mass": 60.0, - "log_R": 1.0, - "mdot": -(10.0**(-5)), - "state": "H-rich_Core_H_burning", - "center_he4": 0.48, - "center_h1": 0.5, - "total_moment_of_inertia": 10.0**57, - "log_total_angular_momentum": 52, - "he_core_mass": 0.0, - "surface_he4": 0.28, - "surface_h1": 0.7, - } - m0, t = get_mist0(matching, SingleStar(**PROPERTIES_STAR),htrack) - - self.assertAlmostEqual( - m0, - 64.2410914922183, - places=1, - msg= - "Initial mass in MIST matching not exactly what expected. Should be 64.68538051198551", - ) - self.assertAlmostEqual( - get_track_val(matching, "mass",htrack, m0, t), - 60.0000000000419, - places=3, - msg= - "Current mass in matching not exactly what expected. Should be 60.0000000000419", - ) - self.assertAlmostEqual( - get_track_val(matching, "log_R",htrack, m0, t), - 1.1282247490900794, - places=1, - msg= - "Current log_R in matching not exactly what expected. Should be 1.1282247490900794", - ) - self.assertAlmostEqual( - get_track_val(matching, "center_he4",htrack, m0, t), - 0.4861139708172655, - places=2, - msg= - "Current center_he4 in matching not exactly what expected. Should be 0.4861139708172655", - ) - self.assertAlmostEqual( - get_track_val(matching, "he_core_mass",htrack, m0, t), - 0.0, - places=1, - msg= - "Current he_core_mass matching not exactly what expected. Should be 0.0", - ) - - def test_matching1_minimize(self): - method = "minimize" - matching = detached_step(#grid='POSYDON', - path=PATH_TO_DATA, - matching_method=method, - #eep_version=eep_version, - verbose=False) - get_mist0 = detached_step.get_mist0 - get_track_val = detached_step.get_track_val - htrack = True - PROPERTIES_STAR = { - "mass": 60.0, - "log_R": 1.0, - "mdot": -(10.0**(-5)), - "state": "H-rich_Core_H_burning", - "center_he4": 0.48, - "center_h1": 0.5, - "total_moment_of_inertia": 10.0**57, - "log_total_angular_momentum": 52, - "he_core_mass": 0.0, - "surface_he4": 0.28, - "surface_h1": 0.7, - } - m0, t = get_mist0(matching, SingleStar(**PROPERTIES_STAR),htrack) - - #self.assertAlmostEqual( - # m0, - # 62.88453923015954, - # places= - # 1, # less accuracy because we try to fit more alternative parameters than "root" method at the same time - # msg= - # "Initial mass in MIST matching not exactly what expected. Should be 62.78903050084804", - #) - #self.assertAlmostEqual( - # get_track_val(matching, "mass",htrack, m0, t), - # 59.96234634155157, - # places=1, - # msg= - # "Current mass in matching not exactly what expected. Should be 59.96234634155157", - #) - #self.assertAlmostEqual( - # get_track_val(matching, "log_R",htrack, m0, t), - # 1.0973066851601672, - # places=1, - # msg= - # "Current log_R in matching not exactly what expected. Should be 1.0973066851601672", - #) - #self.assertAlmostEqual( - # get_track_val(matching, "center_he4",htrack, m0, t), - # 0.4252496982220549, - # places=1, - # msg= - # "Current center_he4 in matching not exactly what expected. Should be 0.4252496982220549", - #) - self.assertAlmostEqual( - get_track_val(matching, "he_core_mass",htrack, m0, t), - 0.0, - places=1, - msg= - "Current mass he_core_mass matching not exactly what expected. Should be 0.0", - ) - - def test_only_tides(self): - method = "minimize" - matching = detached_step(#grid='POSYDON', - path=PATH_TO_DATA, - matching_method=method, - #eep_version=eep_version, - verbose=False) - step_ODE_minimize_hist = detached_step( - #grid='POSYDON', - path=PATH_TO_DATA, - n_o_steps_history=30, - #eep_version=eep_version, - matching_method=method, - verbose=False, - ) - step_ODE_minimize_hist_onlytides = detached_step( - #grid='POSYDON', - path=PATH_TO_DATA, - n_o_steps_history=30, - matching_method=method, - #eep_version=eep_version, - verbose=False, - do_wind_loss=False, - do_tides=True, - do_gravitational_radiation=False, - do_magnetic_braking=False, - do_stellar_evolution_and_spin_from_winds=False, - ) - PROPERTIES_STAR1 = {"mass": 10.0, "state": "BH"} - LOW_MS_PROPERTIES_STAR2_non_rot = { - "mass": 8.0, - "log_R": 0.6, - "mdot": -(10.0**(-7)), - "state": "H-rich_Core_H_burning", - "center_he4": 0.28, - "center_h1": 0.7, - "total_moment_of_inertia": 10.0**57, - "log_total_angular_momentum": -10.99, # non-rotating - "he_core_mass": 0.0, - "surface_he4": 0.28, - "surface_h1": 0.7, - } - init_orbital_period = 10 - init_separation = cf.orbital_separation_from_period( - init_orbital_period, - PROPERTIES_STAR1["mass"], - LOW_MS_PROPERTIES_STAR2_non_rot["mass"], - ) - CLOSE_BINARY = { - "time": 5 * 10.0**6, - "orbital_period": init_orbital_period, - "separation": init_separation, - "state": "detached", - "eccentricity": 0.0, - "event": "None", - } - - binary = BinaryStar( - star_1=SingleStar(**PROPERTIES_STAR1), - star_2=SingleStar(**LOW_MS_PROPERTIES_STAR2_non_rot), - **CLOSE_BINARY) - binary.properties.max_simulation_time = 10.0**10 - step_ODE_minimize_hist_onlytides(binary) - - self.assertLessEqual( - getattr(binary, "separation_history")[-1], - getattr(binary, "separation_history")[0], - msg= - "final sepertation with tides only and a non-rotating donor should decrease.", - ) - - def test_tides_vs_tides_and_winds(self): - method = "minimize" - matching = detached_step(#grid='POSYDON', - path=PATH_TO_DATA, - matching_method=method, - #eep_version=eep_version, - verbose=False) - step_ODE_minimize_hist_tides_and_winds = detached_step( - #grid='POSYDON', - path=PATH_TO_DATA, - n_o_steps_history=30, - matching_method=method, - #eep_version=eep_version, - verbose=False, - do_wind_loss=True, - do_tides=True, - do_gravitational_radiation=False, - do_magnetic_braking=False, - do_stellar_evolution_and_spin_from_winds=False, - ) - step_ODE_minimize_hist_onlytides = detached_step( - #grid='POSYDON', - path=PATH_TO_DATA, - n_o_steps_history=30, - matching_method=method, - #eep_version=eep_version, - verbose=False, - do_wind_loss=False, - do_tides=True, - do_gravitational_radiation=False, - do_magnetic_braking=False, - do_stellar_evolution_and_spin_from_winds=False, - ) - - PROPERTIES_STAR1 = {"mass": 10.0, "state": "BH"} - LOW_MS_PROPERTIES_STAR2_non_rot = { - "mass": 8.0, - "log_R": 0.6, - "mdot": -(10.0**(-7)), - "state": "H-rich_Core_H_burning", - "center_he4": 0.28, - "center_h1": 0.7, - "total_moment_of_inertia": 10.0**57, - "log_total_angular_momentum": -10.99, # non-rotating - "he_core_mass": 0.0, - "surface_he4": 0.28, - "surface_h1": 0.7, - } - init_orbital_period = 10 - init_separation = cf.orbital_separation_from_period( - init_orbital_period, - PROPERTIES_STAR1["mass"], - LOW_MS_PROPERTIES_STAR2_non_rot["mass"], - ) - CLOSE_BINARY = { - "time": 5 * 10.0**6, - "orbital_period": init_orbital_period, - "separation": init_separation, - "state": "detached", - "eccentricity": 0.0, - "event": "None", - } - - binary = BinaryStar( - star_1=SingleStar(**PROPERTIES_STAR1), - star_2=SingleStar(**LOW_MS_PROPERTIES_STAR2_non_rot), - **CLOSE_BINARY) - binary_test = BinaryStar( - star_1=SingleStar(**PROPERTIES_STAR1), - star_2=SingleStar(**LOW_MS_PROPERTIES_STAR2_non_rot), - **CLOSE_BINARY) - - binary.properties.max_simulation_time = 10.0**10 - binary_test.properties.max_simulation_time = 10.0**10 - - step_ODE_minimize_hist_tides_and_winds(binary) - step_ODE_minimize_hist_onlytides(binary_test) - - self.assertLessEqual( - getattr(binary_test, "separation_history")[-1], - getattr(binary, "separation_history")[-1], - msg= - "final sepertation with tides only and a non-rotating donor should be lower than including winds too that widen the orbit too.", - ) - - # the following tests are out because they need more EEPS MIST models around their mass. If included they should work. - """ - def test_matching2_root(self): - method = "root" - matching = HMS_detached_step(PATH_TO_EEPS, matching_method=method, verbose=True) - get_mist0 = HMS_detached_step.get_mist0 - get_track_val = HMS_detached_step.get_track_val - PROPERTIES_STAR = { - "mass": 20.0, - "log_R": 2.5, - "mdot": -(10.0 ** (-5)), - "state": "PostMS", - "center_he4": 0.8, - "center_h1": 0.0, - "total_moment_of_inertia": 10.0 ** 57, - "log_total_angular_momentum": 52, - "he_core_mass": 7.0, - "surface_he4": 0.2, - "surface_h1": 0.7, - } - m0, t = get_mist0(matching, SingleStar(**PROPERTIES_STAR)) - - self.assertAlmostEqual( - m0, - 23.092430444226363, - places=5, - msg="Initial mass in MIST matching not exactly what expected. Should be 23.092430444226363", - ) - self.assertAlmostEqual( - get_track_val(matching, "mass", m0, t), - 20.000000000003112, - places=5, - msg="Current mass in matching not exactly what expected. Should be 20.000000000003112", - ) - self.assertAlmostEqual( - get_track_val(matching, "log_R", m0, t), - 3.006026700371161, - places=5, - msg="Current log_R in matching not exactly what expected. Should be 3.006026700371161", - ) - self.assertAlmostEqual( - get_track_val(matching, "center_he4", m0, t), - 0.6426242107646186, - places=5, - msg="Current center_he4 in matching not exactly what expected. Should be 0.6426242107646186", - ) - self.assertAlmostEqual( - get_track_val(matching, "he_core_mass", m0, t), - 6.99999999999913, - places=5, - msg="Current mass he_core_mass matching not exactly what expected. Should be 6.99999999999913", - ) - - def test_matching2_minimize(self): - method = "minimize" - matching = HMS_detached_step(PATH_TO_EEPS, matching_method=method, verbose=True) - get_mist0 = HMS_detached_step.get_mist0 - get_track_val = HMS_detached_step.get_track_val - PROPERTIES_STAR = { - "mass": 20.0, - "log_R": 2.5, - "mdot": -(10.0 ** (-5)), - "state": "PostMS", - "center_he4": 0.8, - "center_h1": 0.0, - "total_moment_of_inertia": 10.0 ** 57, - "log_total_angular_momentum": 52, - "he_core_mass": 7.0, - "surface_he4": 0.2, - "surface_h1": 0.7, - } - m0, t = get_mist0(matching, SingleStar(**PROPERTIES_STAR)) - - self.assertAlmostEqual( - m0, - 23.441360274390483, - places=5, - msg="Initial mass in MIST matching not exactly what expected. Should be 23.441360274390483", - ) - self.assertAlmostEqual( - get_track_val(matching, "mass", m0, t), - 21.931950016322705, - places=5, - msg="Current mass in matching not exactly what expected. Should be 21.931950016322705", - ) - self.assertAlmostEqual( - get_track_val(matching, "log_R", m0, t), - 2.5019520817980974, - places=5, - msg="Current log_R in matching not exactly what expected. Should be 2.5019520817980974", - ) - self.assertAlmostEqual( - get_track_val(matching, "center_he4", m0, t), - 0.9095107795447867, - places=5, - msg="Current center_he4 in matching not exactly what expected. Should be 0.9095107795447867", - ) - self.assertAlmostEqual( - get_track_val(matching, "he_core_mass", m0, t), - 6.86182071726521, - places=5, - msg="Current mass he_core_mass matching not exactly what expected. Should be 6.86182071726521", - ) - """ - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/binary_evol/SN/test_profile_collapse.py b/posydon/tests/binary_evol/SN/test_profile_collapse.py deleted file mode 100644 index eb137cd6fa..0000000000 --- a/posydon/tests/binary_evol/SN/test_profile_collapse.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import unittest - -import posydon.utils.constants as const -from posydon.binary_evol.singlestar import SingleStar -from posydon.binary_evol.SN.profile_collapse import ( - compute_isco_properties, - do_core_collapse_BH, - get_initial_BH_properties, -) -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5") - -if not os.path.isfile(PATH_TO_GRID): - print(PATH_TO_GRID) - raise ValueError("Test grid for unit testing was not found!") - -# constants in CGS -G = const.standard_cgrav -c = const.clight -Mo = const.Msun - - -class TestProfileCollapse(unittest.TestCase): - def test_r_isco(self): - m_BH = 1. * Mo - self.assertAlmostEqual(compute_isco_properties(0., m_BH)[0] / - (G * m_BH / c**2), - 6.0, - places=5) - self.assertAlmostEqual(compute_isco_properties(0.999, m_BH)[0] / - (G * m_BH / c**2), - 1.1817646130335708, - places=5) - - def test_j_isco(self): - m_BH = 1. * Mo - self.assertAlmostEqual(compute_isco_properties(0, m_BH)[1] / - (G * m_BH / c), - 3.464101615137754, - places=5) - self.assertAlmostEqual(compute_isco_properties(0.999, m_BH)[1] / - (G * m_BH / c), - 1.3418378380509774, - places=5) - - def test_radiation_efficiency(self): - m_BH = 1. * Mo - self.assertAlmostEqual((1 - compute_isco_properties(0., m_BH)[2]), - 0.057190958417936644, - places=5) - self.assertAlmostEqual((1 - compute_isco_properties(0.999, m_BH)[2]), - 0.3397940734762088, - places=5) - - def test_low_spinning_He_star(self): - grid = PSyGrid(PATH_TO_GRID) - i = 42 - star = SingleStar(**{'profile': grid[i].final_profile1}) - m_rembar = grid[i].final_values['star_1_mass'] - mass_direct_collapse = 3. # Msun - delta_M = 0.5 # Msun - results = do_core_collapse_BH(star, m_rembar, mass_direct_collapse, - delta_M) - self.assertAlmostEqual(results[0], 13.365071929231409, places=5) - self.assertAlmostEqual(results[1], 8.98074719361575e-09, places=5) - - def test_midly_spinning_He_star(self): - grid = PSyGrid(PATH_TO_GRID) - i = 13 - star = SingleStar(**{'profile': grid[i].final_profile1}) - m_rembar = grid[i].final_values['star_1_mass'] - mass_direct_collapse = 3. # Msun - delta_M = 0.5 # Msun - results = do_core_collapse_BH(star, m_rembar, mass_direct_collapse, - delta_M) - self.assertAlmostEqual(results[0], 5.60832288900688, places=5) - self.assertAlmostEqual(results[1], 0.42583967572001924, places=5) - - def test_rapidly_spinning_He_star(self): - grid = PSyGrid(PATH_TO_GRID) - i = 6 - star = SingleStar(**{'profile': grid[i].final_profile1}) - m_rembar = grid[i].final_values['star_1_mass'] - mass_direct_collapse = 3. # Msun - delta_M = 0.5 # Msun - results = do_core_collapse_BH(star, m_rembar, mass_direct_collapse, - delta_M) - self.assertAlmostEqual(results[0], 38.50844589130613, places=5) - self.assertAlmostEqual(results[1], 0.9835226614001595, places=5) - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/binary_evol/SN/test_step_SN.py b/posydon/tests/binary_evol/SN/test_step_SN.py deleted file mode 100644 index 065a8d0399..0000000000 --- a/posydon/tests/binary_evol/SN/test_step_SN.py +++ /dev/null @@ -1,567 +0,0 @@ -import os -import unittest - -import matplotlib.cm as cm -import numpy as np -import pandas as pd -from scipy.stats import maxwell - -from posydon.binary_evol.binarystar import BinaryStar -from posydon.binary_evol.singlestar import SingleStar -from posydon.binary_evol.SN.step_SN import StepSN -from posydon.config import PATH_TO_POSYDON - -# github action are not cloning the data submoule, data for unit testing -# are therefore stored to the unit test submodule - -path_to_Sukhbold_datasets = os.path.join( - PATH_TO_POSYDON, "posydon/tests/data/POSYDON-UNIT-TESTS/binary_evol/SN/") - -class TestStepSN(unittest.TestCase): - # TODO - ''' - """ - Test WD formation - """ - def test_WD_formation_RAPID(self): - M_CO = 1.3 - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'WD') - self.assertEqual( M_rembar , star_he_core) - - def test_WD_formation_DELAYED(self): - M_CO = 1.3 - SN = StepSN( **{'mechanism' : 'Fryer+12-delayed', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'WD') - self.assertEqual( M_rembar , star_he_core) - - def test_WD_formation_SUKHBOLDN20(self): - M_CO = 1.3 - SN = StepSN( **{'mechanism' : 'Sukhbold+16-engine', - 'engine' : 'N20', - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False, - 'path_to_datasets': path_to_Sukhbold_datasets} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'WD') - self.assertEqual( M_rembar , star_he_core) - - - """ - Test ECSN formation - """ - def test_WD_formation_RAPID(self): - M_CO = 1.38 - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_c_core = star.co_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'ECSN') - self.assertEqual( M_rembar , star_c_core) - - def test_WD_formation_DELAYED(self): - M_CO = 1.38 - SN = StepSN( **{'mechanism' : 'Fryer+12-delayed', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_c_core = star.co_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'ECSN') - self.assertEqual( M_rembar , star_c_core) - - def test_WD_formation_SUKHBOLDN20(self): - M_CO = 1.38 - SN = StepSN( **{'mechanism' : 'Sukhbold+16-engine', - 'engine' : 'N20', - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False, - 'path_to_datasets': path_to_Sukhbold_datasets} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_c_core = star.co_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'ECSN') - self.assertEqual( M_rembar , star_c_core) - - - """ - Test CCSN formation - """ - def test_CCSN_formation_RAPID(self): - M_CO = 2.0 - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'CCSN') - - def test_CCSN_formation_DELAYED(self): - M_CO = 2.0 - SN = StepSN( **{'mechanism' : 'Fryer+12-delayed', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'CCSN') - - def test_CCSN_formation_SUKHBOLDN20(self): - M_CO = 2.0 - SN = StepSN( **{'mechanism' : 'Sukhbold+16-engine', - 'engine' : 'N20', - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False, - 'path_to_datasets': path_to_Sukhbold_datasets} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'CCSN') - - """ - Test PPISN - """ - def test_remnant_mass_PPISN(self): - M_He = 35.0 - - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_He, - 'co_core_mass': M_He * 0.7638113015667961 , - 'he_core_mass': M_He , - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - - star = SingleStar(**star_prop) - - m_PISN = SN.PISN_prescription(star) - - SN.compute_m_rembar(star , m_PISN)[0] - - self.assertTrue( m_PISN > 0.0 ) - self.assertTrue( m_PISN <= 50.0 ) - self.assertEqual(SN.SN_type , 'PPISN') - - """ - Test PISN - """ - def test_remnant_mass_PPISN(self): - M_He = 70.0 - - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_He, - 'co_core_mass': M_He * 0.7638113015667961 , - 'he_core_mass': M_He , - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - - star = SingleStar(**star_prop) - - m_PISN = SN.PISN_prescription(star) - - SN.compute_m_rembar(star , m_PISN)[0] - - self.assertTrue( np.isnan(m_PISN) ) - self.assertEqual(SN.SN_type , 'PISN') - - """ - Test kick distribution for ECSN - """ - def test_kick_ECSN(self): - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - SN_type = np.array([]) - Vkick = np.array([]) - M_co = np.full_like(np.arange(50000)*1.0 , 1.38) - - # The He stars are created - for m_co in M_co: - star_prop = {'mass':m_co / 0.7638113015667961, - 'co_core_mass':m_co, - 'he_core_mass':m_co / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - - - star = SingleStar(**star_prop) - - # The fallback fraction is stracted, this is not a random - # variable then is a fixed value for all explotions - f_fb = SN.compute_m_rembar(star, None)[1] - - # We perform the collapse to extract the SN type of the - # from the code - SN.collapse_star(star) - - if (SN.SN_type == 'CCSN') + (SN.SN_type == 'PPISN') : - kick = SN.generate_kick(star , SN.sigma_kick_CCSN) - sigma = SN.sigma_kick_CCSN - elif SN.SN_type == 'ECSN': - kick = SN.generate_kick(star , SN.sigma_kick_ECSN) - sigma = SN.sigma_kick_ECSN - - - SN_type = np.append(SN_type , SN.SN_type) - Vkick = np.append(Vkick , kick) - - star = None - - dist = (Vkick[SN_type == 'ECSN'] / (1.0 - f_fb)) - - sigma_ECSN = np.round(np.std(dist) / np.sqrt((3*np.pi - 8)/np.pi) , 2) - - print(sigma_ECSN) - - lower = sigma_ECSN <= (SN.sigma_kick_ECSN + 2) - upper = sigma_ECSN >= (SN.sigma_kick_ECSN - 2) - - self.assertTrue( lower ) - self.assertTrue( upper ) - - """ - Test kick distribution for CCSN - """ - def test_kick_CCSN(self): - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - SN_type = np.array([]) - Vkick = np.array([]) - M_co = np.full_like(np.arange(50000)*1.0 , 8.0) - - # The He stars are created - for m_co in M_co: - star_prop = {'mass':m_co / 0.7638113015667961, - 'co_core_mass':m_co, - 'he_core_mass':m_co / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - - - star = SingleStar(**star_prop) - - # The fallback fraction is stracted, this is not a random - # variable then is a fixed value for all explotions - f_fb = SN.compute_m_rembar(star, None)[1] - - # We perform the collapse to extract the SN type of the - # from the code - SN.collapse_star(star) - - if (SN.SN_type == 'CCSN') + (SN.SN_type == 'PPISN') : - kick = SN.generate_kick(star , SN.sigma_kick_CCSN) - sigma = SN.sigma_kick_CCSN - elif SN.SN_type == 'ECSN': - kick = SN.generate_kick(star , SN.sigma_kick_ECSN) - sigma = SN.sigma_kick_ECSN - - - SN_type = np.append(SN_type , SN.SN_type) - Vkick = np.append(Vkick , kick) - - star = None - - dist = (Vkick[SN_type == 'CCSN'] / (1.0 - f_fb)) - - sigma_CCSN = np.round(np.std(dist) / np.sqrt((3*np.pi - 8)/np.pi) , 2) - - print(sigma_CCSN) - - lower = sigma_CCSN <= (SN.sigma_kick_CCSN + 2) - upper = sigma_CCSN >= (SN.sigma_kick_CCSN - 2) - - self.assertTrue( lower ) - self.assertTrue( upper ) - - """ - Test generate kick for expanding orbit - """ - def test_generate_kick(self): - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - fallback = [] - - sep_i = [] - ecc_i = [] - - - sep_f = [] - ecc_f = [] - Vsys_f = [] - - # Loading the test data - def end(binary): - binary.event = 'END' - - properties_star1 = {"mass": 16.200984100257546, "state": "BH", "profile": None} - properties_star2 = {"mass": 5.497560636139926, - "state": "stripped_He_Core_C_depleted", - "profile": None, - 'he_core_mass': 5.497560636139926, - 'co_core_mass': 4.1990989449324205} - - BH = SingleStar(**properties_star1) - He_star = SingleStar(**properties_star2) - properties_binary = { - 'orbital_period' : 6.182118856988261, - 'eccentricity' : 0.0, - 'separation': 39.5265173131476, - 'state' : 'ZAMS', - 'event' : 'CC2', - 'V_sys' : [0, 0, 0], - 'mass_transfer_case' : None, - } - binary = BinaryStar(BH, He_star, **properties_binary) - - pop = [binary] - - - for i in range(len(pop)): - binary = pop[i] - - # We consider that the kicks will have the same direction - # as the velocity of the He star at the periapsis - binary.star_2.natal_kick_array = [None , 0., 0., 0.] - - # We save the orbital separation end eccentricity pre-supernova - sep_i.append( binary.separation ) - ecc_i.append( binary.eccentricity ) - - # We save the fallback fraction f_fb of the remnant - fallback.append(SN.compute_m_rembar(binary.star_2, None)[1]) - - # The orbital kick is applied to the three dimensional orbit - SN.orbital_kick(binary) - - # We save the orbital separation, eccentricity and kick velocity post-supernova - sep_f.append( binary.separation ) - ecc_f.append( binary.eccentricity ) - Vsys_f.append( binary.V_sys ) - - index = [sep_f[i] < sep_i[i] for i in range(len(sep_i))] - - # See if there is any orbit post supernova that shrinked more than one meter - smaller_orbits = np.array(np.array(sep_i)[index] - np.array(sep_f)[index] > 10**-8) - - orbit_comparision = np.sum(smaller_orbits) - - self.assertEqual( orbit_comparision , 0.0) - ''' - - -if __name__ == '__main__': - unittest.main() diff --git a/posydon/tests/binary_evol/test_BinaryStar.py b/posydon/tests/binary_evol/test_BinaryStar.py deleted file mode 100644 index b0f12fe7a7..0000000000 --- a/posydon/tests/binary_evol/test_BinaryStar.py +++ /dev/null @@ -1,127 +0,0 @@ -import os -import unittest - -import numpy as np - -from posydon.binary_evol.binarystar import BinaryStar -from posydon.binary_evol.singlestar import SingleStar -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, - "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5" -) - -if not os.path.exists(PATH_TO_GRID): - raise ValueError("Test grid for unit testing was not found!") - - -class TestSingleStar(unittest.TestCase): - def test_BinaryStar_initialisation(self): - # load an example grid: compact object + He-star - grid = PSyGrid(PATH_TO_GRID) - - # initialise a star with the properties of run i=42 - i = 42 - - kwargs1 = { - 'state': 'stripped_He_Core_C_depleted', - 'metallicity': grid[i].initial_values['Z'], - 'mass': grid[i].history1['star_mass'][-1], - 'log_R': np.nan, - 'log_L': grid[i].history1['log_L'][-1], - 'lg_mdot': np.nan, - 'lg_system_mdot' : np.nan, - 'lg_wind_mdot': np.nan, - 'he_core_mass': grid[i].history1['he_core_mass'][-1], - 'he_core_radius': np.nan, - 'c_core_radius': grid[i].history1['he_core_mass'][-1], - 'o_core_mass': np.nan, - 'o_core_radius': np.nan, - 'center_h1': grid[i].history1['center_h1'][-1], - 'center_he4': grid[i].history1['center_he4'][-1], - 'center_c12': grid[i].history1['center_c12'][-1], - 'center_n14': np.nan, - 'center_o16': np.nan, - 'surface_h1': grid[i].history1['surface_h1'][-1], - 'surface_he4': np.nan, - 'surface_c12': np.nan, - 'surface_n14': np.nan, - 'surface_o16': np.nan, - 'log_LH': grid[i].history1['log_LH'][-1], - 'log_LHe': grid[i].history1['log_LHe'][-1], - 'log_LZ': grid[i].history1['log_LZ'][-1], - 'log_Lnuc': grid[i].history1['log_Lnuc'][-1], - 'c12_c12': grid[i].history1['c12_c12'][-1], - 'surf_avg_omega_div_omega_crit': np.nan, - 'total_moment_of_inertia': np.nan, - 'log_total_angular_momentum': np.nan, - 'spin': np.nan, - 'profile': grid[i].final_profile1 - } - - star_1 = SingleStar(**kwargs1) - - kwargs2 = { - 'state': 'stripped_He_Core_C_depleted', - 'metallicity': grid[i].initial_values['Z'], - 'mass': grid[i].initial_values['star_2_mass'], - 'log_R': np.nan, - 'log_L': np.nan, - 'lg_mdot': np.nan, - 'lg_system_mdot' : np.nan, - 'lg_wind_mdot': np.nan, - 'he_core_mass': np.nan, - 'he_core_radius': np.nan, - 'c_core_radius': np.nan, - 'o_core_mass': np.nan, - 'o_core_radius': np.nan, - 'center_h1': np.nan, - 'center_he4': np.nan, - 'center_c12': np.nan, - 'center_n14': np.nan, - 'center_o16': np.nan, - 'surface_h1': np.nan, - 'surface_he4': np.nan, - 'surface_c12': np.nan, - 'surface_n14': np.nan, - 'surface_o16': np.nan, - 'log_LH': np.nan, - 'log_LHe': np.nan, - 'log_LZ': np.nan, - 'log_Lnuc': np.nan, - 'c12_c12': np.nan, - 'surf_avg_omega_div_omega_crit': np.nan, - 'total_moment_of_inertia': np.nan, - 'log_total_angular_momentum': np.nan, - 'spin': np.nan, - 'profile': None - } - - star_2 = SingleStar(**kwargs2) - - kwargs3 = { - 'state': 'detached', - 'event': 'CC1', - 'time': grid.final_values['age'][i], - 'orbital_period': grid.final_values['period_days'][i], - 'eccentricity': 0., - 'separation': grid.final_values['binary_separation'][i], - 'V_sys': [0, 0, 0], - 'rl_relative_overflow_1' : np.nan, - 'rl_relative_overflow_2' : np.nan, - 'lg_mtransfer_rate': np.nan, - #'mass_transfer_case': None - } - - binary = BinaryStar(star_1, star_2, **kwargs3) - - # check that the above kwars have a history - for item in kwargs3.keys(): - self.assertIsInstance(getattr(binary, item + '_history'), list) - - -if __name__ == '__main__': - unittest.main() diff --git a/posydon/tests/binary_evol/test_SingleStar.py b/posydon/tests/binary_evol/test_SingleStar.py deleted file mode 100644 index 4fabf41981..0000000000 --- a/posydon/tests/binary_evol/test_SingleStar.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -import unittest - -import numpy as np - -from posydon.binary_evol.singlestar import SingleStar -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, - "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5" -) - -if not os.path.exists(PATH_TO_GRID): - raise ValueError("Test grid for unit testing was not found!") - - -class TestSingleStar(unittest.TestCase): - def test_SingleStar_initialisation(self): - # load an example grid: compact object + He-star - grid = PSyGrid(PATH_TO_GRID) - - # initialise a star with the properties of run i=42 - i = 42 - - # all STARPROPERTIES - kwargs = { - 'state': 'stripped_He_Central_C_depletion', - 'metallicity': grid[i].initial_values['Z'], - 'mass': grid[i].history1['star_mass'][-1], - 'log_R': np.nan, - 'log_L': grid[i].history1['log_L'][-1], - 'lg_mdot': np.nan, - 'lg_system_mdot': np.nan, - 'lg_wind_mdot': np.nan, - 'he_core_mass': grid[i].history1['he_core_mass'][-1], - 'he_core_radius': np.nan, - 'c_core_radius': grid[i].history1['he_core_mass'][-1], - 'o_core_mass': np.nan, - 'o_core_radius': np.nan, - 'center_h1': grid[i].history1['center_h1'][-1], - 'center_he4': grid[i].history1['center_he4'][-1], - 'center_c12': grid[i].history1['center_c12'][-1], - 'center_n14': np.nan, - 'center_o16': np.nan, - 'surface_h1': grid[i].history1['surface_h1'][-1], - 'surface_he4': np.nan, - 'surface_c12': np.nan, - 'surface_n14': np.nan, - 'surface_o16': np.nan, - 'log_LH': grid[i].history1['log_LH'][-1], - 'log_LHe': grid[i].history1['log_LHe'][-1], - 'log_LZ': grid[i].history1['log_LZ'][-1], - 'log_Lnuc': grid[i].history1['log_Lnuc'][-1], - 'c12_c12': grid[i].history1['c12_c12'][-1], - 'surf_avg_omega_div_omega_crit': np.nan, - 'total_moment_of_inertia': np.nan, - 'log_total_angular_momentum': np.nan, - 'spin': np.nan, - 'profile': grid[i].final_profile1 - } - - star = SingleStar(**kwargs) - - # check that the above kwars have a history - for item in kwargs.keys(): - self.assertIsInstance(getattr(star, item + '_history'), list) - - -if __name__ == '__main__': - unittest.main() diff --git a/posydon/tests/data/POSYDON-UNIT-TESTS b/posydon/tests/data/POSYDON-UNIT-TESTS deleted file mode 160000 index eaf9d59229..0000000000 --- a/posydon/tests/data/POSYDON-UNIT-TESTS +++ /dev/null @@ -1 +0,0 @@ -Subproject commit eaf9d592291f093cc0095e13c93d431c5b6051da diff --git a/posydon/tests/interpolation/test_data_scaling.py b/posydon/tests/interpolation/test_data_scaling.py deleted file mode 100644 index ba9e29fd73..0000000000 --- a/posydon/tests/interpolation/test_data_scaling.py +++ /dev/null @@ -1,251 +0,0 @@ -from unittest import TestCase - -import numpy as np - -from posydon.interpolation.data_scaling import DataScaler - - -class DataScaler_test(TestCase): - def setUp(self): - self.sc = DataScaler() - self.x = np.array([1,2,3,4]) - self.y = -self.x.copy() - - def test_fit(self): - # not a 1D array - with self.assertRaises(AssertionError): - self.sc.fit([12,2,3]) #list - with self.assertRaises(AssertionError): - self.sc.fit(np.ones((5,1))) # list - # default value 'none' - with self.subTest(i=0): - self.sc.fit(self.x) - self.assertIsInstance(self.sc.params, list) - self.assertEqual(self.sc.method,'none') - self.assertEqual(len(self.sc.params),0) - # min_max - with self.subTest(i=1): - self.sc.fit(self.x, method='min_max') - self.assertEqual(self.sc.method, 'min_max') - self.assertEqual(len(self.sc.params),2) - self.assertEqual(self.sc.params[0], 1) - self.assertEqual(self.sc.params[1], 4) - self.assertEqual(self.sc.lower, -1) - self.assertEqual(self.sc.upper, 1) - # min_max modifying lower/upper - with self.subTest(i=2): - with self.assertRaises(AssertionError): - self.sc.fit(self.x, method='min_max', lower=2) - self.sc.fit(self.x, method='min_max', lower=-2, upper=0.5) - self.assertEqual(self.sc.params[0], 1) - self.assertEqual(self.sc.params[1], 4) - self.assertEqual(self.sc.lower, -2) - self.assertEqual(self.sc.upper, 0.5) - # max_abs - with self.subTest(i=3): - self.sc.fit(self.x, method='max_abs') - self.assertEqual(self.sc.method, 'max_abs') - self.assertEqual(len(self.sc.params), 1) - self.assertEqual(self.sc.params[0], 4) - self.sc.fit(self.y, method='max_abs') # check with negative numbers - self.assertEqual(self.sc.params[0], 4) - # standarize - with self.subTest(i=4): - self.sc.fit(self.x, method='standarize') - self.assertEqual(self.sc.method, 'standarize') - self.assertEqual(len(self.sc.params), 2) - self.assertEqual(self.sc.params[0], np.mean(self.x)) - self.assertEqual(self.sc.params[1], np.std(self.x)) - # log_min_max - with self.subTest(i=5): - self.sc.fit(self.x, method='log_min_max') - self.assertEqual(self.sc.method, 'log_min_max') - self.assertEqual(len(self.sc.params), 2) - self.assertEqual(self.sc.params[0], 0) - self.assertEqual(self.sc.params[1], np.log10(4)) - self.assertEqual(self.sc.lower, -1) - self.assertEqual(self.sc.upper, 1) - # log_min_max modifying lower/upper - with self.subTest(i=6): - with self.assertRaises(AssertionError): - self.sc.fit(self.x, method='log_min_max', lower=2) - self.sc.fit(self.x, method='log_min_max', lower=-2, upper=0.5) - self.assertEqual(self.sc.params[0], 0) - self.assertEqual(self.sc.params[1], np.log10(4)) - self.assertEqual(self.sc.lower, -2) - self.assertEqual(self.sc.upper, 0.5) - # log_max_abs - with self.subTest(i=7): - self.sc.fit(self.x, method='log_max_abs') - self.assertEqual(self.sc.method, 'log_max_abs') - self.assertEqual(len(self.sc.params), 1) - self.assertEqual(self.sc.params[0], np.log10(4)) - self.sc.fit(self.y, method='log_max_abs') # check with negative numbers - self.assertTrue(np.isnan(self.sc.params[0])) - # log_standarize - with self.subTest(i=8): - self.sc.fit(self.x, method='log_standarize') - self.assertEqual(self.sc.method, 'log_standarize') - self.assertEqual(len(self.sc.params), 2) - self.assertEqual(self.sc.params[0], np.mean(np.log10(self.x))) - self.assertEqual(self.sc.params[1], np.std(np.log10(self.x))) - # wrong method string - with self.assertRaises(ValueError): - self.sc.fit(self.x, method='wrong') - - def test_transform(self): - # check .fit has been run first - with self.assertRaises(AssertionError): - sc = DataScaler() - sc.transform(self.x) - # default value 'none' - with self.subTest(i=0): - self.sc.fit(self.x) - xt = self.sc.transform(self.x) - self.assertIsInstance(xt, np.ndarray) - self.assertEqual(len(xt.shape),1) - self.assertEqual(np.sum(np.abs(xt-self.x)),0) - # min_max - with self.subTest(i=1): - self.sc.fit(self.x, method='min_max') - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.min(),self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # min_max modifying lower/upper - with self.subTest(i=2): - self.sc.fit(self.x, method='min_max', lower=-2, upper=0.5) - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # max_abs - with self.subTest(i=3): - self.sc.fit(self.x, method='max_abs') - xt = self.sc.transform(self.x) - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(),-1) - self.sc.fit(self.y, method='max_abs') # check with negative numbers - xt = self.sc.transform(self.x) - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(), -1) - # standarize - with self.subTest(i=4): - self.sc.fit(self.x, method='standarize') - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.mean(),0) - self.assertAlmostEqual(xt.std(), 1) - # log_min_max - with self.subTest(i=5): - self.sc.fit(self.x, method='log_min_max') - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # log_min_max modifying lower/upper - with self.subTest(i=6): - self.sc.fit(self.x, method='log_min_max', lower=-2, upper=0.5) - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # log_max_abs - with self.subTest(i=7): - self.sc.fit(self.x, method='log_max_abs') - xt = self.sc.transform(self.x) - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(), -1) - # log_standarize - with self.subTest(i=8): - self.sc.fit(self.x, method='log_standarize') - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.mean(), 0) - self.assertAlmostEqual(xt.std(), 1) - - def test_fit_and_transform(self): - - # default value 'none' - with self.subTest(i=0): - xt = self.sc.fit_and_transform(self.x) - self.assertIsInstance(xt, np.ndarray) - self.assertEqual(len(xt.shape),1) - self.assertEqual(np.sum(np.abs(xt-self.x)),0) - # min_max - with self.subTest(i=1): - xt = self.sc.fit_and_transform(self.x, method='min_max') - self.assertAlmostEqual(xt.min(),self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # min_max modifying lower/upper - with self.subTest(i=2): - xt = self.sc.fit_and_transform(self.x, method='min_max', lower=-2, upper=0.5) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # max_abs - with self.subTest(i=3): - xt = self.sc.fit_and_transform(self.x, method='max_abs') - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(),-1) - xt = self.sc.fit_and_transform(self.y, method='max_abs') # check with negative numbers - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(), -1) - # standarize - with self.subTest(i=4): - xt = self.sc.fit_and_transform(self.x, method='standarize') - self.assertAlmostEqual(xt.mean(),0) - self.assertAlmostEqual(xt.std(), 1) - # log_min_max - with self.subTest(i=5): - xt = self.sc.fit_and_transform(self.x, method='log_min_max') - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # log_min_max modifying lower/upper - with self.subTest(i=6): - xt = self.sc.fit_and_transform(self.x, method='log_min_max', lower=-2, upper=0.5) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # log_max_abs - with self.subTest(i=7): - xt = self.sc.fit_and_transform(self.x, method='log_max_abs') - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(), -1) - # log_standarize - with self.subTest(i=8): - xt = self.sc.fit_and_transform(self.x, method='log_standarize') - self.assertAlmostEqual(xt.mean(), 0) - self.assertAlmostEqual(xt.std(), 1) - - def test_inv_transform(self): - # default value 'none' - with self.subTest(i=0): - xt = self.sc.fit_and_transform(self.x) - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # min_max - with self.subTest(i=1): - xt = self.sc.fit_and_transform(self.x, method='min_max') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # min_max modifying lower/upper - with self.subTest(i=2): - xt = self.sc.fit_and_transform(self.x, method='min_max', lower=-2, upper=0.5) - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # max_abs - with self.subTest(i=3): - xt = self.sc.fit_and_transform(self.x, method='max_abs') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - xt = self.sc.fit_and_transform(self.y, method='max_abs') # check with negative numbers - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.y)), 0) - # standarize - with self.subTest(i=4): - xt = self.sc.fit_and_transform(self.x, method='standarize') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # log_min_max - with self.subTest(i=5): - xt = self.sc.fit_and_transform(self.x, method='log_min_max') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # log_min_max modifying lower/upper - with self.subTest(i=6): - xt = self.sc.fit_and_transform(self.x, method='log_min_max', lower=-2, upper=0.5) - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # log_max_abs - with self.subTest(i=7): - xt = self.sc.fit_and_transform(self.x, method='log_max_abs') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # log_standarize - with self.subTest(i=8): - xt = self.sc.fit_and_transform(self.x, method='log_standarize') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) diff --git a/posydon/tests/interpolation/test_interpolation.py b/posydon/tests/interpolation/test_interpolation.py deleted file mode 100644 index 3e791481bb..0000000000 --- a/posydon/tests/interpolation/test_interpolation.py +++ /dev/null @@ -1,37 +0,0 @@ -# from unittest import TestCase -# -# import numpy as np -# import posydon.grids.psygrid as psg -# import posydon.interpolation.interpolation as psi -# -# try: -# import gpflow -# except ImportError: -# print("Import Error for TensorFlow and/or GPFlow, most, if not all " -# "features of the psyInterp class will not work, please check your installation " -# "of gpflow or tensorflow or install the correct gpflow by running pip install .[ml]") -# -# -# class Interpolation_test(TestCase): -# def setUp(self): -# # FIX PATH YOU CANNOT GIVE A LOCAL PATH -# self.grid = psg.PSyGrid() -# self.grid.load("/home/juanga/Desktop/data/grid_BH_He_star.h5") -# self.input_keys = self.grid.initial_values.dtype.names -# self.output_keys = self.grid.final_values.dtype.names[2:4] -# self.input_norms = ['log_min_max', 'log_min_max', 'log_min_max'] -# self.output_norms = ['log_standarize', 'log_standarize'] -# -# def test_init(self): -# m = psi.psyInterp(grid=self.grid, -# in_keys=self.input_keys, -# out_keys=self.output_keys, -# in_scaling=self.input_norms, -# out_scaling=self.output_norms) -# self.assertEqual(len(m.in_keys), len(self.input_keys)) -# self.assertEqual(len(m.out_keys), len(self.output_keys)) -# self.assertEqual(m.XYT.shape[0], m.N) -# self.assertEqual(m.XYT.shape[1], m.n_in+m.n_out) -# -# class SGPInterp_test(TestCase): -# pass diff --git a/posydon/tests/popsyn/test_binarypopulation.py b/posydon/tests/popsyn/test_binarypopulation.py deleted file mode 100644 index 8388acb69a..0000000000 --- a/posydon/tests/popsyn/test_binarypopulation.py +++ /dev/null @@ -1,180 +0,0 @@ -import os -import unittest - -import matplotlib.pyplot as plt -import numpy as np - -from posydon.binary_evol.flow_chart import flow_chart -from posydon.binary_evol.simulationproperties import SimulationProperties -from posydon.binary_evol.step_end import step_end -from posydon.popsyn.binarypopulation import BinaryPopulation - - -class TestBinaryPopulation(unittest.TestCase): - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.POP_KWARGS = { - "number_of_binaries": int(500), - "primary_mass_min": 15 - } - - class MyEndStep(step_end): - def __call__(self, binary): - step_end.__call__(binary) - binary.star_1.mass = np.sqrt(binary.star_1.mass) - # change star_1 mass - - cls.SIM_PROP = SimulationProperties( - flow = (flow_chart, {}), - step_HMS_HMS = (MyEndStep, {}), - step_CO_HeMS = (MyEndStep, {}), - step_CO_HMS_RLO = (MyEndStep, {}), - step_detached = (MyEndStep, {}), - step_CE = (MyEndStep, {}), - step_SN = (MyEndStep, {}), - step_end = (MyEndStep, {}), - ) - - def test_init_0(self): - bin_pop = BinaryPopulation() - self.assertTrue(isinstance(bin_pop, BinaryPopulation)) - self.assertTrue(hasattr(bin_pop, 'population_properties')) - self.assertTrue(hasattr(bin_pop, 'entropy')) - - - def test_generate(self): - bin_pop = BinaryPopulation(**self.POP_KWARGS) - for i in range(bin_pop.number_of_binaries): - bin_pop.manager.generate(**self.POP_KWARGS) - self.assertTrue(len(bin_pop) == self.POP_KWARGS["number_of_binaries"]) - - self.assertTrue( [b.star_1.mass > self.POP_KWARGS["primary_mass_min"] - for b in bin_pop] ) - - # def test_generate_initial_binaries(self): - # bin_pop = BinaryPopulation(generate_initial_population=False, **self.POP_KWARGS) - # bin_pop.generate_initial_binaries() - # first_bin = bin_pop[1] - # bin_pop.generate_initial_binaries(overwrite=True) - # second_bin = bin_pop[1] - # self.assertFalse( - # first_bin is second_bin, msg="Binaries should not be the same object." - # ) - # - # def test_gen_init_bin_err(self): - # bin_pop = BinaryPopulation() - # with self.assertRaisesRegex( - # ValueError, "set overwrite=True to overwrite existing population" - # ): - # bin_pop.generate_initial_binaries() - # - def test_sim_properties(self): - bin_pop = BinaryPopulation(population_properties=self.SIM_PROP) - self.assertTrue(isinstance(bin_pop, BinaryPopulation)) - self.assertTrue(bin_pop.population_properties is self.SIM_PROP) - bin_pop.population_properties.load_steps() - return bin_pop - - # def test_evolve(self): - # bin_pop = self.test_sim_properties() - # test_ids = np.arange(0, 15, 1) - # original_bins = bin_pop.copy(ids=test_ids) - # bin_pop.evolve() - # for b in bin_pop: - # with self.subTest("Check event END", binary_ind=b.index): - # self.assertTrue(b.event == "END") - # - # for j, b in enumerate(bin_pop[test_ids]): - # with self.subTest( - # "Check mass changed", - # binary_ind=b.index, - # test_ind=original_bins[j].index, - # ): - # self.assertAlmostEqual( - # b.star_1.mass, np.sqrt(original_bins[j].star_1.mass), places=8 - # ) - - # def test_evolve_binary_population(self): - # # It is unclear how to test multiprocessing at the moment - # POP_KWARGS = {"number_of_binaries": int(500), "primary_mass_min": 15} - # - # def end(binary): - # binary.star_1.mass = np.sqrt(binary.star_1.mass) - # binary.event = "END" - # - # def get_sim_prop(): - # SIM_PROP = SimulationProperties( - # flow={("H-rich_Core_H_burning", "H-rich_Core_H_burning", - # "detached", "ZAMS"): "step_end"}, step_end=end, max_simulation_time=13.7e9) - # return SIM_PROP - # - # bin_pop = BinaryPopulation(population_properties=get_sim_prop, **POP_KWARGS) - # bin_pop.evolve_binary_population(num_batches=4, verbose=True, use_df=True) - - # def test_evolve_each_binary(self): - # bin_pop = self.test_sim_properties() - # for num, evolved_bin in enumerate(bin_pop.evolve_each_binary()): - # with self.subTest("Evolve generator", num=num): - # self.assertTrue(num == evolved_bin.index) - # self.assertTrue(evolved_bin.event == "END") - - # def test_copy(self): - # bin_pop = self.test_sim_properties() - # binary_copy = bin_pop.copy(ids=0) - # self.assertFalse(binary_copy is bin_pop[0]) - # all_binaries_copy = bin_pop.copy() - # self.assertFalse( - # any([copy_b is b for copy_b, b in zip(all_binaries_copy, bin_pop)]) - # ) - - # TODO: step_times is breaking to_df with only initialized binary / pop - # def test_to_df(self): - # bin_pop = self.test_sim_properties() - # self.assertTrue( isinstance(bin_pop.to_df(), pd.DataFrame) ) - - # def test_get_bin_by_index(self): - # bin_pop = self.test_sim_properties() - # test_indicies = [1, 6, 9, 8, 2] - # out_bins = bin_pop.get_binaries_by_index(test_indicies) - # self.assertTrue([b.index for b in out_bins] == test_indicies) - - # def test_bool_and_len(self): - # bin_pop = BinaryPopulation(population_properties=self.SIM_PROP) - # self.assertTrue(bool(bin_pop), msg="True if len self > 0") - # self.assertTrue( - # len(bin_pop) == bin_pop.number_of_binaries, msg="Should be len __binaries" - # ) - - # def test_get_subpopulation(self): - # bin_pop = BinaryPopulation(population_properties=self.SIM_PROP, **self.POP_KWARGS) - # for i in range(200, 300): - # bin_pop[i].star_2.state = "BH" - # subpop = bin_pop.get_subpopulation(star_1_states=None, star_2_states="BH") - # self.assertTrue(all([bin.index == 200 + j for j, bin in enumerate(subpop)])) - # self.assertTrue(all([bin.star_2.state == "BH" for bin in subpop])) - - # def test_pickle_and_load(self): - # bin_pop = BinaryPopulation() - # bin_pop.pickle("saved_population.pkl") - # self.assertTrue(os.path.isfile("saved_population.pkl")) - # - # loaded_pop = BinaryPopulation.load("saved_population.pkl") - # self.assertTrue(isinstance(loaded_pop, BinaryPopulation)) - # - # def test_unique_sim_prop(self): - # bin_pop = BinaryPopulation() - # prop = bin_pop.population_properties - # self.assertTrue( - # all([prop is b.properties for b in bin_pop]), - # msg="All binary properties should map to the same object.", - # ) - - def tearDown(self): - # remove pickled files - if os.path.isfile("saved_population.pkl"): - os.remove("saved_population.pkl") - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/popsyn/test_synthetic_population.py b/posydon/tests/popsyn/test_synthetic_population.py deleted file mode 100644 index 58fc01a033..0000000000 --- a/posydon/tests/popsyn/test_synthetic_population.py +++ /dev/null @@ -1,645 +0,0 @@ -import os -import tempfile - -import numpy as np -import pandas as pd -import pytest - -from posydon.config import PATH_TO_POSYDON -from posydon.popsyn.synthetic_population import ( - History, - Oneline, - Population, - PopulationIO, - PopulationRunner, - parameter_array, -) -from posydon.utils.constants import Zsun - - -# Test the PopulationRunner class -class TestPopulationRunner: - # Test the initialisation of the PopulationRunner class - def test_init(self): - # Test the initialisation of the PopulationRunner class - poprun = PopulationRunner(PATH_TO_POSYDON+'posydon/popsyn/population_params_default.ini', verbose=True) - - # Check if the verbose attribute is set correctly - assert poprun.verbose == True, 'Verbose attribute is not set correctly' - - # Check if the solar_metallicities attribute is a list - assert isinstance(poprun.solar_metallicities, list), 'solar_metallicities attribute is not a list' - - # Check if the binary_populations attribute is a list - assert isinstance(poprun.binary_populations, list), 'binary_populations attribute is not a list' - - def test_init_invalid_ini_file(self): - with pytest.raises(ValueError): - PopulationRunner('invalid_file') - - def test_single_metallicity(self): - # copy the default ini file to a new file - new_ini_file = 'test_population_params.ini' - with open(PATH_TO_POSYDON+'posydon/popsyn/population_params_default.ini', 'r') as file: - data = file.read() - start = data.find('metallicity') - replace_str = 'metallicity = 0.0001' - data_new = data[:start] + replace_str + data[start+22:] - with open(new_ini_file, 'w') as file: - file.write(data_new) - - poprun = PopulationRunner(new_ini_file) - assert poprun.binary_populations[0].metallicity == 0.0001 - - def test_evolve(self, mocker): - mocker.patch('posydon.popsyn.binarypopulation.BinaryPopulation.evolve', return_value=None) - mocker.patch('posydon.popsyn.binarypopulation.BinaryPopulation.combine_saved_files', return_value=None) - - poprun = PopulationRunner(PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # set population to 1 binary - for pop in poprun.binary_populations: - pop.number_of_systems = 1 - - # create a temporary directory with 1e-04_Zsun_batches - os.makedirs('1e-04_Zsun_batches', exist_ok=True) - - poprun.evolve() - assert poprun.binary_populations, 'binary_populations attribute is empty after calling the evolve method' - - - def test_evolve_file_exists(self, mocker): - mocker.patch('posydon.popsyn.binarypopulation.BinaryPopulation.evolve', return_value=None) - mocker.patch('posydon.popsyn.binarypopulation.BinaryPopulation.combine_saved_files', return_value=None) - - # Create a temporary file with the 1e-04_ZSun_population.h5 name - open('1e-04_Zsun_population.h5', 'w').close() - - poprun = PopulationRunner(PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # set population to 1 binary - for pop in poprun.binary_populations: - pop.number_of_systems = 1 - - with pytest.raises(FileExistsError): - poprun.evolve() - - - # Test the evolve method - def test_changed_binarypop(self): - poprun = PopulationRunner(PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # test - assert poprun.binary_populations[0].metallicity == 0.0001 - # Check if the temp_directory attribute is set correctly - assert poprun.binary_populations[0].kwargs['temp_directory'] == '1e-04_Zsun_batches', 'temp_directory attribute is not set correctly' - - # Check if the binary_populations attribute is not empty after calling the evolve method - assert poprun.binary_populations, 'binary_populations attribute is empty after calling the evolve method' - - @classmethod - def teardown_class(cls): - if os.path.exists('1e-04_Zsun_batches'): - os.rmdir('1e-04_Zsun_batches') - if os.path.exists('test_population_params.ini'): - os.remove('test_population_params.ini') - if os.path.exists('1e-04_Zsun_population.h5'): - os.remove('1e-04_Zsun_population.h5') - - - -# Test the History class -class TestHistory: - - @classmethod - def setup_class(cls): - # Set up a test HDF5 file using pandas HDFStore - cls.filename = 'test_population.h5' - with pd.HDFStore(cls.filename, 'w') as store: - # Create a history dataframe - history_data = pd.DataFrame({'time': [1, 2, 3], 'event': ['ZAMS','oRLO1', 'CEE']}) - store.append('history',history_data, data_columns=True) - - cls.filename2 = 'test_population2.h5' - with pd.HDFStore(cls.filename2, 'w') as store: - # Create a history dataframe - history_data = pd.DataFrame({'time': [1, 2, 3], 'event': ['ZAMS','oRLO1', 'CEE']}) - store.append('history',history_data, data_columns=True) - - @classmethod - def teardown_class(cls): - os.remove(cls.filename) - - def setup_method(self): - self.history = History(self.filename, verbose=False, chunksize=10000) - - def test_init(self): - history = History(self.filename2, verbose=True, chunksize=10000) - assert history.filename == self.filename2, 'Filename is not set correctly' - assert history.verbose == True, 'Verbose attribute is not set correctly' - assert history.chunksize == 10000, 'Chunksize attribute is not set correctly' - - expected_lengths = pd.DataFrame(index=[0, 1, 2],data={'index': [1, 1, 1]}) - expected_lengths.index.name = 'index' - pd.testing.assert_frame_equal(history.lengths, expected_lengths, 'Lengths attribute is not equal to the expected dataframe') - - assert history.number_of_systems == 3, 'Number of systems attribute is not None' - assert history.columns.to_list() == ['time', 'event'], 'Columns attribute is not None' - - assert isinstance(history.indices, np.ndarray), 'Indices attribute is not an ndarray' - np.testing.assert_array_equal(history.indices, np.array([0, 1, 2]), 'Indices attribute is not equal to the expected list') - - with pytest.raises(FileNotFoundError): - History('invalid_filename.h5', verbose=False, chunksize=10000) - - def test_init_verbose_true(self): - history = History(self.filename, chunksize=10000) - assert history.verbose == False, 'Verbose attribute is not set correctly' - - - def test_getitem_single_index(self): - df = self.history[0] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 1, 'Returned DataFrame does not have the correct length' - - - def test_getitem_multiple_indices(self): - df = self.history[[0, 1, 2]] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 3, 'Returned DataFrame does not have the correct length' - - def test_getitem_index_array(self): - indices = np.array([0, 1, 2]) - df = self.history[indices] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 3, 'Returned DataFrame does not have the correct length' - - def test_getitem_single_column(self): - column = 'time' - df = self.history[column] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 1, 'Returned DataFrame does not have the correct number of columns' - - def test_getitem_invalid_column(self): - column = 'invalid_column' - with pytest.raises(ValueError): - self.history[column] - - def test_getitem_invalid_keys(self): - columns = ['time', 'invalid_column'] - with pytest.raises(ValueError): - self.history[columns] - - def test_getitem_boolean_mask_numpy(self): - mask = (self.history['time'] > 1).to_numpy() - df = self.history[mask] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - - def test_getitem_boolean_mask_pandas(self): - mask = self.history['time'] > 1 - df = self.history[mask] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - - def test_getitem_multiple_columns(self): - columns = ['time', 'event'] - df = self.history[columns] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 2, 'Returned DataFrame does not have the correct number of columns' - - def test_getitem_invalid_key(self): - with pytest.raises(ValueError): - self.history[{1: 2}] - - def test_len(self): - length = len(self.history) - assert isinstance(length, int), 'Returned object is not an integer' - assert length == 3, 'Returned length is not correct' - - def test_head(self): - n = 2 - df = self.history.head(n) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == n, 'Returned DataFrame does not have the correct length' - - def test_tail(self): - n = 2 - df = self.history.tail(n) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == n, 'Returned DataFrame does not have the correct length' - - def test_repr(self): - representation = self.history.__repr__() - assert isinstance(representation, str), 'Returned object is not a string' - - def test_repr_html(self): - html_representation = self.history._repr_html_() - assert isinstance(html_representation, str), 'Returned object is not a string' - - - def test_select(self): - df = self.history.select(where="time > 1", start=0, stop=10, columns=['event', 'time']) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 2, 'Returned DataFrame does not have the correct number of columns' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - - - -# Test the Oneline class -class TestOneline: - - @classmethod - def setup_class(cls): - # Set up a test HDF5 file using pandas HDFStore - cls.filename = 'test_oneline.h5' - with pd.HDFStore(cls.filename, 'w') as store: - # Create a oneline dataframe - oneline_data = pd.DataFrame({'time': [1, 2, 3], 'S1_mass_i': ['30','30', '70']}) - store.append('oneline', oneline_data, data_columns=True) - - def setup_method(self): - self.oneline = Oneline(self.filename, verbose=False, chunksize=10000) - - def test_init(self): - oneline = Oneline(self.filename, verbose=True, chunksize=5000) - - assert oneline.filename == self.filename, 'Filename is not set correctly' - assert oneline.verbose == True, 'Verbose attribute is not set correctly' - assert oneline.chunksize == 5000, 'Chunksize attribute is not set correctly' - assert oneline.number_of_systems == 3, 'Number of systems attribute is not set correctly' - assert oneline.columns.to_list() == ['time', 'S1_mass_i'], 'Columns attribute is not set correctly' - assert oneline.number_of_systems == 3, 'Number of systems attribute is not set correctly' - - assert isinstance(oneline.indices, np.ndarray), 'Indices attribute is not an ndarray' - np.testing.assert_array_equal(oneline.indices, np.array([0, 1, 2]), 'Indices attribute is not equal to the expected list') - - with pytest.raises(FileNotFoundError): - Oneline('invalid_filename.h5', verbose=False, chunksize=10000) - - def test_getitem_single_index(self): - df = self.oneline[0] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 1, 'Returned DataFrame does not have the correct length' - - def test_getitem_multiple_indices(self): - df = self.oneline[[0, 1, 2]] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 3, 'Returned DataFrame does not have the correct length' - - def test_getitem_index_array(self): - indices = np.array([0, 1, 2]) - df = self.oneline[indices] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 3, 'Returned DataFrame does not have the correct length' - pd.testing.assert_frame_equal(df, - pd.DataFrame({'time': [1, 2, 3], - 'S1_mass_i': ['30','30', '70']}), - 'Returned DataFrame is not equal to the expected DataFrame') - - def test_getitem_slice(self): - df = self.oneline[0:2] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - pd.testing.assert_frame_equal(df, - pd.DataFrame({'time': [1, 2], - 'S1_mass_i': ['30','30']}),) - def test_getitem_endslice(self): - df = self.oneline[:2] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - pd.testing.assert_frame_equal(df, - pd.DataFrame({'time': [1, 2], - 'S1_mass_i': ['30','30']}),) - def test_getitem_beginslice(self): - df = self.oneline[1:] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - pd.testing.assert_frame_equal(df, - pd.DataFrame(index=[1, 2], - data={'time': [2, 3], - 'S1_mass_i': ['30', '70']}),) - def test_getitem_float_indices(self): - with pytest.raises(ValueError): - self.oneline[[0.5, 1.2]] - - - - def test_getitem_single_column(self): - column = 'time' - df = self.oneline[column] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 1, 'Returned DataFrame does not have the correct number of columns' - - def test_getitem_boolean_mask_numpy(self): - mask = (self.oneline['time'] > 1).to_numpy().flatten() - df = self.oneline[mask] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - - def test_getitem_boolean_mask_pandas(self): - mask = self.oneline['time'] > 1 - print(mask) - df = self.oneline[mask] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - - def test_getitem_multiple_columns(self): - columns = ['time', 'S1_mass_i'] - df = self.oneline[columns] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 2, 'Returned DataFrame does not have the correct number of columns' - - def test_getitems_multiple_columns_invalid(self): - columns = ['time', 'invalid_column'] - with pytest.raises(ValueError): - self.oneline[columns] - - def test_getitem_invalid_key_type(self): - with pytest.raises(ValueError): - self.oneline[{1: 2}] - - def test_getitem_invalid_key(self): - with pytest.raises(ValueError): - self.oneline['invalid_key'] - - def test_len(self): - length = len(self.oneline) - assert isinstance(length, int), 'Returned object is not an integer' - assert length == 3, 'Returned length is not correct' - - def test_head(self): - n = 2 - df = self.oneline.head(n) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == n, 'Returned DataFrame does not have the correct length' - - def test_tail(self): - n = 2 - df = self.oneline.tail(n) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == n, 'Returned DataFrame does not have the correct length' - - def test_repr(self): - representation = self.oneline.__repr__() - assert isinstance(representation, str), 'Returned object is not a string' - - def test_repr_html(self): - html_representation = self.oneline._repr_html_() - assert isinstance(html_representation, str), 'Returned object is not a string' - - def test_select(self): - df = self.oneline.select(where="time > 1", start=0, stop=10, columns=['S1_mass_i', 'time']) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 2, 'Returned DataFrame does not have the correct number of columns' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - - - @classmethod - def teardown_class(cls): - os.remove(cls.filename) - -# Test the PopulationIO class -class TestPopulationIO: - - def setup_method(self): - self.filename = "test_population.h5" - - def teardown_method(self): - if os.path.exists(self.filename): - os.remove(self.filename) - - def test_init(self): - pop_io = PopulationIO() - assert pop_io.verbose == False, "Verbose attribute is not set correctly" - - - def test_invalid_filename(self): - pop_io = PopulationIO() - with pytest.raises(ValueError): - pop_io._load_metadata("invalid_filename") - - def test_save_and_load_mass_per_met(self): - population_io = PopulationIO() - population_io.verbose = True - population_io.mass_per_metallicity = pd.DataFrame({"metallicity": [0.02, 0.04], "mass": [1.0, 2.0]}) - population_io._save_mass_per_metallicity(self.filename) - - loaded_io = PopulationIO() - loaded_io.verbose = True - loaded_io._load_mass_per_metallicity(self.filename) - pd.testing.assert_frame_equal(population_io.mass_per_metallicity, loaded_io.mass_per_metallicity) - - def test_save_and_load_ini_params(self): - population_io = PopulationIO() - population_io.ini_params = {i:10 for i in parameter_array} - population_io._save_ini_params(self.filename) - - loaded_io = PopulationIO() - loaded_io._load_ini_params(self.filename) - - assert population_io.ini_params == loaded_io.ini_params, "Loaded ini_params are not equal to the saved ini_params" - - def test_save_and_load_metadata(self): - pop_io = PopulationIO() - pop_io.verbose = True - pop_io.ini_params = {i:10 for i in parameter_array} - pop_io._save_ini_params(self.filename) - pop_io.mass_per_metallicity = pd.DataFrame({"metallicity": [0.02, 0.04], "mass": [1.0, 2.0]}) - pop_io._save_mass_per_metallicity(self.filename) - - load_io = PopulationIO() - load_io._load_metadata(self.filename) - - assert pop_io.ini_params == load_io.ini_params, "Loaded ini_params are not equal to the saved ini_params" - assert pop_io.mass_per_metallicity.equals(load_io.mass_per_metallicity), "Loaded mass_per_metallicity is not equal to the saved mass_per_metallicity" - - - -class TestPopulation: - def setup_method(self): - pass - - def teardown_method(self): - # Clean up any resources used by the test - pass - - def setup_class(self): - self.filename1 = "no_mass_per_met_population.h5" - self.filename2 = "history_population.h5" - self.filename3 = "oneline_population.h5" - self.history_data = pd.DataFrame({'time': [1, 2, 3], 'event': ['ZAMS','oRLO1', 'CEE']}) - self.oneline_data = pd.DataFrame({'time': [1, 2, 3], 'S1_mass_i': [30, 30, 70], 'S2_mass_i': [30, 30, 70.]}) - self.formation_channels = pd.DataFrame({'channel': ['channel1', 'channel2', 'channel3'], 'channel_debug':['debug1', 'debug2', 'debug3']}) - - # create a file with only history and oneline data - with pd.HDFStore(self.filename1, 'w') as store: - store.append('history',self.history_data, data_columns=True) - store.append('oneline', self.oneline_data, data_columns=True) - - with pd.HDFStore(self.filename2, 'w') as store: - store.append('history', self.history_data, data_columns=True) - - with pd.HDFStore(self.filename3, 'w') as store: - store.append('oneline', self.oneline_data, data_columns=True) - - def teardown_class(self): - if os.path.exists(self.filename1): - os.remove(self.filename1) - if os.path.exists(self.filename2): - os.remove(self.filename2) - if os.path.exists(self.filename3): - os.remove(self.filename3) - - @pytest.fixture - def mass_per_met_pop(self): - self.filename = "mass_per_met_population.h5" - with pd.HDFStore(self.filename, 'w') as store: - store.append('history', self.history_data, data_columns=True) - store.append('oneline', self.oneline_data, data_columns=True) - - pop = Population(self.filename, verbose=True, metallicity=0.02, ini_file=PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - yield - if os.path.exists(self.filename): - os.remove(self.filename) - - @pytest.fixture - def no_mass_per_met_pop(self): - self.filename = "no_mass_per_met_population.h5" - with pd.HDFStore(self.filename, 'w') as store: - store.append('history', self.history_data, data_columns=True) - store.append('oneline', self.oneline_data, data_columns=True) - yield - if os.path.exists(self.filename): - os.remove(self.filename) - - @pytest.fixture - def mass_per_met_pop_channels(self): - self.filename = "mass_per_met_population.h5" - with pd.HDFStore(self.filename, 'w') as store: - store.append('history', self.history_data, data_columns=True) - store.append('oneline', self.oneline_data, data_columns=True) - store.append('formation_channels', self.formation_channels, data_columns=True) - pop = Population(self.filename, verbose=True, metallicity=0.02, ini_file=PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - yield - if os.path.exists(self.filename): - os.remove(self.filename) - - @pytest.fixture - def clean_up_selection_file(self): - self.outfile = "test_selection.h5" - yield - if os.path.exists(self.outfile): - os.remove(self.outfile) - - def test_init_invalid_file(self): - with pytest.raises(ValueError): - pop = Population('invalid_filename') - - def test_init_no_history(self): - with pytest.raises(ValueError): - pop = Population(self.filename3) - - def test_init_no_oneline(self): - with pytest.raises(ValueError): - pop = Population(self.filename2) - - def test_init_no_mass_per_met(self): - with pytest.raises(ValueError): - pop = Population(self.filename1, verbose=True) - - - def test_init_mass_per_met_calc(self, no_mass_per_met_pop: None): - pop = Population(self.filename, verbose=True, metallicity=1., ini_file=PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # check that the history and oneline data are read correctly - pd.testing.assert_frame_equal(pop.history[:], self.history_data) - pd.testing.assert_frame_equal(pop.oneline[:], self.oneline_data) - assert pop.solar_metallicities == [1.] - assert pop.metallicities == [1*Zsun] - pd.testing.assert_frame_equal(pop.mass_per_metallicity, pd.DataFrame(index=[1.], data={'simulated_mass': [260.], 'underlying_mass': [1462.194834], 'number_of_systems': [3]})) - - pop = Population(self.filename, verbose=True, metallicity=1., ini_file=PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # check that the history and oneline data are the same - pd.testing.assert_frame_equal(pop.history[:], self.history_data) - pd.testing.assert_frame_equal(pop.oneline[:], self.oneline_data) - assert pop.solar_metallicities == [1.] - assert pop.metallicities == [1*Zsun] - pd.testing.assert_frame_equal(pop.mass_per_metallicity, pd.DataFrame(index=[1.], data={'simulated_mass': [260.], 'underlying_mass': [1462.194834], 'number_of_systems': [3]})) - - - def test_init(self,mass_per_met_pop: None): - pop = Population(self.filename) - # check that the history and oneline data are read correctly - pd.testing.assert_frame_equal(pop.history[:], self.history_data) - pd.testing.assert_frame_equal(pop.oneline[:], self.oneline_data) - assert pop.metallicities == [0.02*Zsun] - assert pop.solar_metallicities == [0.02] - tmp_df = pd.DataFrame(index=[0, 1, 2], data={'index': [1, 1, 1]}) - tmp_df.index.name = 'index' - pd.testing.assert_frame_equal(pop.history_lengths, tmp_df) - pd.testing.assert_frame_equal(pop.mass_per_metallicity, pd.DataFrame(index=[0.02], data={'simulated_mass': [260.], 'underlying_mass': [1462.194834], 'number_of_systems': [3]})) - - - def test_read_formation_channels(self, mass_per_met_pop_channels: None): - pop = Population(self.filename) - # check that the formation channels are read correctly - assert pop.formation_channels.equals(self.formation_channels) - - def test_export_selection(self, mass_per_met_pop: None, clean_up_selection_file: None): - selection = [1, 2] - chunksize = 1000 - pop = Population(self.filename) - pop.export_selection(selection, self.outfile, chunksize) - assert os.path.exists(self.outfile) - assert pd.read_hdf(self.outfile, 'history').shape[0] == 2 - assert pd.read_hdf(self.outfile, 'oneline').shape[0] == 2 - - def test_bad_name_export_selection(self, mass_per_met_pop: None): - selection = [1, 2] - chunksize = 1000 - pop = Population(self.filename) - with pytest.raises(ValueError): - pop.export_selection(selection, 'test_selection.csv', history_chunksize=chunksize) - - def test_append_selection(self, mass_per_met_pop: None, clean_up_selection_file: None): - selection = [1, 2] - chunksize = 1000 - pop = Population(self.filename) - pop.export_selection(selection, self.outfile, overwrite=True, history_chunksize=chunksize) - pop.export_selection(selection, self.outfile, overwrite=False, history_chunksize=chunksize) - - assert pd.read_hdf(self.outfile, 'history').shape[0] == 4 - assert pd.read_hdf(self.outfile, 'oneline').shape[0] == 4 - - def test_no_formation_channels(self, mass_per_met_pop: None): - pop = Population(self.filename, verbose=True) - assert pop.formation_channels is None - - def test_len(self, mass_per_met_pop: None): - pop = Population(self.filename) - assert len(pop) == 3 - - def test_columns(self, mass_per_met_pop: None): - pop = Population(self.filename) - columns = pop.columns - - assert columns['history'].tolist() == self.history_data.columns.tolist() - assert columns['oneline'].tolist() == self.oneline_data.columns.tolist() - - - - - # Test formation channel calculation, I need a specific test file for this, - # since it requires specific columns to be present in the oneline and history dataframes - - # Test create_transient_population method requires a specific test file for this, - # since it requires specific columns to be present in the oneline and history dataframes - - -class TestTransientPopulation: - pass - # to implement - - - -class TestRates: - pass - # to implement - -# Run the tests - -if __name__ == '__main__': - pytest.main() diff --git a/posydon/tests/visualization/test_VHdiagram.py b/posydon/tests/visualization/test_VHdiagram.py deleted file mode 100644 index 4e0fa81276..0000000000 --- a/posydon/tests/visualization/test_VHdiagram.py +++ /dev/null @@ -1,93 +0,0 @@ -import os -import unittest -from unittest.mock import patch - -from PyQt5.QtCore import QTimer -from PyQt5.QtWidgets import QApplication - -from posydon.config import PATH_TO_POSYDON -from posydon.visualization.VH_diagram.Presenter import Presenter, PresenterMode - -PATH_TO_DATASET = os.path.join( - PATH_TO_POSYDON, - "posydon", - "tests", - "data", - "POSYDON-UNIT-TESTS", - "visualization", - "20000_binaries.csv.gz" -) - -# https://stackoverflow.com/questions/60692711/cant-create-python-qapplication-in-github-action - -# if not os.path.exists(PATH_TO_DATASET): -# raise ValueError("Dataset for unit testing (VH diagram) was not found!") -# -# -# class TestVHdiagram(unittest.TestCase): -# def test_termination_detailled_view(self): -# app = QApplication.instance() -# if not app: -# app = QApplication([]) -# -# presenter = Presenter(PATH_TO_DATASET) -# -# with patch('PyQt5.QtWidgets.QMainWindow.show') as show_patch: -# presenter.present(19628, PresenterMode.DETAILED) -# -# QTimer.singleShot(0, lambda : presenter.close() ) -# -# app.exec_() -# -# assert show_patch.called -# -# def test_termination_reduced_view(self): -# app = QApplication.instance() -# if not app: -# app = QApplication([]) -# -# presenter = Presenter(PATH_TO_DATASET) -# -# with patch('PyQt5.QtWidgets.QMainWindow.show') as show_patch: -# presenter.present(19628, PresenterMode.REDUCED) -# -# QTimer.singleShot(0, lambda : presenter.close() ) -# -# app.exec_() -# -# assert show_patch.called -# -# def test_termination_simplified_view(self): -# app = QApplication.instance() -# if not app: -# app = QApplication([]) -# -# presenter = Presenter(PATH_TO_DATASET) -# -# with patch('PyQt5.QtWidgets.QMainWindow.show') as show_patch: -# presenter.present(19628, PresenterMode.SIMPLIFIED) -# -# QTimer.singleShot(0, lambda : presenter.close() ) -# -# app.exec_() -# -# assert show_patch.called -# -# def test_termination_diagram_view(self): -# app = QApplication.instance() -# if not app: -# app = QApplication([]) -# -# presenter = Presenter(PATH_TO_DATASET) -# -# with patch('PyQt5.QtWidgets.QMainWindow.show') as show_patch: -# presenter.present(19628, PresenterMode.DIAGRAM) -# -# QTimer.singleShot(0, lambda : presenter.close() ) -# -# app.exec_() -# -# assert show_patch.called -# -# if __name__ == "__main__": -# unittest.main() diff --git a/posydon/tests/visualization/test_plot1D.py b/posydon/tests/visualization/test_plot1D.py deleted file mode 100644 index 5621832351..0000000000 --- a/posydon/tests/visualization/test_plot1D.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import unittest -from unittest.mock import patch - -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, - "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5" -) - -if not os.path.exists(PATH_TO_GRID): - raise ValueError("Test grid for unit testing was not found!") - - -class TestPlot1D(unittest.TestCase): - def test_one_track_one_var_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot(42, - "star_age", - "center_he4", - history="history1", - **{'show_fig': True}) - assert show_patch.called - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot(42, - "age", - "star_1_mass", - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_one_track_many_vars_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot(42, - "star_age", ["center_he4", "log_LHe"], - history="history1", - **{'show_fig': True}) - assert show_patch.called - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot( - 42, - "age", - ["star_1_mass", "binary_separation", "rl_relative_overflow_1"], - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_many_tracks_one_var_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot([42, 43, 44], - "star_age", - "center_he4", - history="history1", - **{'show_fig': True}) - assert show_patch.called - - def test_many_tracks_many_vars_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot( - [42, 43, 44], - "age", - ["star_1_mass", "binary_separation", "rl_relative_overflow_1"], - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_one_track_one_var_extra_var_color_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot(42, - "age", - "star_1_mass", - "period_days", - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_many_tracks_one_var_extra_var_color_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot([42, 43], - "age", - "star_1_mass", - "period_days", - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_one_track_HR_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.HR(42, history="history1", **{'show_fig': True}) - assert show_patch.called - - def test_many_tracks_HR_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.HR([42, 43, 44], history="history1", **{'show_fig': True}) - assert show_patch.called - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/visualization/test_plot2D.py b/posydon/tests/visualization/test_plot2D.py deleted file mode 100644 index 637ebb72e8..0000000000 --- a/posydon/tests/visualization/test_plot2D.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import unittest -from unittest.mock import patch - -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, - "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5" -) -if not os.path.exists(PATH_TO_GRID): - raise ValueError("Test grid for unit testing was not found!") - -# class TestPlot2D(unittest.TestCase): -# def test_termination_flag_1_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "star_1_mass", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "c_core_mass", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "binary_separation", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_termination_flag_2_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# None, -# termination_flag="termination_flag_2", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_termination_flag_3_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# None, -# termination_flag="termination_flag_3", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_termination_flag_4_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# None, -# termination_flag="termination_flag_4", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_all_termination_flags_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "binary_separation", -# termination_flag="all", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_RLO_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# # with patch('matplotlib.pyplot.show') as show_patch: -# # grid.plot2D("star_1_mass", -# # "period_days", -# # "c_core_mass", -# # termination_flag="termination_flag_1", -# # grid_3D=True, -# # slice_3D_var_str="star_2_mass", -# # slice_3D_var_range=(2.5, 3.0), -# # slice_at_RLO=True, -# # **{ -# # 'show_fig': True -# # }) -# # assert show_patch.called -# # with patch('matplotlib.pyplot.show') as show_patch: -# # grid.plot2D("star_1_mass", -# # "period_days", -# # "star_1_mass", -# # termination_flag="termination_flag_1", -# # grid_3D=True, -# # slice_3D_var_str="star_2_mass", -# # slice_3D_var_range=(2.5, 3.0), -# # slice_at_RLO=True, -# # **{ -# # 'show_fig': True -# # }) -# # assert show_patch.called -# -# def test_extra_grid_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "star_1_mass", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# extra_grid=grid, -# **{'show_fig': True}) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "star_1_mass", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# extra_grid=grid, -# **{'show_fig': True}) -# assert show_patch.called -# -# -# if __name__ == "__main__": -# unittest.main() diff --git a/posydon/unit_tests/CLI/popsyn/test_check.py b/posydon/unit_tests/CLI/popsyn/test_check.py index dc0c717002..14b966f960 100644 --- a/posydon/unit_tests/CLI/popsyn/test_check.py +++ b/posydon/unit_tests/CLI/popsyn/test_check.py @@ -397,6 +397,22 @@ def test_get_expected_batch_count_array_without_dash(self, tmp_path): result = totest.get_expected_batch_count(str(tmp_path), str_met) assert result is None # Should return None when array doesn't have dash + + def test_get_expected_batch_count_with_max_concurrent_jobs(self, tmp_path): + """Test get_expected_batch_count with % (max concurrent jobs) in array range.""" + metallicity = 1.0 + str_met = convert_metallicity_to_string(metallicity) + slurm_file = tmp_path / f"{str_met}_Zsun_slurm_array.slurm" + slurm_content = """#!/bin/bash +#SBATCH --array=0-9%5 +#SBATCH --job-name=test +""" + slurm_file.write_text(slurm_content) + + result = totest.get_expected_batch_count(str(tmp_path), str_met) + # Should strip %5 and return 10 (0-9 inclusive) + assert result == 10 + class TestFindMissingBatchIndices: """Test class for find_missing_batch_indices function.""" diff --git a/posydon/unit_tests/CLI/popsyn/test_setup.py b/posydon/unit_tests/CLI/popsyn/test_setup.py index 07f0ac9c4b..ea5d0c65e3 100644 --- a/posydon/unit_tests/CLI/popsyn/test_setup.py +++ b/posydon/unit_tests/CLI/popsyn/test_setup.py @@ -125,7 +125,8 @@ def test_setup_popsyn_function_too_few_binaries(self, mock_binarypop, mock_valid """Test that ValueError is raised when number of binaries is too small.""" mock_binarypop.return_value = { 'metallicities': [1.0], - 'number_of_binaries': 5 # Less than job_array (10) + 'number_of_binaries': 5, # Less than job_array (10) + 'secondary_mass_scheme': 'q=1' } with pytest.raises(ValueError, match="number of binaries is less than the job array"): @@ -147,7 +148,8 @@ def test_setup_popsyn_function_success( metallicities = [0.01, 1.0] mock_binarypop.return_value = { 'metallicities': metallicities, - 'number_of_binaries': 1000 + 'number_of_binaries': 1000, + 'secondary_mass_scheme': 'q=1' } # Call function @@ -174,7 +176,8 @@ def test_setup_popsyn_function_creates_log_directories( metallicities = [0.1, 1.0] mock_binarypop.return_value = { 'metallicities': metallicities, - 'number_of_binaries': 1000 + 'number_of_binaries': 1000, + 'secondary_mass_scheme': 'q=1' } totest.setup_popsyn_function(mock_args) @@ -187,6 +190,90 @@ def test_setup_popsyn_function_creates_log_directories( assert any(['1e-01_logs' in call for call in calls]) +class TestFlatMassRatioWarning: + """Test class for flat_mass_ratio mass ratio warning logic.""" + + @pytest.fixture + def mock_args(self): + """Create mock command-line arguments.""" + args = MagicMock() + args.ini_file = "test.ini" + args.job_array = 10 + return args + + @patch('posydon.CLI.popsyn.setup.validate_ini_file') + @patch('posydon.CLI.popsyn.setup.binarypop_kwargs_from_ini') + @patch('posydon.CLI.popsyn.setup.create_python_scripts') + @patch('posydon.CLI.popsyn.setup.create_slurm_scripts') + @patch('posydon.CLI.popsyn.setup.create_bash_submit_script') + @patch('posydon.CLI.popsyn.setup.Pwarn') + @patch('os.makedirs') + def test_flat_mass_ratio_low_q_issues_warning( + self, mock_makedirs, mock_pwarn, mock_bash, mock_slurm, + mock_python, mock_binarypop, mock_validate, mock_args + ): + """Test warning is issued when q_min < 0.05 with flat_mass_ratio.""" + mock_binarypop.return_value = { + 'metallicities': [1.0], + 'number_of_binaries': 1000, + 'secondary_mass_scheme': 'flat_mass_ratio', + 'secondary_mass_min': 0.1, + 'primary_mass_max': 150.0, + } + + totest.setup_popsyn_function(mock_args) + + mock_pwarn.assert_called_once() + call_args = mock_pwarn.call_args[0] + assert 'InappropriateValueWarning' in call_args + + @patch('posydon.CLI.popsyn.setup.validate_ini_file') + @patch('posydon.CLI.popsyn.setup.binarypop_kwargs_from_ini') + @patch('posydon.CLI.popsyn.setup.create_python_scripts') + @patch('posydon.CLI.popsyn.setup.create_slurm_scripts') + @patch('posydon.CLI.popsyn.setup.create_bash_submit_script') + @patch('posydon.CLI.popsyn.setup.Pwarn') + @patch('os.makedirs') + def test_flat_mass_ratio_acceptable_q_no_warning( + self, mock_makedirs, mock_pwarn, mock_bash, mock_slurm, + mock_python, mock_binarypop, mock_validate, mock_args + ): + """Test no warning is issued when q_min >= 0.05 with flat_mass_ratio.""" + mock_binarypop.return_value = { + 'metallicities': [1.0], + 'number_of_binaries': 1000, + 'secondary_mass_scheme': 'flat_mass_ratio', + 'secondary_mass_min': 5.0, + 'primary_mass_max': 80.0, + } + + totest.setup_popsyn_function(mock_args) + + mock_pwarn.assert_not_called() + + @patch('posydon.CLI.popsyn.setup.validate_ini_file') + @patch('posydon.CLI.popsyn.setup.binarypop_kwargs_from_ini') + @patch('posydon.CLI.popsyn.setup.create_python_scripts') + @patch('posydon.CLI.popsyn.setup.create_slurm_scripts') + @patch('posydon.CLI.popsyn.setup.create_bash_submit_script') + @patch('posydon.CLI.popsyn.setup.Pwarn') + @patch('os.makedirs') + def test_non_flat_mass_ratio_scheme_no_warning( + self, mock_makedirs, mock_pwarn, mock_bash, mock_slurm, + mock_python, mock_binarypop, mock_validate, mock_args + ): + """Test no warning is issued when secondary_mass_scheme is not flat_mass_ratio.""" + mock_binarypop.return_value = { + 'metallicities': [1.0], + 'number_of_binaries': 1000, + 'secondary_mass_scheme': 'q=1', + } + + totest.setup_popsyn_function(mock_args) + + mock_pwarn.assert_not_called() + + class TestIntegration: """Integration tests for the setup module.""" @@ -211,7 +298,8 @@ def test_full_setup_workflow( } mock_binarypop.return_value = { 'metallicities': [1.0], - 'number_of_binaries': 100 + 'number_of_binaries': 100, + 'secondary_mass_scheme': 'q=1' } # Create mock args diff --git a/posydon/unit_tests/CLI/test_io.py b/posydon/unit_tests/CLI/test_io.py index 9229e48adf..767f820812 100644 --- a/posydon/unit_tests/CLI/test_io.py +++ b/posydon/unit_tests/CLI/test_io.py @@ -175,7 +175,9 @@ def test_create_slurm_array(self, tmp_path): totest.create_slurm_array( metallicity, job_array_length, partition, email, walltime, account, mem_per_cpu, - path_to_posydon, path_to_posydon_data + max_concurrent_jobs=None, exclude=None, + path_to_posydon=path_to_posydon, + path_to_posydon_data=path_to_posydon_data ) # Check that the file was created @@ -212,6 +214,7 @@ def test_create_slurm_array_minimal(self, tmp_path): partition=None, email=None, walltime="12:00:00", account=None, mem_per_cpu="2G", + max_concurrent_jobs=None, exclude=None, path_to_posydon="/posydon", path_to_posydon_data="/data" ) @@ -227,6 +230,67 @@ def test_create_slurm_array_minimal(self, tmp_path): finally: os.chdir(original_dir) + def test_create_slurm_array_with_max_concurrent_jobs(self, tmp_path): + """Test create_slurm_array with max_concurrent_jobs set.""" + original_dir = os.getcwd() + os.chdir(tmp_path) + + try: + metallicity = 1.0 + job_array_length = 10 + max_concurrent_jobs = 5 + + totest.create_slurm_array( + metallicity, job_array_length, + partition=None, email=None, + walltime="24:00:00", account=None, + mem_per_cpu="4G", + max_concurrent_jobs=max_concurrent_jobs, + exclude=None, + path_to_posydon="/posydon", + path_to_posydon_data="/data" + ) + str_met = convert_metallicity_to_string(metallicity) + filename = f"{str_met}_Zsun_slurm_array.slurm" + assert os.path.exists(filename) + + with open(filename, "r") as f: + content = f.read() + # Array should include the %N max concurrent jobs specifier + assert f"#SBATCH --array=0-{job_array_length - 1}%{max_concurrent_jobs}" in content + finally: + os.chdir(original_dir) + + def test_create_slurm_array_with_exclude(self, tmp_path): + """Test create_slurm_array with exclude set.""" + original_dir = os.getcwd() + os.chdir(tmp_path) + + try: + metallicity = 1.0 + job_array_length = 10 + exclude = "node01,node02" + + totest.create_slurm_array( + metallicity, job_array_length, + partition=None, email=None, + walltime="24:00:00", account=None, + mem_per_cpu="4G", + max_concurrent_jobs=None, + exclude=exclude, + path_to_posydon="/posydon", + path_to_posydon_data="/data" + ) + str_met = convert_metallicity_to_string(metallicity) + filename = f"{str_met}_Zsun_slurm_array.slurm" + assert os.path.exists(filename) + + with open(filename, "r") as f: + content = f.read() + assert f"#SBATCH --exclude={exclude}" in content + finally: + os.chdir(original_dir) + def test_create_slurm_merge(self, tmp_path): """Test that create_slurm_merge creates a merge SLURM script.""" original_dir = os.getcwd() @@ -343,6 +407,8 @@ def test_create_slurm_rescue(self, tmp_path): walltime="20:00:00", account="test_account", mem_per_cpu="4G", + max_concurrent_jobs=None, + exclude=None, path_to_posydon="/posydon", path_to_posydon_data="/data" ) @@ -379,6 +445,8 @@ def test_create_slurm_rescue_minimal(self, tmp_path): walltime="20:00:00", account=None, # No account mem_per_cpu="4G", + max_concurrent_jobs=None, + exclude=None, path_to_posydon="/posydon", path_to_posydon_data="/data" ) @@ -399,6 +467,78 @@ def test_create_slurm_rescue_minimal(self, tmp_path): finally: os.chdir(original_dir) + def test_create_slurm_rescue_with_max_concurrent_jobs(self, tmp_path): + """Test create_slurm_rescue with max_concurrent_jobs set.""" + original_dir = os.getcwd() + os.chdir(tmp_path) + + try: + metallicity = 1.0 + missing_indices = [1, 3] + job_array_length = 10 + max_concurrent_jobs = 4 + + totest.create_slurm_rescue( + metallicity=metallicity, + missing_indices=missing_indices, + job_array_length=job_array_length, + partition=None, + email=None, + walltime="20:00:00", + account=None, + mem_per_cpu="4G", + max_concurrent_jobs=max_concurrent_jobs, + exclude=None, + path_to_posydon="/posydon", + path_to_posydon_data="/data" + ) + + str_met = convert_metallicity_to_string(metallicity) + filename = f"{str_met}_Zsun_rescue.slurm" + assert os.path.exists(filename) + + with open(filename, "r") as f: + content = f.read() + assert f"#SBATCH --array=1,3%{max_concurrent_jobs}" in content + finally: + os.chdir(original_dir) + + def test_create_slurm_rescue_with_exclude(self, tmp_path): + """Test create_slurm_rescue with exclude set.""" + original_dir = os.getcwd() + os.chdir(tmp_path) + + try: + metallicity = 1.0 + missing_indices = [2, 4] + job_array_length = 10 + exclude = "node01,node02" + + totest.create_slurm_rescue( + metallicity=metallicity, + missing_indices=missing_indices, + job_array_length=job_array_length, + partition=None, + email=None, + walltime="20:00:00", + account=None, + mem_per_cpu="4G", + max_concurrent_jobs=None, + exclude=exclude, + path_to_posydon="/posydon", + path_to_posydon_data="/data" + ) + + str_met = convert_metallicity_to_string(metallicity) + filename = f"{str_met}_Zsun_rescue.slurm" + assert os.path.exists(filename) + + with open(filename, "r") as f: + content = f.read() + assert f"#SBATCH --exclude={exclude}" in content + finally: + os.chdir(original_dir) + def test_create_bash_submit_script(self, tmp_path): """Test that create_bash_submit_script creates a submission script.""" original_dir = os.getcwd() @@ -471,6 +611,8 @@ def mock_args(self): args.partition = "normal" args.account = "test_account" args.email = "test@example.com" + args.max_concurrent_jobs = None + args.exclude = None return args @pytest.fixture @@ -595,6 +737,8 @@ def test_create_batch_rescue_script_partial_overrides(self, tmp_path, mock_batch args.partition = None # Don't override args.account = None # Don't override args.email = None # Don't override + args.max_concurrent_jobs = None # Don't override + args.exclude = None # Don't override # Create a mock SLURM array script slurm_script = run_folder / "1e+00_Zsun_slurm_array.slurm" @@ -647,6 +791,8 @@ def test_create_batch_rescue_script_no_overrides(self, tmp_path, mock_batch_stat args.partition = None # Don't override args.account = None # Don't override args.email = None # Don't override + args.max_concurrent_jobs = None # Don't override + args.exclude = None # Don't override # Create a mock SLURM array script slurm_script = run_folder / "1e+00_Zsun_slurm_array.slurm" @@ -696,6 +842,8 @@ def test_create_batch_rescue_script_array_without_dash(self, tmp_path, mock_batc args.partition = None args.account = None args.email = None + args.max_concurrent_jobs = None + args.exclude = None # Create a mock SLURM array script with array format that doesn't have a dash # (e.g., just a single number or comma-separated list) @@ -723,3 +871,172 @@ def test_create_batch_rescue_script_array_without_dash(self, tmp_path, mock_batc assert rescue_script.exists() finally: os.chdir(original_dir) + + def test_create_batch_rescue_script_with_exclude_in_slurm(self, tmp_path, mock_batch_status): + """Test create_batch_rescue_script parses --exclude from existing SLURM script.""" + run_folder = tmp_path / "run" + run_folder.mkdir() + + args = MagicMock() + args.run_folder = str(run_folder) + args.walltime = None + args.mem_per_cpu = None + args.partition = None + args.account = None + args.email = None + args.max_concurrent_jobs = None + args.exclude = None + + # SLURM script with --exclude= directive + slurm_script = run_folder / "1e+00_Zsun_slurm_array.slurm" + slurm_content = textwrap.dedent("""\ + #!/bin/bash + #SBATCH --array=0-9 + #SBATCH --time=24:00:00 + #SBATCH --mem-per-cpu=4G + #SBATCH --exclude=node01,node02 + export PATH_TO_POSYDON=/path/to/posydon + export PATH_TO_POSYDON_DATA=/path/to/data + srun python ./run_metallicity.py 1.0 + """) + slurm_script.write_text(slurm_content) + + original_dir = os.getcwd() + os.chdir(run_folder) + + try: + result = totest.create_batch_rescue_script(args, mock_batch_status) + + rescue_script = run_folder / "1e+00_Zsun_rescue.slurm" + content = rescue_script.read_text() + # Verify exclude was parsed from SLURM and passed to rescue script + assert "#SBATCH --exclude=node01,node02" in content + finally: + os.chdir(original_dir) + + def test_create_batch_rescue_script_with_max_concurrent_in_slurm( + self, tmp_path, mock_batch_status + ): + """Test create_batch_rescue_script parses % (max concurrent) from SLURM array.""" + run_folder = tmp_path / "run" + run_folder.mkdir() + + args = MagicMock() + args.run_folder = str(run_folder) + args.walltime = None + args.mem_per_cpu = None + args.partition = None + args.account = None + args.email = None + args.max_concurrent_jobs = None + args.exclude = None + + # SLURM script with %N in array (max concurrent jobs) + slurm_script = run_folder / "1e+00_Zsun_slurm_array.slurm" + slurm_content = textwrap.dedent("""\ + #!/bin/bash + #SBATCH --array=0-9%5 + #SBATCH --time=24:00:00 + #SBATCH --mem-per-cpu=4G + export PATH_TO_POSYDON=/path/to/posydon + export PATH_TO_POSYDON_DATA=/path/to/data + srun python ./run_metallicity.py 1.0 + """) + slurm_script.write_text(slurm_content) + + original_dir = os.getcwd() + os.chdir(run_folder) + + try: + result = totest.create_batch_rescue_script(args, mock_batch_status) + + rescue_script = run_folder / "1e+00_Zsun_rescue.slurm" + content = rescue_script.read_text() + # Verify max_concurrent_jobs was parsed from SLURM and included in rescue script + assert "%5" in content + finally: + os.chdir(original_dir) + + def test_create_batch_rescue_script_with_max_concurrent_arg( + self, tmp_path, mock_batch_status + ): + """Test create_batch_rescue_script with max_concurrent_jobs override from args.""" + run_folder = tmp_path / "run" + run_folder.mkdir() + + args = MagicMock() + args.run_folder = str(run_folder) + args.walltime = None + args.mem_per_cpu = None + args.partition = None + args.account = None + args.email = None + args.max_concurrent_jobs = 3 + args.exclude = None + + slurm_script = run_folder / "1e+00_Zsun_slurm_array.slurm" + slurm_content = textwrap.dedent("""\ + #!/bin/bash + #SBATCH --array=0-9 + #SBATCH --time=24:00:00 + #SBATCH --mem-per-cpu=4G + export PATH_TO_POSYDON=/path/to/posydon + export PATH_TO_POSYDON_DATA=/path/to/data + srun python ./run_metallicity.py 1.0 + """) + slurm_script.write_text(slurm_content) + + original_dir = os.getcwd() + os.chdir(run_folder) + + try: + result = totest.create_batch_rescue_script(args, mock_batch_status) + + rescue_script = run_folder / "1e+00_Zsun_rescue.slurm" + content = rescue_script.read_text() + # Verify max_concurrent_jobs from args is used in rescue script + assert "%3" in content + finally: + os.chdir(original_dir) + + def test_create_batch_rescue_script_with_exclude_arg( + self, tmp_path, mock_batch_status + ): + """Test create_batch_rescue_script with exclude override from args.""" + run_folder = tmp_path / "run" + run_folder.mkdir() + + args = MagicMock() + args.run_folder = str(run_folder) + args.walltime = None + args.mem_per_cpu = None + args.partition = None + args.account = None + args.email = None + args.max_concurrent_jobs = None + args.exclude = "badnode01" + + slurm_script = run_folder / "1e+00_Zsun_slurm_array.slurm" + slurm_content = textwrap.dedent("""\ + #!/bin/bash + #SBATCH --array=0-9 + #SBATCH --time=24:00:00 + #SBATCH --mem-per-cpu=4G + export PATH_TO_POSYDON=/path/to/posydon + export PATH_TO_POSYDON_DATA=/path/to/data + srun python ./run_metallicity.py 1.0 + """) + slurm_script.write_text(slurm_content) + + original_dir = os.getcwd() + os.chdir(run_folder) + + try: + result = totest.create_batch_rescue_script(args, mock_batch_status) + + rescue_script = run_folder / "1e+00_Zsun_rescue.slurm" + content = rescue_script.read_text() + # Verify exclude from args is used in rescue script + assert "#SBATCH --exclude=badnode01" in content + finally: + os.chdir(original_dir) diff --git a/posydon/unit_tests/_data/POSYDON_data/Patton+Sukhbold20/Kepler_sc_table.dat b/posydon/unit_tests/_data/POSYDON_data/Patton+Sukhbold20/Kepler_sc_table.dat new file mode 100644 index 0000000000..50e832da4d --- /dev/null +++ b/posydon/unit_tests/_data/POSYDON_data/Patton+Sukhbold20/Kepler_sc_table.dat @@ -0,0 +1,54 @@ +# Patton & Sukhbold (2020, MNRAS) +# +# s_c, the central entropy per baryon in units of K_B (Boltzmann constant), evaluated at the presupernova stage (needed for example in the Maltsev+2025 criterion) +# for KEPLER models of a given CO-core mass (columns) and initial carbon mass fraction (rows) +# +# All masses are in units of Msun +# +X_c 2.5 2.6 2.7 2.8 2.9 3.0 3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8 3.9 4.0 4.1 4.2 4.3 4.4 4.5 4.6 4.7 4.8 4.9 5.0 5.1 5.2 5.3 5.4 5.5 5.6 5.7 5.8 5.9 6.0 6.1 6.2 6.3 6.4 6.5 6.6 6.7 6.8 6.9 7.0 7.1 7.2 7.3 7.4 7.5 7.6 7.7 7.8 7.9 8.0 8.1 8.2 8.3 8.4 8.5 8.6 8.7 8.8 8.9 9.0 9.1 9.2 9.3 9.4 9.5 9.6 9.7 9.8 9.9 10.0 +0.05 0.610768691 0.629214872 0.652130574 0.679378435 0.713697783 0.74581048 0.766641043 0.792775501 0.814823111 0.836142529 0.859851158 0.886006531 0.912686043 0.932691598 0.957916148 0.969091329 0.99016173 1.010252072 1.028069634 1.051504705 1.075993752 1.089255614 1.111628723 1.119726359 1.14275482 1.161087743 1.18040405 1.197066773 1.216365632 1.214596008 1.229566027 1.235289616 1.238242492 1.24917912 1.245350981 1.220772176 1.21836697 1.210971843 1.187712915 1.171558872 0.91886142 0.86374049 0.865662993 0.870884007 0.661530736 0.673704051 0.678886681 0.688039343 0.696985998 0.702601761 0.711639237 0.721744358 0.731312803 0.734804539 0.744305905 0.754773775 0.76451156 0.77518963 0.78722218 0.795089722 0.806110319 0.815227124 0.825915333 0.835420584 0.845608485 0.860265491 0.868023291 0.875491152 0.88321075 0.890889423 0.89890877 0.905702072 0.918266901 0.925577573 0.93295046 0.937262003 +0.06 0.605510653 0.623915014 0.645836118 0.669321322 0.711392924 0.733214322 0.757603034 0.779210939 0.809051813 0.827016041 0.853383266 0.876592594 0.900028228 0.917641138 0.940458175 0.955522327 0.953821922 0.970036773 0.980775063 0.993971811 1.018976141 1.059055157 1.059484498 1.091019601 1.128510811 1.148597138 1.169798336 1.182814096 1.193163283 1.204827891 1.218106081 1.216802697 1.223924857 1.232778163 1.233200008 1.241112042 1.242004225 1.234769423 1.215049219 1.215025812 1.201259291 1.183634618 1.171101742 1.155320234 1.108348597 1.096416255 0.977886043 0.921079589 0.677173147 0.682344755 0.692446217 0.700938994 0.704854752 0.714058995 0.721762359 0.73015152 0.734381046 0.743366991 0.753205934 0.76265074 0.772110114 0.782041086 0.792809845 0.805655556 0.816805382 0.825751237 0.837665069 0.841273076 0.833332233 0.841220896 0.848222083 0.860692799 0.868035718 0.875329517 0.882247869 0.890472758 +0.15 0.728703373 0.741488304 0.743411687 0.766181023 0.613251415 0.654098815 0.706728262 0.710525257 0.645557443 0.792594011 0.782761187 0.727139362 0.714101691 0.710218882 0.688158065 0.697339026 0.764769326 0.81681644 0.891546699 0.999853931 1.00023813 1.014926909 1.022942369 1.009845222 1.010410148 1.013922446 1.015498105 1.014668779 1.015341347 1.019183579 1.027168877 1.028267606 1.023165366 1.002707222 0.996271088 0.999200651 1.000902452 0.990751311 0.981018618 0.961410751 0.959386139 0.960218557 0.996352029 0.949127941 0.967258548 0.946368996 0.952897947 0.946629509 0.92697324 0.958720712 0.946238579 0.940932286 0.947880852 0.933168069 0.989886803 0.994774084 1.013139995 1.00647957 1.018824226 1.041982809 1.056247591 1.054407715 1.03772044 1.046278875 1.047502012 1.04529673 1.044989069 1.054270488 1.055515471 1.061478717 1.074239667 1.074599382 1.080888843 1.082550062 1.085902628 1.086620885 +0.16 0.599633585 0.611961543 0.625996328 0.633822125 0.615217121 0.68483262 0.741640804 0.619520461 0.689847198 0.682656235 0.637590107 0.80817751 0.83878263 0.760413159 0.739164805 0.700305141 0.712754622 0.677125565 0.730130902 0.780742243 0.843040634 0.977177974 0.993821747 1.023210575 1.045642263 1.049752601 1.045145577 1.036772519 1.038826626 1.040209459 1.039261589 1.042671897 1.0474753 1.042759345 1.044256182 1.038306369 1.029397954 1.029626247 1.003515831 0.983398453 0.968150634 0.962061126 0.943464758 0.938685426 0.980288044 0.93565932 0.932571021 0.925900812 0.924422837 0.91899856 0.919840663 0.909938799 0.908222729 0.884624715 0.947441916 0.935440271 0.959698716 0.978727859 0.985858279 0.994486455 0.916248277 1.025814764 0.987339092 0.93338029 0.99845715 1.003853643 0.997222447 0.999187517 1.010113832 1.017798073 1.021035711 1.020547659 1.012971681 1.023295175 1.018392363 1.02134024 +0.17 0.587725025 0.612759 0.627221458 0.646763752 0.681311163 0.710167506 0.721270885 0.699205477 0.696677518 0.715640212 0.624307784 0.749942997 0.65038581 0.732325339 0.875772174 0.789570186 0.747227081 0.688299955 0.70436639 0.659003087 0.701630785 0.757770746 0.813137437 0.851168643 0.993620809 1.025745324 1.049771566 1.071080624 1.078912683 1.068253829 1.06231945 1.065611582 1.064376071 1.061958916 1.057772624 1.054217852 1.048283199 1.050983592 1.045242996 1.032177621 1.020595686 1.001124084 0.969305095 0.965945212 0.954382256 0.927264762 0.922587195 0.967355137 0.92057521 0.931032706 0.897565613 0.903844817 0.894341806 0.886701797 0.922287303 0.925330893 0.920176669 0.927682868 0.93509453 0.930533478 0.912689298 0.906929262 0.927766559 0.934492348 0.935035565 0.940913949 0.93153897 0.929876437 0.925307186 0.915233431 0.928474516 0.952779735 0.958731999 0.924558995 0.923206804 0.934589435 +0.18 0.714812461 0.692773794 0.613792284 0.630027041 0.649855255 0.672508811 0.701349874 0.728603143 0.724875224 0.742257278 0.811857801 0.770264626 0.656872767 0.699842862 0.646901719 0.764874888 0.913510515 0.89138468 0.792259055 0.730158077 0.689693931 0.639964685 0.67411517 0.727404807 0.770459075 0.821487178 0.863029641 1.059560295 1.035655298 1.076871488 1.101405818 1.1066164 1.101483525 1.087786831 1.093537301 1.081583575 1.071869165 1.0671787 1.062486977 1.051066843 1.042141229 1.049765835 1.028267987 1.026518791 1.000759463 0.991332417 0.963493146 0.940345205 0.947860654 0.903517802 0.9009087 0.914452068 0.883628099 0.876459656 0.88037106 0.868060338 0.875999065 0.894374463 0.901717925 0.893990531 0.901202933 0.865691672 0.860992171 0.855043814 0.857375624 0.862012287 0.870245295 0.864464627 0.869944295 0.868347402 0.871126038 0.875259993 0.877083873 0.880593155 0.884207058 0.890785358 +0.19 0.727233715 0.743851809 0.76354611 0.760753509 0.617168234 0.61794141 0.650806325 0.660348827 0.706062145 0.707048881 0.74496703 0.764045885 0.784318085 0.785524333 0.670872529 0.708017104 0.632780578 0.807783283 0.924338934 0.906389152 0.7833206 0.750296728 0.677112811 0.705436654 0.683518706 0.706021226 0.740130761 0.790800401 0.846049857 0.882190769 0.994632623 1.055943214 1.094053795 1.108537171 1.123517082 1.114169814 1.101315029 1.103504987 1.098985184 1.076504483 1.075215846 1.065854465 1.057549715 1.039080143 1.040481551 1.043411307 1.011084281 0.9946586 0.982224421 0.975246055 0.922378203 0.921905954 0.890951127 0.851420352 0.879906609 0.853852975 0.84432161 0.847062915 0.864089848 0.864962826 0.881890994 0.847187889 0.826110122 0.815214058 0.807663401 0.810570517 0.8103254 0.827159101 0.815621849 0.825291956 0.826375372 0.833898273 0.838969986 0.844355918 0.851350684 0.856840041 +0.2 0.725408675 0.665036581 0.660357254 0.72321987 0.676935978 0.607083394 0.737230097 0.627402677 0.642831359 0.638348081 0.650153976 0.696334225 0.697186155 0.715268179 0.730252938 0.815231584 0.812987491 0.794369068 0.637120215 0.66280771 0.862751357 0.924217324 0.919134428 0.74458322 0.688406117 0.699657403 0.638443575 0.66110335 0.718083155 0.754260088 0.797524686 0.833732443 0.943479358 1.037992689 1.084115839 1.136359028 1.13136764 1.131865793 1.115464271 1.120548625 1.110858932 1.100940163 1.08511353 1.074885521 1.072853982 1.063516903 1.053632589 1.035061315 1.040954895 1.016377375 0.987752848 0.963146375 0.951000966 0.899426588 0.885562717 0.858724861 0.838486966 0.83842367 0.82286416 0.825247071 0.836047606 0.816389719 0.803362609 0.750579898 0.770668116 0.771567862 0.775730331 0.782189531 0.764858975 0.753347005 0.785706695 0.810970859 0.815701883 0.825304187 0.830636744 0.833193849 +0.21 0.735126773 0.744628837 0.645749284 0.654571107 0.659779622 0.676540063 0.69309537 0.686906054 0.692523999 0.713091576 0.697731622 0.714895199 0.72855382 0.672382652 0.68018181 0.670527191 0.700748121 0.721272835 0.828538225 0.819732078 0.77269786 0.670726498 0.830413287 0.862263732 0.914938654 0.751235882 0.77395706 0.691449011 0.673080647 0.641296124 0.680259896 0.702148518 0.743657283 0.81745519 0.847451739 0.927289367 0.968268163 1.10271581 1.160478325 1.156932162 1.106469019 1.112365303 1.124255286 1.12093272 1.122682189 1.105946567 1.107622071 1.07231254 1.059866214 1.057039918 1.029399634 1.037075677 1.015201954 0.969069512 0.970802023 0.949553777 0.915996532 0.828294009 0.829624809 0.811956314 0.796417264 0.795673675 0.820751574 0.823976182 0.814909962 0.742422034 0.745382346 0.74915411 0.734188658 0.71282445 0.715909514 0.729040972 0.74339648 0.752799785 0.715557119 0.703158514 +0.22 0.701398866 0.711386416 0.726749635 0.675935837 0.692545086 0.613977848 0.653415306 0.693918532 0.698899073 0.70752421 0.698111522 0.631634111 0.630324906 0.633126497 0.635687214 0.643507131 0.646162186 0.654918893 0.754520876 0.694397342 0.709439984 0.746461911 0.813749933 0.627419073 0.803238027 0.768233454 0.946523224 0.924993001 0.829934978 0.692961439 0.683272168 0.622423238 0.659676979 0.693452996 0.694382079 0.720235104 0.811223938 0.826530647 0.850492336 0.973606635 1.098183492 1.170232748 1.178672548 1.120467503 1.107065636 1.115182021 1.118926609 1.099597647 1.121216982 1.11167808 1.08538358 1.077721256 1.051007097 1.045073783 1.02819266 1.012672673 0.965328456 0.96519974 0.939342945 0.841751651 0.837567496 0.802709817 0.799928885 0.799525233 0.802391004 0.797193407 0.72400731 0.709400902 0.71011076 0.68167114 0.690300333 0.705918457 0.696770802 0.695384452 0.706007033 0.709845331 +0.23 0.658451411 0.658211544 0.72008665 0.663961997 0.608773249 0.767288307 0.750697968 0.721680885 0.723213208 0.708280227 0.647791405 0.677654161 0.651894377 0.669186928 0.653146245 0.616553276 0.603232839 0.599819042 0.690478821 0.698952129 0.715817204 0.664047424 0.660235423 0.756301266 0.810886066 0.89325317 0.815302286 0.803072628 0.947570826 0.925897138 0.747721421 0.696804569 0.682187708 0.740896246 0.638532656 0.672612467 0.706223969 0.681518374 0.732676513 0.808662416 0.85616011 0.889549038 1.024724837 1.068332024 1.18598479 0.992473428 1.034451038 1.09827728 1.112214288 1.12912739 1.109564791 1.105677334 1.111320494 1.097623584 1.085324496 1.054961066 1.046582892 1.037505401 1.015485476 1.000589101 0.954740459 0.920999465 0.90826714 0.802312534 0.775737812 0.761084778 0.770015126 0.778706811 0.698187162 0.696664823 0.669757003 0.672539602 0.670864418 0.677396089 0.683695123 0.696839995 +0.24 0.704386171 0.708339231 0.703784872 0.676793868 0.703329384 0.702250478 0.716516172 0.699388092 0.701383125 0.610823047 0.645956098 0.728628556 0.678827012 0.647876043 0.646526662 0.639506743 0.662445226 0.652024836 0.713104442 0.699785106 0.675601315 0.670440501 0.60670898 0.635678117 0.592668839 0.667837807 0.709553617 0.781214981 0.67871891 0.738804204 0.817301997 0.954952508 0.807959618 0.691695008 0.689354743 0.594279444 0.6478417 0.695044656 0.733582623 0.712450562 0.75624058 0.804116146 0.813306826 0.867075469 0.906719598 0.947998439 1.103944855 1.171471978 0.944103579 0.991219219 1.022312779 1.138247345 1.120713222 1.122846631 1.120523415 1.104092606 1.090414107 1.092920718 1.072445562 1.043424124 1.041582503 1.009936191 1.019911548 0.937973455 0.903119754 0.889136115 0.792004575 0.758036277 0.740187954 0.758494189 0.754387651 0.682106069 0.666932574 0.667560261 0.670734494 0.673272902 +0.07 0.602024189 0.622541816 0.641925427 0.661061987 0.682199475 0.71608768 0.736130711 0.756423012 0.78220097 0.805954179 0.833453975 0.850680206 0.873652473 0.895387393 0.919000237 0.936650553 0.956749168 0.958428265 0.976431461 1.00027234 1.014072395 1.037385533 1.061917497 1.074569321 1.100016978 1.119481309 1.138185408 1.15516439 1.177634754 1.191928795 1.212483174 1.229613407 1.245076232 1.250715207 1.268432907 1.259737306 1.245551569 1.249007368 1.249760592 1.24559669 1.221035129 1.064478441 1.132249041 0.958213568 0.877971319 0.85894617 0.858601039 0.861637054 0.85968838 0.854605375 0.866231732 0.666231413 0.678038778 0.681786241 1.241645492 1.248384477 1.251460329 1.241026878 1.230966744 1.222895047 1.203731852 1.210402905 1.204245193 1.199608974 1.196047686 1.159717642 1.144563403 1.150349404 1.135195411 1.14361162 1.145764216 1.146450762 1.144878776 1.141819996 0.848542549 0.848886815 +0.25 0.716741845 0.71652531 0.730556964 0.732498272 0.692143561 0.694080449 0.69078604 0.69228101 0.699390694 0.703209284 0.72939904 0.695670576 0.681784049 0.711635967 0.723680634 0.682729749 0.730438536 0.731000471 0.724743182 0.736726756 0.733173765 0.746577822 0.701202452 0.705390591 0.656166652 0.720975792 0.712413235 0.668007289 0.680940446 0.674063711 0.727474074 0.753429634 0.789417914 0.991797448 0.854183853 0.748466602 0.669532971 0.673957414 0.610258737 0.666440811 0.700366833 0.744776875 0.738482906 0.785075386 0.852227779 0.85522408 0.902952106 1.000174571 1.073474871 0.862423364 0.918155396 0.97555319 0.976700043 1.02506644 1.050052472 1.149103718 1.128975822 1.125965798 1.120895941 1.116517754 1.105770505 1.099405256 1.069334411 1.033345557 1.011322467 1.018202499 1.000461303 0.914618633 0.875836395 0.851416431 0.727269031 0.734557216 0.750873739 0.787126288 0.737978366 0.714128365 +0.26 0.704446971 0.721169831 0.712893113 0.70820766 0.679560191 0.704907958 0.715057514 0.6918948 0.701139001 0.691232777 0.709673416 0.709262094 0.68306829 0.684419024 0.68679468 0.697990159 0.707121228 0.683780982 0.711374617 0.71063618 0.695620696 0.606977989 0.622181006 0.619561259 0.621536638 0.622779289 0.771995804 0.78320893 0.707279074 0.662875209 0.720812834 0.685710167 0.589247753 0.713615703 0.818133775 0.792878401 0.981497744 0.772903203 0.730205992 0.694643687 0.677360268 0.677080113 0.684114393 0.664199935 0.73483551 0.768444583 0.787001033 0.857254563 0.849327118 0.951819288 0.932847798 1.084784336 1.148996543 0.890544193 0.923256639 0.900231496 0.999316417 1.055970225 1.132200398 1.112469449 1.124044788 1.105739561 1.109378055 1.108443987 1.10150171 1.038735213 1.006590658 1.002410311 1.017222984 1.006555199 0.982159368 0.854080981 0.847835274 0.818685493 0.741078607 0.778403357 +0.27 0.688226045 0.701539399 0.712657602 0.661800295 0.670031133 0.691022283 0.712173324 0.692649489 0.710674744 0.698398626 0.688772972 0.713615637 0.72435416 0.729662961 0.701634315 0.646965586 0.71565246 0.721869439 0.715653875 0.727152061 0.714700927 0.717324645 0.665248254 0.639488288 0.636178219 0.644493515 0.64601868 0.648726091 0.644032369 0.640182839 0.640841862 0.817062924 0.776765747 0.702299277 0.726177613 0.63312702 0.608764391 0.7215841 0.981807613 0.898267736 0.742523874 0.724872391 0.650384957 0.702511296 0.629969494 0.702341296 0.739845968 0.741604625 0.784942608 0.802600102 0.870100039 0.817238257 0.89899934 1.016697398 1.035264925 1.149570449 0.877201696 0.934245767 0.911914938 0.991560878 1.001912353 1.051156303 1.133744782 1.120559077 1.119935266 1.119528917 1.107158387 1.096679434 1.052955993 1.04359871 1.013702736 1.005226334 0.986061548 0.981708933 0.7869379 0.765596838 +0.28 0.672400487 0.699112825 0.685701712 0.669767785 0.666571176 0.677358191 0.625984675 0.63773839 0.640121294 0.626375525 0.618917875 0.623809699 0.662786939 0.62531922 0.636898952 0.623660535 0.63247957 0.635609554 0.655551349 0.646794358 0.655720321 0.686448527 0.671620193 0.674195655 0.681607371 0.675076191 0.695447545 0.692585797 0.686230035 0.692869765 0.69023029 0.665855773 0.662909085 0.671751978 0.677744297 0.657616546 0.800889951 0.807214346 0.713211034 0.629941479 0.788112764 0.969053516 0.79111148 0.75100382 0.717263475 0.671586018 0.615275618 0.643065338 0.729574481 0.663084923 0.766393673 0.807956292 0.822326966 0.901237829 0.830933972 0.890279973 0.931649158 1.091177171 1.091282775 1.115811277 0.900489272 0.978653314 0.977942181 1.009150729 1.031709844 1.026037788 1.049355651 1.138773643 1.124070638 1.110063079 1.092999279 1.090880674 1.058046955 1.015677503 0.985555346 0.980169381 +0.29 0.665453455 0.680790969 0.706939162 0.594713365 0.673527984 0.645034799 0.700104705 0.633307581 0.588352435 0.613282962 0.642647562 0.618123238 0.623913884 0.625743984 0.624161598 0.631020458 0.628384829 0.652071695 0.64106055 0.669700957 0.657876721 0.724868778 0.670861482 0.72798338 0.66871124 0.672119399 0.678860744 0.67551983 0.7724433 0.708646697 0.712679666 0.709883458 0.711452031 0.706935443 0.740712788 0.745365346 0.692741748 0.739111822 0.68351506 0.690338961 0.862242494 0.825828372 0.757635129 0.798618098 0.999407334 0.740818563 0.729327934 0.586937662 0.726740225 0.624854326 0.655861258 0.747535055 0.73605682 0.765121481 0.827010749 0.84183159 0.867000758 0.850370762 0.902490409 0.931232089 1.024361918 1.072452066 1.164462837 0.895891475 0.903569584 0.907620365 0.923169243 0.978376152 1.013577895 1.018344139 1.078516969 1.149707585 1.11546471 1.1064485 1.096717667 1.092731707 +0.3 0.601742495 0.603479351 0.625129097 0.684184559 0.636908287 0.696402014 0.69426112 0.629387991 0.710284152 0.654474783 0.723289138 0.597115699 0.624999849 0.730114118 0.65000875 0.661914733 0.706213782 0.698087305 0.667790691 0.752302519 0.657779234 0.656482567 0.758328733 0.765503764 0.677809254 0.778296252 0.782668806 0.799847207 0.799127183 0.808777483 0.696221926 0.849511672 0.745400241 0.787214935 0.792843381 0.822588308 0.716610653 0.727747945 0.767964028 0.77352834 0.769325662 0.775351453 0.712933338 0.716076188 0.863077702 0.907653407 0.899411478 0.707095793 0.720794209 0.578051347 0.64708457 0.715991872 0.644705111 0.663647839 0.693140744 0.759587089 0.792349194 0.830673659 0.852019003 0.870533958 0.848837912 0.880728817 0.925043699 0.982958649 1.053823489 1.150515811 0.897477284 0.890299938 0.916569743 0.924511122 0.958726464 1.00533473 1.004704377 1.047265598 1.066599926 1.167767699 +0.31 0.691325178 0.688663832 0.628669891 0.65491838 0.663987843 0.751553439 0.688352662 0.702859954 0.729196245 0.71441199 0.699700676 0.740500318 0.751189599 0.722165374 0.712277933 0.755404786 0.657165038 0.675756012 0.688696872 0.710015829 0.741099345 0.726088897 0.738567843 0.741869568 0.735708132 0.752197278 0.721989435 0.773039511 0.790930818 0.798099672 0.772176986 0.763454064 0.814542629 0.788119697 0.800648491 0.86548908 0.900786694 0.915652909 0.885149524 0.916602087 0.743414226 0.724258559 0.776407177 0.769619242 0.783242734 0.785181673 0.778085288 0.745941384 0.931431459 0.776349218 0.6841201 0.6477042 0.707736854 0.655150727 0.655398181 0.669597053 0.684718985 0.699750704 0.707376361 0.789209972 0.835426534 0.866383149 0.932554222 0.86486251 0.920933216 0.924261736 1.002879887 1.029418745 1.059004797 1.177111327 0.883100827 0.904485439 0.907080739 0.950152868 0.98376019 0.988611552 +0.32 0.614205464 0.707685564 0.695794145 0.62714604 0.66354386 0.66760916 0.669178508 0.712956905 0.67331129 0.71034524 0.741519961 0.725952266 0.747104891 0.750470076 0.758894629 0.771347977 0.778942132 0.780919574 0.790540829 0.668344354 0.736686777 0.692279822 0.755144229 0.749198723 0.728248001 0.733540697 0.721989109 0.72605721 0.724460508 0.731752048 0.773062974 0.764405706 0.709526184 0.697809301 0.775846911 0.775440096 0.763850399 0.731742253 0.850527843 0.846444996 0.904891414 0.822097352 0.807234695 0.927121489 0.956917325 0.941331298 0.871427455 0.942177727 0.786615297 0.762742709 0.803596724 0.785857527 0.781854357 0.719171507 0.629905105 0.708443564 0.65697022 0.708995178 0.718098798 0.78044221 0.781739715 0.8154921 0.868462144 0.870102287 0.870338506 0.878094373 0.886983713 0.878299735 0.89785296 1.005292509 1.002907939 1.021172004 1.16909946 0.8985822 0.912158462 0.946479683 +0.33 0.618420958 0.627870203 0.653063536 0.675510681 0.654044285 0.6865992 0.657468191 0.618668143 0.690746854 0.701296762 0.724371023 0.677683994 0.721116581 0.733218909 0.746587992 0.728400247 0.733639341 0.744518313 0.646803952 0.765147657 0.773164938 0.766223764 0.757983877 0.742250963 0.794378482 0.754158674 0.697215021 0.684813648 0.751123424 0.755727298 0.696574592 0.682785986 0.694155016 0.688118193 0.716716827 0.765746155 0.754270772 0.738291004 0.719440101 0.770212775 0.757561602 0.745670819 0.765796057 0.766801745 0.749333244 0.674019385 0.877433771 0.73322129 0.746184313 0.899737179 0.940869412 0.950801596 0.806305219 0.796186736 0.786324364 0.763002814 0.92241866 0.612922368 0.674136589 0.616622786 0.717908227 0.73568542 0.778011254 0.789131949 0.856289482 0.874952745 0.87987338 0.836837514 0.856067715 0.897723723 0.923809497 0.984182292 0.937990451 0.986306689 1.051171448 1.093657488 +0.34 0.623976778 0.62587934 0.61653257 0.688538716 0.644632417 0.642128912 0.683704788 0.682506137 0.610627814 0.694173297 0.666782693 0.622050799 0.604578851 0.716472335 0.746469824 0.638293309 0.645003997 0.644799494 0.648254115 0.567809208 0.685201096 0.525021296 0.642873748 0.650244212 0.657518189 0.622598744 0.671567179 0.529157179 0.62072265 0.678586252 0.72577096 0.674728141 0.631530444 0.617052654 0.620691192 0.687220199 0.630240305 0.630452424 0.641107476 0.676193963 0.638672301 0.612295257 0.636129398 0.62392038 0.692508496 0.627869713 0.642512681 0.652907644 0.593657452 0.71965915 0.745270716 0.736428925 0.750084732 0.716931971 0.738495415 0.922615601 0.943732058 0.820328149 0.823466355 0.83462419 0.908485964 0.625104623 0.713549408 0.722603766 0.742024454 0.767882723 0.779294859 0.812431058 0.892380516 0.88509637 0.815386939 0.881032787 0.997350749 0.964831098 1.013619774 1.068314726 +0.08 0.59946483 0.62466518 0.64115396 0.666850459 0.696117432 0.717627097 0.74087016 0.760716683 0.781089846 0.804275977 0.824570017 0.846120295 0.864545464 0.885557237 0.909352477 0.929262528 0.947116448 0.96286081 0.967209603 0.985834674 1.003689012 1.017730036 1.032072324 1.050534615 1.066174999 1.075637777 1.101602666 1.117614083 1.134563868 1.147788913 1.159258849 1.177976364 1.196524108 1.224050591 1.241773981 1.2656591 1.261187443 1.281158465 1.285816865 1.205875674 1.290542215 1.308435172 1.322762129 1.318352314 1.299106093 1.274909095 0.989418035 0.932090919 0.853466599 0.846094087 0.838995432 0.872481261 0.889740226 0.980815859 1.300537907 1.238393931 1.239621232 1.239771997 1.237183433 1.236921163 1.232863391 1.234718774 1.228788064 1.221935878 1.223619154 1.207667074 1.193440516 1.201778314 1.200433475 1.198735584 1.192203771 1.174757501 1.146414208 1.128386521 1.130287695 1.140840156 +0.35 0.586546748 0.607600414 0.607399902 0.669907955 0.684005414 0.64703291 0.649272715 0.630045833 0.601904461 0.729417362 0.618297752 0.515615925 0.580925389 0.61044274 0.616127515 0.620558943 0.572541324 0.632188649 0.559202322 0.600559192 0.586677533 0.569775981 0.568757384 0.576966256 0.520147011 0.594659399 0.599512181 0.605220512 0.57553778 0.605152804 0.615698444 0.619697916 0.612602067 0.600065335 0.603638 0.56298872 0.604585971 0.611899561 0.611178798 0.647747727 0.598559372 0.633526481 0.599600157 0.656361564 0.651926498 0.664686262 0.60981849 0.617762614 0.61341751 0.555815623 0.615874037 0.65835662 0.64653035 0.65905679 0.680581211 0.690945613 0.680076286 0.682605054 0.702400325 0.772802111 0.929631961 0.976024507 0.960294838 0.975519411 0.749963603 0.66803794 0.676488555 0.747025381 0.76863084 0.77448973 0.812510063 0.863798763 0.88958003 0.85217883 0.871008384 0.926671707 +0.36 0.582105319 0.620203525 0.595675597 0.615235919 0.676602818 0.672769088 0.596384146 0.576457871 0.580546927 0.582846099 0.552126541 0.559349196 0.548369111 0.5261487 0.599726462 0.589684705 0.527610786 0.569816597 0.575722989 0.570861575 0.574357161 0.570810207 0.549642555 0.59310751 0.575570652 0.59335692 0.584586369 0.598704264 0.615098178 0.60313227 0.618614596 0.569811241 0.618544278 0.600811 0.62189145 0.571429572 0.639655083 0.611037605 0.62484862 0.606930767 0.612233416 0.615039493 0.621304805 0.627376886 0.625727875 0.630269726 0.625080346 0.634116147 0.628572679 0.636574382 0.644269025 0.63270565 0.632616816 0.639173439 0.638830095 0.633455233 0.634583351 0.619583561 0.628239559 0.668302574 0.611548543 0.671348449 0.718132776 0.697446449 0.657728717 0.958350298 0.920179034 0.735842592 0.614291609 0.689014849 0.74422803 0.767892257 0.762852671 0.769424911 0.828432998 0.93618781 +0.37 0.595585147 0.636937671 0.657249378 0.679226954 0.667798139 0.655206629 0.668939678 0.655475066 0.671308595 0.614524282 0.535319893 0.55504848 0.683033358 0.539833072 0.55870436 0.562571851 0.565393616 0.536482482 0.564631318 0.564938191 0.59469886 0.577244259 0.570163352 0.569247941 0.543088035 0.583966998 0.602551167 0.554973551 0.550199982 0.604915511 0.629014402 0.613003129 0.605794144 0.62393946 0.618766164 0.616079625 0.60759177 0.626396291 0.65414182 0.629547263 0.64062036 0.643858075 0.652604389 0.64197826 0.638830882 0.631812979 0.659330719 0.656659297 0.645210329 0.669227771 0.655328236 0.647500022 0.656614308 0.659339843 0.664816609 0.65145111 0.662379913 0.659529067 0.666568386 0.649028098 0.646494806 0.656478324 0.644835012 0.642713687 0.619904512 0.62572982 0.641212543 0.655696138 0.653825495 0.748785114 0.665796842 0.660176664 0.631278907 0.70862235 0.763453757 0.754959809 +0.38 0.592833011 0.704818609 0.594451338 0.658198188 0.668121997 0.591102385 0.666120909 0.657161686 0.615719669 0.688786141 0.667952525 0.707683138 0.544112818 0.683055676 0.582407709 0.543954131 0.541405034 0.570448258 0.540448955 0.579295201 0.548562967 0.499830947 0.559108226 0.554845749 0.593442613 0.60473461 0.603224553 0.607367833 0.567599805 0.620938273 0.618081364 0.624898044 0.620198135 0.623372126 0.619234557 0.651889823 0.629775501 0.623686401 0.653599589 0.651741813 0.664812911 0.655072947 0.676042183 0.662819006 0.641532248 0.654926284 0.67356215 0.679385573 0.682306518 0.676609435 0.691135016 0.687256563 0.686095296 0.688552052 0.694754543 0.691794814 0.683760299 0.691281868 0.705116343 0.697870762 0.692574116 0.685269471 0.690198214 0.670911463 0.672783196 0.679814919 0.685106172 0.665107309 0.660139545 0.661559698 0.633551932 0.636501406 0.611955303 0.616952102 0.600244953 0.62973723 +0.39 0.624291332 0.63520331 0.578243162 0.588527991 0.600472857 0.601013841 0.601208169 0.601262058 0.609480317 0.69054831 0.656203602 0.701464526 0.718004324 0.55820416 0.718117702 0.695073883 0.554740013 0.561492885 0.555177719 0.582089784 0.60732467 0.549873952 0.604656355 0.561358873 0.585205693 0.570703292 0.620189898 0.626415045 0.617631104 0.606474154 0.626300549 0.641229123 0.646901938 0.654754546 0.653074986 0.658729884 0.65782185 0.654535819 0.687855468 0.66593057 0.672567539 0.667623129 0.678930234 0.679406631 0.69220033 0.684917985 0.737824713 0.701119696 0.760717056 0.760219946 0.744243189 0.728124662 0.725514724 0.759910968 0.624294121 0.737241282 0.708364815 0.620809946 0.748937777 0.713320022 0.717672944 0.627949035 0.712905968 0.626627608 0.620764261 0.710674989 0.64517188 0.71513794 0.61342383 0.710509047 0.708874684 0.697725623 0.689418015 0.684364332 0.678473755 0.669688182 +0.4 0.617765296 0.595676994 0.606912037 0.67214894 0.592497339 0.667369494 0.643868999 0.59998374 0.673551406 0.689739178 0.664077633 0.687559775 0.696867061 0.697273763 0.681616294 0.659918141 0.653705494 0.659734005 0.57599348 0.538519658 0.596516224 0.613670139 0.568928627 0.632861642 0.615991865 0.647841019 0.611644164 0.621488961 0.653212381 0.652318303 0.656461107 0.653763842 0.64365145 0.648935696 0.660257686 0.671203667 0.670128164 0.677656963 0.678827035 0.667452846 0.683359575 0.678396317 0.693144951 0.702615992 0.703784185 0.665034304 0.729831262 0.741024168 0.624720043 0.637620309 0.623404555 0.734424222 0.654081735 0.658170985 0.636684359 0.630465348 0.653744027 0.64278867 0.649967738 0.658764968 0.681495753 0.674167519 0.666158069 0.664750747 0.656104729 0.648324144 0.675399808 0.651204948 0.683961375 0.631396311 0.645780133 0.625066493 0.624758761 0.621272983 0.751337806 0.773827184 +0.48 0.564871885 0.618157026 0.56983434 0.620567824 0.593076248 0.62667109 0.616667607 0.578653642 0.626996906 0.569953012 0.598133422 0.541354529 0.525511475 0.59338106 0.531979281 0.572112955 0.569999358 0.563659317 0.578363514 0.711357457 0.571104433 0.570808825 0.571212983 0.564056038 0.547238396 0.54173214 0.565230028 0.53993908 0.585180219 0.57365615 0.545622746 0.58699044 0.547747906 0.535047603 0.57845033 0.633419506 0.616053038 0.646312429 0.644428056 0.65497233 0.635782687 0.662241118 0.675308138 0.66822929 0.687769251 0.592535892 0.699682213 0.590381302 0.59477312 0.593390421 0.580520358 0.598022044 0.592623865 0.586044036 0.584743662 0.590352781 0.585628542 0.586064491 0.616298276 0.607514771 0.769014329 0.771914569 0.64082694 0.775994775 0.75132219 0.652609621 0.68170686 0.674776472 0.6851653 0.686685581 0.705217064 0.712118938 0.75410904 0.766109716 0.758109419 0.78672955 +0.49 0.558841673 0.654356753 0.612012669 0.614411561 0.60257284 0.600968786 0.651718826 0.556500512 0.56282105 0.565140434 0.559060181 0.532172237 0.54612094 0.571170876 0.565051531 0.571714805 0.570050173 0.584659598 0.569861418 0.690101768 0.576516388 0.603658865 0.564024481 0.577077442 0.550960302 0.530642723 0.542788292 0.588029293 0.579490949 0.571823482 0.58723218 0.594129322 0.543524808 0.550644479 0.624449113 0.615045467 0.626085205 0.626916069 0.650367095 0.665218934 0.673521462 0.677132101 0.681014716 0.690095369 0.692805358 0.592479582 0.579041264 0.713596152 0.586624196 0.595784275 0.591513299 0.599595623 0.600664767 0.588217105 0.599090703 0.599844626 0.586037322 0.596935039 0.592559586 0.598022731 0.584413348 0.615564674 0.772309413 0.597564227 0.608722883 0.641022628 0.662570487 0.662232783 0.671345521 0.676524351 0.699929379 0.695369357 0.706922255 0.726102911 0.768779755 0.769635352 +0.5 0.562448457 0.565796802 0.631249643 0.652590714 0.575723906 0.648310924 0.56764643 0.589269263 0.570321931 0.566502361 0.591499302 0.547807542 0.567977957 0.570014744 0.566086688 0.536914794 0.579831688 0.56536537 0.579848915 0.575226731 0.579078688 0.574010385 0.592322334 0.597860223 0.522078987 0.524415889 0.502816514 0.542240105 0.545402687 0.576039545 0.580050983 0.549264545 0.586976788 0.61038385 0.620770262 0.596487721 0.634297241 0.599208397 0.653241749 0.667333959 0.67156429 0.687863394 0.689708141 0.6964109 0.700944058 0.59934155 0.595692533 0.603148577 0.608511337 0.612046395 0.610929704 0.596507246 0.698906841 0.60708681 0.600811445 0.611439312 0.60043717 0.602879905 0.602631421 0.615781052 0.599357877 0.597549144 0.609136407 0.605334934 0.619037922 0.604843015 0.785101786 0.617582175 0.78103712 0.641139155 0.6751528 0.706328072 0.703973487 0.703331817 0.705185906 0.736870346 +0.1 0.747092716 0.590847925 0.617432157 0.657462195 0.68658337 0.725229864 0.756821315 0.768952733 0.789060772 0.80521337 0.825468624 0.850788403 0.856828779 0.870514687 0.883968747 0.893920331 0.911877832 0.924943001 0.937657385 0.943427794 0.958872384 0.968817179 0.965429239 0.977260803 0.992368737 1.008846438 1.02128706 1.032905711 1.045226533 1.055484209 1.06515566 1.077873233 1.093423618 1.103091622 1.111348004 1.116626874 1.122247727 1.129664595 1.131873568 1.139901239 1.144000585 1.145961992 1.15794172 1.16021917 1.158915203 1.169100639 1.172816557 1.172687935 1.174158421 1.166891953 1.179779213 1.17676629 1.182542691 1.185568304 1.188758311 1.186229544 1.176815171 1.183793386 1.188206482 1.195207909 1.199190148 1.201286516 1.214652216 1.215348441 1.225417391 1.2184015 1.206439116 1.208353253 1.19380644 1.202197431 1.221026635 1.227981611 1.240972821 1.247993435 1.237654485 1.228281837 +0.11 0.718764613 0.599474914 0.607216525 0.606080287 0.631152995 0.676770738 0.715614175 0.756281421 0.795879915 0.8243039 0.840519675 0.854861125 0.861965524 0.877027369 0.890756804 0.900058734 0.91464235 0.924962402 0.935690213 0.94651692 0.954841167 0.950626392 0.954611757 0.956455688 0.971533839 0.976819977 0.985754828 0.997951157 1.005175121 1.01429147 1.027853457 1.041420673 1.05925269 1.066334532 1.076425272 1.08302171 1.085970885 1.095703574 1.096879828 1.099218389 1.106741546 1.107545456 1.111946468 1.115100427 1.115782733 1.117643172 1.119719213 1.12024187 1.118694324 1.112449979 1.113851184 1.108425691 1.104528295 1.103950625 1.114296804 1.118102214 1.134062661 1.14281573 1.152516804 1.163241831 1.168180537 1.176055894 1.188337869 1.181959214 1.196372523 1.202278461 1.209735882 1.213729939 1.216416917 1.202326073 1.232637318 1.218147558 1.214404399 1.223491509 1.232152291 1.234304834 +0.12 0.727328733 0.744119883 0.63537271 0.635815124 0.633347568 0.659543561 0.632300628 0.674820788 0.733220163 0.804302957 0.842945517 0.873393985 0.891406056 0.903370788 0.913936373 0.914229135 0.91604933 0.9319282 0.943329425 0.95196064 0.953495304 0.956409769 0.951290806 0.965808639 0.961864576 0.965280402 0.970442374 0.980372211 0.985688901 0.986862207 0.993814095 1.008081102 1.021677842 1.028463531 1.037675876 1.045662078 1.049197665 1.059389783 1.048106322 1.058308584 1.061568929 1.061050447 1.062528505 1.065663091 1.064015278 1.062968704 1.071434468 1.063838609 1.064998479 1.056926547 1.058183976 1.052410469 1.050222779 1.061876479 1.079888699 1.093549113 1.103455055 1.11279069 1.117263954 1.125671104 1.133213531 1.132136583 1.154296184 1.141415901 1.154939426 1.161661833 1.167010806 1.196724488 1.187802313 1.200160767 1.21104686 1.211451798 1.213185234 1.222367216 1.229947005 1.227428727 +0.13 0.601867127 0.586773483 0.600291418 0.635146605 0.645121482 0.707495856 0.684254478 0.668184034 0.689693157 0.664674857 0.700340081 0.748996 0.828114863 0.888523246 0.926079973 0.949329116 0.958958635 0.944314766 0.952998876 0.962033423 0.962938722 0.959762199 0.96987308 0.961342607 0.973026501 0.974167819 0.977164689 0.983567328 0.977104285 0.981981924 0.974034967 0.989786089 0.990960553 0.995518647 0.996114567 1.00520115 1.010702681 1.007901328 1.015320633 1.015779304 1.021731042 1.017163301 1.021308647 1.024127036 1.022789429 1.015187487 1.014099054 1.021990534 1.002970087 1.000630188 0.990243921 0.997505765 1.012796707 1.035735923 1.051331727 1.072018285 1.083695963 1.08233286 1.085689989 1.098967014 1.091008424 1.098238325 1.126863425 1.134443507 1.130398957 1.145033687 1.152506408 1.156389894 1.160808781 1.167730219 1.171753802 1.176697302 1.184025482 1.190697257 1.193916506 1.203540054 +0.14 0.748886478 0.730068006 0.616407581 0.648475949 0.646272363 0.677748387 0.762755116 0.758991138 0.718126682 0.710406983 0.686522903 0.710485332 0.685580497 0.727819672 0.770873433 0.867372954 0.937301444 0.961584709 0.976990251 0.987841408 0.983628603 0.984615649 0.992312554 0.994464226 0.997010879 0.995940726 1.000981949 1.001100674 0.982449543 0.985647686 0.987429418 0.997894844 1.011496173 0.994963623 0.978200394 0.973162588 0.981309013 0.969742114 0.974289248 0.98953026 0.994711795 0.989358524 0.994100931 0.964802639 0.99856263 0.995508674 0.991980889 0.976675994 0.958807549 0.953772728 0.942155197 0.950045416 0.968090641 0.976174771 1.022999845 1.043522635 1.036042729 1.03888494 1.041200928 1.057821976 1.084040967 1.091453453 1.098481501 1.091168945 1.097286241 1.101914215 1.105847319 1.10876122 1.113515319 1.119597526 1.124727437 1.12849184 1.132505127 1.137228314 1.140367502 1.141837336 +0.45 0.607336366 0.596646638 0.623324487 0.633359331 0.597850415 0.575882228 0.628021282 0.594751274 0.687516034 0.58636833 0.557394492 0.656069409 0.616621563 0.593168102 0.571953128 0.582513912 0.576053971 0.592576076 0.571620533 0.589199281 0.560479749 0.556940451 0.566455539 0.570328001 0.600213825 0.560885072 0.560768733 0.50882076 0.555416408 0.537861252 0.532894881 0.606344007 0.610878136 0.538564205 0.538723307 0.625178525 0.604666708 0.717882192 0.655003663 0.646592943 0.655158454 0.644452545 0.643259285 0.692133312 0.686971344 0.724577433 0.683139874 0.708940322 0.7690132 0.761211383 0.782134994 0.76612838 0.801533421 0.796045685 0.766638276 0.783916686 0.811115154 0.768788868 0.798619042 0.790371328 0.856350716 0.82351104 0.813627623 0.82424673 0.847097865 0.78766966 0.842094283 0.821012753 0.830837675 0.833104706 0.83826088 0.843812904 0.843232431 0.819840247 0.848627137 0.847493043 +0.46 0.612834094 0.615009195 0.630159514 0.583494408 0.579705563 0.587836924 0.584407859 0.593186692 0.653949366 0.647290085 0.569951248 0.571935303 0.568603413 0.593405221 0.527210018 0.567514475 0.574859912 0.582129032 0.564478528 0.56651943 0.578976364 0.586762486 0.549138771 0.614784214 0.564091051 0.555347638 0.512814897 0.566090909 0.540165987 0.537084089 0.544036632 0.543089471 0.53699811 0.600372241 0.608430944 0.633266692 0.608028314 0.65279974 0.645915645 0.633515183 0.65371355 0.654984623 0.662452097 0.680359237 0.682239812 0.588564111 0.718034969 0.594851301 0.585631418 0.695960461 0.696329503 0.780611983 0.599402453 0.734480927 0.797497767 0.609197701 0.851274106 0.828247878 0.79480671 0.846858599 0.851736442 0.884357518 0.745389717 0.685087721 0.667016056 0.822355792 0.864852498 0.901420734 0.830529746 0.950221392 0.898121308 0.869012859 0.935410484 0.879375009 0.881217336 0.885822834 +0.47 0.610298412 0.567974749 0.621728343 0.572176636 0.585935319 0.653759762 0.621083958 0.620812063 0.627082648 0.580066755 0.591531267 0.573102491 0.56874627 0.571112249 0.599337768 0.584192293 0.569085595 0.575663315 0.561855213 0.710263837 0.710851294 0.568877362 0.572090474 0.571443877 0.53223315 0.55437323 0.571911401 0.513109736 0.512072007 0.552348019 0.551398431 0.596140358 0.550251065 0.547603475 0.588527219 0.621597279 0.601051432 0.611600325 0.638869788 0.613537192 0.650396954 0.678336368 0.656103995 0.686207091 0.687841381 0.699109198 0.692686109 0.594178511 0.587822786 0.588240534 0.592290305 0.702290464 0.593046809 0.582072283 0.601803236 0.594065344 0.594705269 0.72603775 0.593915522 0.761696182 0.730290642 0.647770668 0.647858698 0.899885802 0.673907954 0.831964284 0.69483633 0.695154513 0.735000361 0.713838156 0.757331398 0.741037896 0.737003271 0.942975821 0.73229815 0.946013172 +0.42 0.60073815 0.626327768 0.61813395 0.574251234 0.580484144 0.685181134 0.582386464 0.644761016 0.590398177 0.594626572 0.709430651 0.681248513 0.709472343 0.706530202 0.578535759 0.573522119 0.563868564 0.572295752 0.582054736 0.620825486 0.642583142 0.648980006 0.656866793 0.651911177 0.58168907 0.564268089 0.660487807 0.673402173 0.66044355 0.669134937 0.660481677 0.734031654 0.726770451 0.692529089 0.749037481 0.745947944 0.663558491 0.738407542 0.622693142 0.628597232 0.737291338 0.64497676 0.65456268 0.652630791 0.661740611 0.649077517 0.671132223 0.659783562 0.662387181 0.694065635 0.718406339 0.696965382 0.712613526 0.669660214 0.728182562 0.72067191 0.685976444 0.698067333 0.701362823 0.763910845 0.741530624 0.740983602 0.751575781 0.813831228 0.76931737 0.718780727 0.792156343 0.777750759 0.791762693 0.802952022 0.777494449 0.716257194 0.768248485 0.751861294 0.770082398 0.7711934 +0.43 0.595145706 0.60302457 0.619972538 0.646573564 0.577274758 0.581718241 0.588003732 0.656963502 0.693678591 0.640511035 0.695307031 0.60556636 0.568278648 0.697686627 0.556110828 0.571112008 0.579015978 0.579572861 0.588304509 0.549694135 0.643985945 0.570168592 0.564929586 0.57461237 0.5783386 0.667851978 0.563424244 0.560546132 0.662583923 0.677812457 0.571818449 0.642311224 0.510846383 0.611918455 0.614382448 0.61237248 0.607253359 0.621940123 0.646120977 0.623185082 0.607718376 0.639085034 0.654178767 0.649767542 0.682509912 0.680252279 0.730491782 0.676925199 0.722503086 0.726713362 0.721660788 0.751365296 0.749646418 0.787935786 0.788590938 0.766080843 0.752596013 0.782413938 0.792935968 0.809031322 0.796061434 0.814749252 0.805056675 0.80295323 0.810457668 0.796093616 0.812755194 0.803348552 0.75736775 0.810526153 0.803154665 0.806174223 0.764378794 0.805914234 0.80504915 0.80866908 +0.44 0.591431527 0.607661658 0.624825793 0.63701503 0.578334532 0.583249042 0.580143814 0.588152054 0.596243466 0.603155642 0.600711827 0.692051063 0.652601714 0.566131675 0.583052016 0.579646836 0.565504867 0.53172174 0.548281149 0.588025623 0.558883359 0.575720146 0.570710875 0.583884708 0.573714315 0.567342573 0.557573953 0.582681252 0.554120876 0.497941403 0.530849822 0.531323163 0.616521952 0.589070867 0.587647715 0.596157587 0.638708344 0.637172155 0.674264342 0.663362381 0.63765902 0.662241823 0.63234751 0.661836422 0.662733293 0.680349843 0.684638602 0.698176532 0.719272081 0.750880503 0.71800663 0.754956366 0.780647839 0.787337362 0.774210961 0.759398115 0.788684801 0.761406448 0.799091498 0.820348172 0.783862862 0.792193349 0.794269699 0.794422618 0.799149107 0.822964159 0.810804922 0.812912909 0.81598853 0.821613906 0.811475757 0.80634565 0.805055153 0.782942577 0.810187756 0.793903032 +0.09 0.596538035 0.625069312 0.645460278 0.672991228 0.702029468 0.728526636 0.744991217 0.76241932 0.777531009 0.799178424 0.820307472 0.837222501 0.851724551 0.872162823 0.883935997 0.902890579 0.917373302 0.934540538 0.950370294 0.964459243 0.96412249 0.979381885 0.994877952 1.008583257 1.022977279 1.035159106 1.054731139 1.064920984 1.075309081 1.087892623 1.104720544 1.121528798 1.14606314 1.150958531 1.157456098 1.18504402 1.189461314 1.194722754 1.196936276 1.201613689 1.180256059 1.214436328 1.216189046 1.222237731 1.226484867 1.229942503 1.235186629 1.249791667 1.25063787 1.252734132 1.249424579 1.250225376 1.25165446 1.260442893 1.261156695 1.255134293 1.251315625 1.253844092 1.253015676 1.252175585 1.232453444 1.230006834 1.221369727 1.225072441 1.213266955 1.211952449 1.227619884 1.238857421 1.233353471 1.220974435 1.217326201 1.211256508 1.206030585 1.208002138 1.200204997 1.190296262 +0.41 0.619459584 0.627420422 0.61704851 0.680315574 0.676232938 0.6784373 0.603456204 0.699035925 0.671664135 0.671487655 0.681012424 0.675726082 0.672842609 0.707628195 0.717966721 0.655151405 0.564010832 0.595345308 0.594038618 0.627402002 0.593653818 0.631823948 0.646044618 0.599277034 0.628341952 0.655970211 0.62407952 0.648483097 0.658462558 0.647605758 0.661530793 0.671136451 0.672930617 0.683101898 0.689614277 0.678882844 0.681540214 0.672623252 0.752606 0.752673317 0.753770182 0.74938198 0.764383541 0.74295999 0.731723496 0.675362515 0.650845147 0.640671586 0.6573916 0.633689832 0.655195313 0.683538269 0.671070249 0.709917196 0.673618755 0.699074956 0.705009319 0.695062781 0.719316689 0.683813081 0.694908773 0.741490571 0.692571743 0.736891087 0.750672096 0.757822227 0.691738843 0.700604521 0.739444719 0.736910087 0.733084289 0.696290329 0.703985478 0.736943961 0.712480183 0.693505923 diff --git a/posydon/unit_tests/_helper_functions_for_tests/population.py b/posydon/unit_tests/_helper_functions_for_tests/population.py new file mode 100644 index 0000000000..2813e3df19 --- /dev/null +++ b/posydon/unit_tests/_helper_functions_for_tests/population.py @@ -0,0 +1,322 @@ +"""Helper function(s) for tests requiring a POSYDON Population + +used in: + - posydon/unit_tests/popsyn/test_synthetic_population.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +import os + +import h5py +import numpy as np +import pandas as pd + +from posydon.popsyn.rate_calculation import ( + DEFAULT_SFH_MODEL, + get_cosmic_time_from_redshift, + get_redshift_bin_centers, +) +from posydon.popsyn.synthetic_population import Population, Rates, TransientPopulation + +# helper functions + +def make_ini(tmp_path,content=None): + """ + Create a minimal dummy .ini file inside the pytest tmp_path using os.path. + + Parameters + ---------- + tmp_path : pathlib.Path + pytest temporary directory. + content : str, optional + Content to write to the ini file. + + Returns + ------- + str + Path (string) to the created dummy ini file. + """ + dir_path = str(tmp_path) + ini_path = os.path.join(dir_path, "dummy.ini") + + if content is None: + content = "[DEFAULT]\nkey=value\n" + + with open(ini_path, "w") as f: + f.write(content) + + return str(ini_path) + +def make_test_pop( + tmp_path, + filename="test_population.h5", + oneline_rows=None, + history_rows=None, + metallicity=0.02): + """ + Create a minimally valid synthetic population HDF5 file and return a + fully initialized Population object. This centralizes and standardizes + population generation across unit tests. + + Parameters + ---------- + tmp_path : Path-like + Directory in which the HDF5 file will be created. + filename : str + Name of the file to write. + oneline_rows : list[dict], optional + Rows for the /oneline table. + history_rows : list[dict], optional + Rows for the /history table. + metallicity : float + Metallicty value for oneline and mass_per_metallicity tables. + + Returns + ------- + Population + A fully initialized Population instance. + """ + + # history and oneline tables + + if history_rows is None: + history_rows = [{"binary_index": 0, "event": "start", "time": 0.0}, + {"binary_index": 0, "event": "end", "time": 1.0}, + {"binary_index": 1, "event": "start", "time": 0.0}, + {"binary_index": 1, "event": "end", "time": 1.0}] + + if oneline_rows is None: + oneline_rows = [{"binary_index": 0, + "S1_mass_i": 1.0, + "S2_mass_i": 1.0, + "state_i": "initial", + "metallicity": metallicity, + "interp_class_HMS_HMS": "initial_MT", + "mt_history_HMS_HMS": "Stable"}, + {"binary_index": 1, + "S1_mass_i": 1.0, + "S2_mass_i": 1.0, + "state_i": "initial", + "metallicity": metallicity, + "interp_class_HMS_HMS": "no_MT", + "mt_history_HMS_HMS": None}] + + # Convert to DataFrames + oneline_df = pd.DataFrame(oneline_rows).sort_values("binary_index") + history_df = pd.DataFrame(history_rows).sort_values(["binary_index", "time"]) + + # history_lengths = number of rows per binary_index + history_lengths_df = history_df.groupby("binary_index").size().to_frame("length") + + # ini_parameters – include all keys that _load_ini_params expects + ini_params = { + "metallicity": metallicity, + "number_of_binaries": len(oneline_df), + "binary_fraction_scheme": "const", + "binary_fraction_const": 0.7, + "star_formation": "burst", + "max_simulation_time": 13800000000.0, + "primary_mass_scheme": "Kroupa2001", + "primary_mass_min": 0.01, + "primary_mass_max": 200.0, + "secondary_mass_scheme": "flat_mass_ratio", + "secondary_mass_min": 0.0005, + "secondary_mass_max": 200.0, + "orbital_scheme": "period", + "orbital_period_scheme": "Sana+12_period_extended", + "orbital_period_min": 0.35, + "orbital_period_max": 6000.0, + "orbital_separation_scheme": "log_uniform", + "orbital_separation_min": 5.0, + "orbital_separation_max": 100000.0, + "eccentricity_scheme": "zero", + "posydon_version": "test", + } + ini_df = pd.DataFrame({k: [v] for k, v in ini_params.items()}) + + + # mass_per_metallicity + mass_df = pd.DataFrame( + {"simulated_mass": [1.0], "number_of_systems": [len(oneline_df)]}, + index=[metallicity] + ) + + # Write HDF5 file using pandas/HDFStore (Population expects PyTables layout) + fpath = os.path.join(tmp_path, filename) + with pd.HDFStore(fpath, "w") as store: + store.put("oneline", oneline_df, format="table") + store.put("history", history_df, format="table") + store.put("history_lengths", history_lengths_df, format="table") + store.put("ini_parameters", ini_df, format="table") + store.put("mass_per_metallicity", mass_df, format="table") + + # Return fully initialized Population object + return Population(fpath) + +def make_test_transient_pop( + tmp_path, + transient_name="test_transient", + filename="test_population.h5", + oneline_rows=None, + history_rows=None, + transient_rows=None, + metallicity=0.02): + """ + Create a minimally valid TransientPopulation HDF5 file and return + a fully initialized TransientPopulation object. + + Builds on make_test_pop by adding a /transients/{transient_name} table. + + Parameters + ---------- + tmp_path : Path-like + Directory in which the HDF5 file will be created. + transient_name : str + Name for the transient population key. + filename : str + Name of the file to write. + oneline_rows : list[dict], optional + Rows for the /oneline table. + history_rows : list[dict], optional + Rows for the /history table. + transient_rows : list[dict], optional + Rows for the /transients/{transient_name} table. + metallicity : float + Metallicity value. + + Returns + ------- + TransientPopulation + A fully initialized TransientPopulation instance. + """ + pop = make_test_pop( + tmp_path, + filename=filename, + oneline_rows=oneline_rows, + history_rows=history_rows, + metallicity=metallicity, + ) + + if transient_rows is None: + transient_rows = [ + {"time": 100.0, "metallicity": metallicity, "channel": "ch_A"}, + {"time": 200.0, "metallicity": metallicity, "channel": "ch_B"}, + ] + + transient_df = pd.DataFrame(transient_rows) + + with pd.HDFStore(pop.filename, "a") as store: + store.append( + "transients/" + transient_name, + transient_df, + format="table", + min_itemsize={"channel": 100}, + ) + + return TransientPopulation( + pop.filename, transient_name, verbose=False + ) + + +def make_test_rates( + tmp_path, + transient_name="test_transient", + SFH_identifier="test_SFH", + filename="test_population.h5", + oneline_rows=None, + history_rows=None, + transient_rows=None, + metallicity=0.02, + MODEL=None): + """ + Create a minimally valid Rates HDF5 file and return + a fully initialized Rates object. + + Builds on make_test_transient_pop by adding the rates tables: + MODEL, weights, z_events, and birth. + + Parameters + ---------- + tmp_path : Path-like + Directory in which the HDF5 file will be created. + transient_name : str + Name for the transient population key. + SFH_identifier : str + Name for the star formation history identifier. + filename : str + Name of the file to write. + oneline_rows : list[dict], optional + Rows for the /oneline table. + history_rows : list[dict], optional + Rows for the /history table. + transient_rows : list[dict], optional + Rows for the /transients/{transient_name} table. + metallicity : float + Metallicity value. + MODEL : dict, optional + The SFH model dict. If None, uses DEFAULT_SFH_MODEL. + + Returns + ------- + Rates + A fully initialized Rates instance. + """ + tpop = make_test_transient_pop( + tmp_path, + transient_name=transient_name, + filename=filename, + oneline_rows=oneline_rows, + history_rows=history_rows, + transient_rows=transient_rows, + metallicity=metallicity, + ) + + if MODEL is None: + MODEL = dict(DEFAULT_SFH_MODEL) + else: + # Merge user overrides into a copy of the defaults + merged = dict(DEFAULT_SFH_MODEL) + merged.update(MODEL) + MODEL = merged + + n_transients = len(pd.read_hdf(tpop.filename, "transients/" + transient_name)) + + # Compute birth bins from MODEL + z_birth = get_redshift_bin_centers(MODEL["delta_t"]) + t_birth = get_cosmic_time_from_redshift(z_birth) + nr_of_birth_bins = len(z_birth) + + base_path = "/transients/" + transient_name + "/rates/" + SFH_identifier + "/" + + with pd.HDFStore(tpop.filename, "a") as store: + # MODEL table — mirror _write_MODEL_data logic for dlogZ lists + if (MODEL["dlogZ"] is not None) and (not isinstance(MODEL["dlogZ"], float)): + store.put(base_path + "MODEL", pd.DataFrame(MODEL)) + else: + store.put(base_path + "MODEL", pd.DataFrame(MODEL, index=[0])) + + # birth table + store.put(base_path + "birth", pd.DataFrame({"z": z_birth, "t": t_birth})) + + # weights table: (n_transients x nr_of_birth_bins) with small dummy values + weights = np.full((n_transients, nr_of_birth_bins), 1e-10) + store.append( + base_path + "weights", + pd.DataFrame(data=weights, index=np.arange(n_transients)), + format="table", + ) + + # z_events table: same shape, dummy redshift values + z_events = np.full((n_transients, nr_of_birth_bins), 0.1) + store.append( + base_path + "z_events", + pd.DataFrame(data=z_events, index=np.arange(n_transients)), + format="table", + ) + + return Rates( + tpop.filename, transient_name, SFH_identifier, verbose=False + ) diff --git a/posydon/unit_tests/grids/test_psygrid.py b/posydon/unit_tests/grids/test_psygrid.py index 9d67a387f4..2361b7d909 100644 --- a/posydon/unit_tests/grids/test_psygrid.py +++ b/posydon/unit_tests/grids/test_psygrid.py @@ -68,7 +68,7 @@ def test_dir(self): 'PROPERTIES_ALLOWED', 'PROPERTIES_TO_BE_CONSISTENT',\ 'PROPERTIES_TO_BE_NONE', 'PROPERTIES_TO_BE_SET',\ 'PSyGrid', 'PSyGridIterator', 'PSyRunView', 'Pwarn',\ - 'TERMINATION_FLAG_COLUMNS',\ + 'LazyHDF5', 'TERMINATION_FLAG_COLUMNS',\ 'TERMINATION_FLAG_COLUMNS_SINGLE',\ 'THRESHOLD_CENTRAL_ABUNDANCE',\ 'THRESHOLD_CENTRAL_ABUNDANCE_LOOSE_C', 'TrackDownsampler',\ @@ -969,6 +969,29 @@ def grid_path(self, tmp_path, binary_history, star_history, profile): return get_simple_PSyGrid(tmp_path, 1, binary_history, star_history,\ profile) + @fixture + def grid_path_single(self, tmp_path, star_history, profile): + # a path to a single-star psygrid file for testing + # Create a minimal binary_history for grid creation, then modify + minimal_binary_history = np.array([(1.0, 1.0), (1.1, 1.0e+2)], + dtype=[('period_days', '= default_imf.m_min) + assert np.all(samples <= default_imf.m_max) + + def test_rvs_without_rng(self, default_imf): + """Test random sampling without providing an RNG.""" + samples = default_imf.rvs(size=100) + assert len(samples) == 100 + assert np.all(samples >= default_imf.m_min) + assert np.all(samples <= default_imf.m_max) + +class TestKroupa1993IMF: + @pytest.fixture + def default_kroupa1993(self): + """Fixture for default Kroupa1993 instance.""" + return IMFs.Kroupa1993() + + @pytest.fixture + def custom_kroupa1993(self): + """Fixture for custom Kroupa1993 instance.""" + return IMFs.Kroupa1993(alpha=2.5, m_min=0.05, m_max=150.0) + + def test_initialization_default(self, default_kroupa1993): + """Test default initialization of Kroupa1993 IMF.""" + assert default_kroupa1993.alpha == 2.7 + assert default_kroupa1993.m_min == 0.01 + assert default_kroupa1993.m_max == 200.0 + integral, _ = quad(default_kroupa1993.imf, + default_kroupa1993.m_min, + default_kroupa1993.m_max) + assert np.isclose(integral*default_kroupa1993.norm, 1.0, rtol=1e-5) + + def test_initialization_custom(self, custom_kroupa1993): + """Test custom initialization of Kroupa1993 IMF.""" + assert custom_kroupa1993.alpha == 2.5 + assert custom_kroupa1993.m_min == 0.05 + assert custom_kroupa1993.m_max == 150.0 + integral, _ = quad(custom_kroupa1993.imf, + custom_kroupa1993.m_min, + custom_kroupa1993.m_max) + assert np.isclose(integral*custom_kroupa1993.norm, 1.0, rtol=1e-5) + + def test_repr(self, default_kroupa1993): + """Test string representation.""" + rep_str = default_kroupa1993.__repr__() + assert "Kroupa1993(" in rep_str + assert "alpha=2.7" in rep_str + assert "m_min=0.01" in rep_str + assert "m_max=200.0" in rep_str + + def test_repr_html(self, default_kroupa1993): + """Test HTML representation.""" + html_str = default_kroupa1993._repr_html_() + assert "

Kroupa (1993) IMF

" in html_str + assert "alpha = 2.7" in html_str + assert "m_min = 0.01" in html_str + assert "m_max = 200.0" in html_str + + def test_invalid_mass(self, default_kroupa1993): + """Test that the imf method raises ValueError for invalid mass values.""" + with pytest.raises(ValueError, match="Mass must be positive."): + default_kroupa1993.imf(0) + with pytest.raises(ValueError, match="Mass must be positive."): + default_kroupa1993.imf(-1.0) + with pytest.raises(ValueError, match="Mass must be positive."): + default_kroupa1993.imf([0.0]) + + def test_imf(self, default_kroupa1993): + """Test the imf method for correct values.""" + # Test with an array of mass values + m_values = np.array([1.0, 2.0, 5.0, 10.0]) + expected = m_values ** (-default_kroupa1993.alpha) + computed = default_kroupa1993.imf(m_values) + assert np.allclose(computed, expected) + + # Test with a single mass value + m = 3.0 + expected = m ** (-default_kroupa1993.alpha) + computed = default_kroupa1993.imf(m) + assert np.isclose(computed, expected) + + def test_pdf_within_range(self, default_kroupa1993): + """Test that PDF returns correct values within the mass range.""" + m = np.linspace(default_kroupa1993.m_min, default_kroupa1993.m_max, 100) + pdf_values = default_kroupa1993.pdf(m) + expected_pdf = default_kroupa1993.imf(m) * default_kroupa1993.norm + assert np.allclose(pdf_values, expected_pdf) + + def test_pdf_outside_range(self, default_kroupa1993): + """Test that PDF returns zero for masses outside the mass range.""" + m = np.array([default_kroupa1993.m_min - 0.1, default_kroupa1993.m_max + 0.1]) + pdf_values = default_kroupa1993.pdf(m) + assert np.allclose(pdf_values, 0.0) + + def test_normalization(self, default_kroupa1993): + """Ensure that the integral of the PDF over the range is approximately 1.""" + integral, _ = quad(default_kroupa1993.pdf, + default_kroupa1993.m_min, + default_kroupa1993.m_max) + assert np.isclose(integral, 1.0, rtol=1e-5) + + def test_rvs_default(self, default_kroupa1993): + """Test random sampling from the Kroupa1993 IMF.""" + # Test single sample + rng = np.random.default_rng(42) + sample = default_kroupa1993.rvs(size=1, rng=rng) + assert len(sample) == 1 + assert default_kroupa1993.m_min <= sample[0] <= default_kroupa1993.m_max + + # Test multiple samples + rng = np.random.default_rng(42) + samples = default_kroupa1993.rvs(size=1000, rng=rng) + assert len(samples) == 1000 + assert np.all(samples >= default_kroupa1993.m_min) + assert np.all(samples <= default_kroupa1993.m_max) + + def test_rvs_without_rng(self, default_kroupa1993): + """Test random sampling without providing an RNG.""" + samples = default_kroupa1993.rvs(size=100) + assert len(samples) == 100 + assert np.all(samples >= default_kroupa1993.m_min) + assert np.all(samples <= default_kroupa1993.m_max) + + class TestKroupa2001IMF: @pytest.fixture def default_kroupa(self): @@ -277,6 +412,28 @@ def test_repr_html(self, default_kroupa): html_str = default_kroupa._repr_html_() assert "

Kroupa (2001) IMF

" in html_str + def test_rvs_default(self, default_kroupa): + """Test random sampling from the Kroupa2001 IMF.""" + # Test single sample + rng = np.random.default_rng(42) + sample = default_kroupa.rvs(size=1, rng=rng) + assert len(sample) == 1 + assert default_kroupa.m_min <= sample[0] <= default_kroupa.m_max + + # Test multiple samples + rng = np.random.default_rng(42) + samples = default_kroupa.rvs(size=1000, rng=rng) + assert len(samples) == 1000 + assert np.all(samples >= default_kroupa.m_min) + assert np.all(samples <= default_kroupa.m_max) + + def test_rvs_without_rng(self, default_kroupa): + """Test random sampling without providing an RNG.""" + samples = default_kroupa.rvs(size=100) + assert len(samples) == 100 + assert np.all(samples >= default_kroupa.m_min) + assert np.all(samples <= default_kroupa.m_max) + class TestChabrierIMF: @pytest.fixture def default_chabrier(self): @@ -401,3 +558,25 @@ def test_repr(self, default_chabrier): def test_repr_html(self, default_chabrier): html_str = default_chabrier._repr_html_() assert "

Chabrier IMF

" in html_str + + def test_rvs_default(self, default_chabrier): + """Test random sampling from the Chabrier2003 IMF.""" + # Test single sample + rng = np.random.default_rng(42) + sample = default_chabrier.rvs(size=1, rng=rng) + assert len(sample) == 1 + assert default_chabrier.m_min <= sample[0] <= default_chabrier.m_max + + # Test multiple samples + rng = np.random.default_rng(42) + samples = default_chabrier.rvs(size=1000, rng=rng) + assert len(samples) == 1000 + assert np.all(samples >= default_chabrier.m_min) + assert np.all(samples <= default_chabrier.m_max) + + def test_rvs_without_rng(self, default_chabrier): + """Test random sampling without providing an RNG.""" + samples = default_chabrier.rvs(size=100) + assert len(samples) == 100 + assert np.all(samples >= default_chabrier.m_min) + assert np.all(samples <= default_chabrier.m_max) diff --git a/posydon/unit_tests/popsyn/test_Moes_distributions.py b/posydon/unit_tests/popsyn/test_Moes_distributions.py new file mode 100644 index 0000000000..3a4ae7b9ef --- /dev/null +++ b/posydon/unit_tests/popsyn/test_Moes_distributions.py @@ -0,0 +1,201 @@ +"""Unit tests of posydon/popsyn/Moes_distributions.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.Moes_distributions as totest + +# aliases +np = totest.np + +# import other needed code for the tests, which is not already imported in the +# module you like to test +from pytest import approx, fixture, raises + + +# define test classes collecting several test functions +class TestElements: + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['Moe_17_PsandQs', '__authors__', + '__builtins__', '__cached__', '__doc__', '__file__', + '__loader__', '__name__', '__package__', '__spec__', + 'np', 'newton_cotes', 'quad'] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + +class TestMoe17PsandQs: + + @fixture + def small_model(self): + """Create a Moe_17_PsandQs with small grid for fast testing.""" + return totest.Moe_17_PsandQs( + n_M1=5, n_logP=10, n_q=10, n_e=20, + RNG=np.random.default_rng(seed=42)) + + # _idl_tabulate + + def test_idl_tabulate_constant(self, small_model): + """Integral of f=1 from 0 to 1 should be 1.""" + x = np.linspace(0.0, 1.0, 11) + f = np.ones_like(x) + result = small_model._idl_tabulate(x, f) + assert result == approx(1.0, abs=1e-10) + + def test_idl_tabulate_linear(self, small_model): + """Integral of f=x from 0 to 1 should be 0.5.""" + x = np.linspace(0.0, 1.0, 11) + f = x.copy() + result = small_model._idl_tabulate(x, f) + assert result == approx(0.5, abs=1e-10) + + def test_idl_tabulate_quadratic(self, small_model): + """Integral of f=x^2 from 0 to 1 should be 1/3.""" + x = np.linspace(0.0, 1.0, 21) + f = x**2 + result = small_model._idl_tabulate(x, f) + assert result == approx(1.0 / 3.0, abs=1e-6) + + def test_idl_tabulate_single_point(self, small_model): + """Single point: integral over zero range should be 0.""" + x = np.array([1.0]) + f = np.array([5.0]) + result = small_model._idl_tabulate(x, f) + assert result == approx(0.0) + + # __init__ + + def test_init_grid_shapes(self, small_model): + """Verify grid dimensions match requested sizes.""" + assert small_model.numM1 == 5 + assert small_model.numlogP == 10 + assert small_model.numq == 10 + assert small_model.nume == 20 + assert small_model.M1v.shape == (5,) + assert small_model.logPv.shape == (10,) + assert small_model.qv.shape == (10,) + assert small_model.ev.shape == (20,) + assert small_model.flogP_sq.shape == (10, 5) + assert small_model.cumqdist.shape == (10, 10, 5) + assert small_model.cumedist.shape == (20, 10, 5) + assert small_model.probbin.shape == (10, 5) + assert small_model.cumPbindist.shape == (10, 5) + + def test_init_mass_range(self, small_model): + """M1v should span 0.8 to 40 Msun.""" + assert small_model.M1v[0] == approx(0.8, abs=1e-10) + assert small_model.M1v[-1] == approx(40.0, abs=1e-10) + + def test_init_q_range(self, small_model): + """qv should span 0.1 to 1.0.""" + assert small_model.qv[0] == approx(0.1) + assert small_model.qv[-1] == approx(1.0) + + def test_init_cumulative_distributions(self, small_model): + """Cumulative distributions should end at 1.0.""" + # cumqdist should reach 1.0 at q=1.0 for each (logP, M1) + for i in range(small_model.numM1): + for j in range(small_model.numlogP): + assert small_model.cumqdist[-1, j, i] == approx(1.0, abs=1e-6) + assert small_model.cumedist[-1, j, i] == approx(1.0, abs=1e-6) + + def test_init_default_params(self): + """Test with default parameters (expensive — just verify it constructs).""" + # Use non-default but still small grid to confirm kwarg handling + model = totest.Moe_17_PsandQs( + n_M1=3, n_logP=5, n_q=5, n_e=10) + assert model.numM1 == 3 + + # __call__ + + def test_call_single_mass(self, small_model): + """Generate sample for a single primary mass.""" + M2, P, e, Z = small_model(10.0) + assert len(M2) == 1 + assert len(P) == 1 + assert len(e) == 1 + assert len(Z) == 1 + assert P[0] > 0 + assert Z[0] > 0 + + def test_call_array(self, small_model): + """Generate samples for multiple primary masses.""" + M1 = np.array([5.0, 10.0, 20.0]) + M2, P, e, Z = small_model(M1) + assert len(M2) == 3 + assert len(P) == 3 + assert len(e) == 3 + assert len(Z) == 3 + + def test_call_all_binaries_true(self): + """With all_binaries=True, no single stars should be produced.""" + model = totest.Moe_17_PsandQs( + n_M1=5, n_logP=10, n_q=10, n_e=20, + RNG=np.random.default_rng(seed=42)) + M1 = np.array([10.0] * 20) + M2, P, e, Z = model(M1, all_binaries=True) + # all_binaries=True means mybinfrac=1.0, so no NaN values + assert not np.any(np.isnan(M2)) + assert not np.any(np.isnan(P)) + + def test_call_all_binaries_false(self): + """With all_binaries=False, some single stars may be produced.""" + model = totest.Moe_17_PsandQs( + n_M1=5, n_logP=10, n_q=10, n_e=20, + RNG=np.random.default_rng(seed=0)) + M1 = np.array([1.0] * 50) + M2, P, e, Z = model(M1, all_binaries=False) + # With 50 draws at M1=1.0, expect some single stars (NaN) + # Z is always set, never NaN + assert not np.any(np.isnan(Z)) + assert len(M2) == 50 + + def test_call_high_mass(self, small_model): + """M1 > 40 Msun should adopt binary statistics of M1 = 40 Msun.""" + M2, P, e, Z = small_model(80.0) + assert len(M2) == 1 + assert P[0] > 0 + + def test_call_low_mass(self): + """M1 < 0.8 Msun should rescale binary fraction.""" + model = totest.Moe_17_PsandQs( + n_M1=5, n_logP=10, n_q=10, n_e=20, + RNG=np.random.default_rng(seed=42)) + M2, P, e, Z = model(0.5, M_min=0.08, all_binaries=False) + assert len(M2) == 1 + assert Z[0] > 0 + + def test_call_low_mass_q_truncation(self): + """M1 < 0.8 inside the binary path should truncate q distribution.""" + model = totest.Moe_17_PsandQs( + n_M1=5, n_logP=10, n_q=10, n_e=20, + RNG=np.random.default_rng(seed=42)) + M2, P, e, Z = model(0.5, M_min=0.08, all_binaries=True) + assert len(M2) == 1 + # M2 = M1 * q, and q >= q_min = M_min/M1 = 0.08/0.5 = 0.16 + assert M2[0] >= 0.08 * 0.99 # M2 >= M_min (small tolerance) + + def test_call_metallicity_range(self, small_model): + """Metallicities should be within the expected range.""" + M1 = np.array([10.0] * 100) + _, _, _, Z = small_model(M1) + Zsun = 0.02 + Z_min = Zsun * 10**(-2.3) + Z_max = Zsun * 10**(0.176) + assert all(Z >= Z_min * 0.99) # small tolerance + assert all(Z <= Z_max * 1.01) diff --git a/posydon/unit_tests/popsyn/test_binarypopulation.py b/posydon/unit_tests/popsyn/test_binarypopulation.py new file mode 100644 index 0000000000..640fc1fbf1 --- /dev/null +++ b/posydon/unit_tests/popsyn/test_binarypopulation.py @@ -0,0 +1,562 @@ +"""Unit tests of posydon/popsyn/binarypopulation.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.binarypopulation as totest + +# aliases +np = totest.np +pd = totest.pd +os = totest.os + +from inspect import isclass, isroutine + +# import other needed code for the tests, which is not already imported in the +# module you like to test +from pytest import approx, fixture, raises + +from posydon.binary_evol.binarystar import BinaryStar +from posydon.binary_evol.simulationproperties import SimulationProperties +from posydon.binary_evol.singlestar import SingleStar + + +# define test classes collecting several test functions +class TestElements: + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['BinaryPopulation', 'PopulationManager', + 'BinaryGenerator', + 'saved_ini_parameters', + 'HISTORY_MIN_ITEMSIZE', 'ONELINE_MIN_ITEMSIZE', + 'STEP_NAMES_LOADING_GRIDS', + 'default_kwargs', + '__authors__', '__credits__', + '__builtins__', '__cached__', '__doc__', '__file__', + '__loader__', '__name__', '__package__', '__spec__', + 'np', 'pd', 'os', 'atexit', 'signal', 'traceback', + 'psutil', 'tqdm', + 'posydon', 'BinaryStar', 'SimulationProperties', + 'SingleStar', 'properties_massless_remnant', + 'generate_independent_samples', + 'binarypop_kwargs_from_ini', 'simprop_kwargs_from_ini', + 'get_kick_samples_from_file', 'get_samples_from_file', + 'get_formation_times', + 'orbital_period_from_separation', + 'orbital_separation_from_period', + 'set_binary_to_failed', + 'Zsun', 'POSYDONError', + 'Catch_POSYDON_Warnings', 'Pwarn', + ] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + + +class TestBinaryGenerator: + + @fixture + def generator(self): + """Create a BinaryGenerator with seeded RNG.""" + rng = np.random.default_rng(seed=42) + return totest.BinaryGenerator(RNG=rng, metallicity=1.0, + star_formation='burst', + max_simulation_time=13.8e9) + + @fixture + def kick_csv(self, tmp_path): + """CSV with kick columns for file-based sampling.""" + df = pd.DataFrame({ + 'm1': [10.0, 20.0], + 'm2': [5.0, 10.0], + 'orbital_period': [10.0, 20.0], + 'eccentricity': [0.1, 0.2], + 's1_natal_kick_velocity': [100.0, 200.0], + 's1_natal_kick_azimuthal_angle': [0.5, 1.0], + 's1_natal_kick_polar_angle': [0.3, 0.6], + 's1_natal_kick_mean_anomaly': [0.1, 0.2], + 's2_natal_kick_velocity': [50.0, 150.0], + 's2_natal_kick_azimuthal_angle': [0.4, 0.8], + 's2_natal_kick_polar_angle': [0.2, 0.5], + 's2_natal_kick_mean_anomaly': [0.05, 0.15], + }) + path = os.path.join(tmp_path, "kicks.csv") + df.to_csv(path, index=False) + return str(path) + + def test_init_default_rng(self): + gen = totest.BinaryGenerator() + assert isinstance(gen.RNG, np.random.Generator) + assert gen.entropy is not None + assert gen._num_gen == 0 + + def test_init_custom_rng(self, generator): + assert isinstance(generator.RNG, np.random.Generator) + assert generator._num_gen == 0 + assert generator.star_formation == 'burst' + assert generator.Z_div_Zsun == 1.0 + + def test_init_bad_rng(self): + with raises(AssertionError): + totest.BinaryGenerator(RNG="not_a_generator") + + def test_draw_initial_samples_separation(self, generator): + output = generator.draw_initial_samples( + orbital_scheme='separation', number_of_binaries=5) + assert len(output['S1_mass']) == 5 + assert len(output['separation']) == 5 + assert len(output['orbital_period']) == 5 + assert all(output['binary_index'] == np.arange(0, 5)) + assert generator._num_gen == 5 + + def test_draw_initial_samples_period(self, generator): + output = generator.draw_initial_samples( + orbital_scheme='period', number_of_binaries=3) + assert len(output['S1_mass']) == 3 + assert all(np.isfinite(output['S1_mass'])) + + def test_draw_initial_samples_bad_scheme(self): + """Cover the else branch in draw_initial_samples itself.""" + rng = np.random.default_rng(seed=42) + def mock_sampler(orbital_scheme, **kwargs): + # Returns data without validating scheme + return np.array([100.0]), np.array([0.1]), np.array([10.0]), np.array([5.0]) + gen = totest.BinaryGenerator( + RNG=rng, metallicity=1.0, sampler=mock_sampler, + star_formation='burst', max_simulation_time=13.8e9) + with raises(ValueError, match="Allowed orbital schemes"): + gen.draw_initial_samples(orbital_scheme='invalid', number_of_binaries=1) + + def test_draw_initial_samples_no_number_of_binaries(self, generator): + """When number_of_binaries not in kwargs, defaults to 1 for kicks.""" + output = generator.draw_initial_samples(orbital_scheme='separation') + assert len(output['S1_mass']) == 1 + + def test_draw_initial_samples_from_file(self, kick_csv): + """Kick values read from file.""" + rng = np.random.default_rng(seed=42) + gen = totest.BinaryGenerator( + RNG=rng, metallicity=1.0, + sampler=totest.get_samples_from_file, + star_formation='burst', + max_simulation_time=13.8e9) + output = gen.draw_initial_samples( + orbital_scheme='period', + number_of_binaries=2, + read_samples_from_file=kick_csv) + assert output['S1_natal_kick_velocity'][0] == 100.0 + assert output['S2_natal_kick_velocity'][1] == 150.0 + + def test_draw_initial_binary(self, generator): + binary = generator.draw_initial_binary( + orbital_scheme='separation', metallicity=1.0) + assert isinstance(binary, BinaryStar) + assert isinstance(binary.star_1, SingleStar) + assert isinstance(binary.star_2, SingleStar) + assert binary.star_1.mass > 0 + assert binary.event == 'ZAMS' + assert binary.state == 'detached' + + def test_draw_initial_binary_single_star(self): + """With binary_fraction_const=0, all draws produce single stars.""" + rng = np.random.default_rng(seed=42) + gen = totest.BinaryGenerator( + RNG=rng, metallicity=1.0, + star_formation='burst', + max_simulation_time=13.8e9, + binary_fraction_const=0.0, + binary_fraction_scheme='const') + binary = gen.draw_initial_binary( + orbital_scheme='separation', metallicity=1.0, + binary_fraction_const=0.0, binary_fraction_scheme='const') + assert isinstance(binary, BinaryStar) + assert binary.state == 'initially_single_star' + assert np.isnan(binary.separation) + assert np.isnan(binary.orbital_period) + assert np.isnan(binary.eccentricity) + + def test_draw_initial_binary_with_index(self, generator): + binary = generator.draw_initial_binary( + orbital_scheme='separation', metallicity=1.0, index=99) + assert binary.index == 99 + + def test_reset_rng(self, generator): + generator.draw_initial_samples( + orbital_scheme='separation', number_of_binaries=5) + assert generator._num_gen == 5 + generator.reset_rng() + assert generator._num_gen == 0 + + def test_get_original_rng(self, generator): + rng = generator.get_original_rng() + assert isinstance(rng, np.random.Generator) + + def test_get_binary_by_iter(self, generator): + binary = generator.get_binary_by_iter( + n=3, orbital_scheme='separation', metallicity=1.0) + assert isinstance(binary, BinaryStar) + assert generator._num_gen == 0 + + def test_get_binary_by_iter_zero(self, generator): + """n=0 skips the warmup sampling.""" + binary = generator.get_binary_by_iter( + n=0, orbital_scheme='separation', metallicity=1.0) + assert isinstance(binary, BinaryStar) + + def test_num_gen_increments(self, generator): + generator.draw_initial_samples( + orbital_scheme='separation', number_of_binaries=3) + assert generator._num_gen == 3 + output = generator.draw_initial_samples( + orbital_scheme='separation', number_of_binaries=1) + assert output['binary_index'][0] == 3 + + +class TestPopulationManager: + + @fixture + def manager(self): + """Create a PopulationManager with minimal kwargs.""" + return totest.PopulationManager( + RNG=np.random.default_rng(seed=42), + metallicity=1.0, + star_formation='burst', + max_simulation_time=13.8e9) + + @fixture + def dummy_binary(self): + """Create a minimal BinaryStar.""" + s1 = SingleStar(mass=10.0, state='H-rich_Core_H_burning', + metallicity=1.0) + s2 = SingleStar(mass=5.0, state='H-rich_Core_H_burning', + metallicity=1.0) + return BinaryStar(star_1=s1, star_2=s2, index=0, + state='detached', event='ZAMS', + time=0.0, separation=100.0, + orbital_period=10.0, eccentricity=0.0) + + @fixture + def failed_binary(self): + """Create a binary with FAILED event.""" + s1 = SingleStar(mass=10.0, state='H-rich_Core_H_burning', + metallicity=1.0) + s2 = SingleStar(mass=5.0, state='H-rich_Core_H_burning', + metallicity=1.0) + b = BinaryStar(star_1=s1, star_2=s2, index=1, + state='detached', event='ZAMS', + time=0.0, separation=100.0, + orbital_period=10.0, eccentricity=0.0) + b.event = 'FAILED' + return b + + def test_init(self, manager): + assert manager.binaries == [] + assert manager.indices == [] + assert isinstance(manager.binary_generator, totest.BinaryGenerator) + + def test_init_with_filename(self): + mgr = totest.PopulationManager( + file_name='test.h5', + RNG=np.random.default_rng(seed=42), + metallicity=1.0, + star_formation='burst', + max_simulation_time=13.8e9) + assert mgr.store_file == 'test.h5' + + def test_init_with_file_sampler(self, tmp_path): + csv_path = os.path.join(tmp_path, "samples.csv") + df = pd.DataFrame({ + 'm1': [10.0], 'm2': [5.0], + 'orbital_period': [10.0], 'eccentricity': [0.0], + }) + df.to_csv(csv_path, index=False) + mgr = totest.PopulationManager( + read_samples_from_file=str(csv_path), + RNG=np.random.default_rng(seed=42), + metallicity=1.0, + star_formation='burst', + max_simulation_time=13.8e9) + assert mgr.binary_generator.sampler is not totest.generate_independent_samples + + def test_append_single(self, manager, dummy_binary): + manager.append(dummy_binary) + assert len(manager.binaries) == 1 + assert manager.indices == [0] + + def test_append_list(self, manager, dummy_binary): + manager.append([dummy_binary]) + assert len(manager.binaries) == 1 + + def test_append_invalid(self, manager): + with raises(ValueError, match="Must be BinaryStar"): + manager.append("not_a_binary") + + def test_remove_single(self, manager, dummy_binary): + manager.append(dummy_binary) + manager.remove(dummy_binary) + assert len(manager.binaries) == 0 + + def test_remove_list(self, manager, dummy_binary): + manager.append(dummy_binary) + manager.remove([dummy_binary]) + assert len(manager.binaries) == 0 + + def test_remove_invalid(self, manager): + with raises(ValueError, match="Must be BinaryStar"): + manager.remove("not_a_binary") + + def test_clear_dfs(self, manager): + manager.history_dfs = [pd.DataFrame()] + manager.oneline_dfs = [pd.DataFrame()] + manager.clear_dfs() + assert manager.history_dfs == [] + assert manager.oneline_dfs == [] + + def test_generate(self, manager): + binary = manager.generate(orbital_scheme='separation', + metallicity=1.0) + assert isinstance(binary, BinaryStar) + assert len(manager.binaries) == 1 + + def test_to_df_empty(self, manager): + assert manager.to_df() is None + + def test_to_df_with_binaries(self, manager): + manager.generate(orbital_scheme='separation', metallicity=1.0) + result = manager.to_df() + assert isinstance(result, pd.DataFrame) + assert len(result) > 0 + + def test_to_df_with_selection_accept(self, manager): + manager.generate(orbital_scheme='separation', metallicity=1.0) + result = manager.to_df(selection_function=lambda b: True) + assert isinstance(result, pd.DataFrame) + + def test_to_df_with_selection_reject(self, manager): + manager.generate(orbital_scheme='separation', metallicity=1.0) + result = manager.to_df(selection_function=lambda b: False) + assert result is None + + def test_to_df_with_history_dfs(self, manager): + dummy_df = pd.DataFrame({'state': ['detached'], 'time': [0.0]}, + index=[0]) + manager.history_dfs = [dummy_df] + result = manager.to_df() + assert isinstance(result, pd.DataFrame) + assert len(result) == 1 + + def test_to_oneline_df_empty(self, manager): + assert manager.to_oneline_df() is None + + def test_to_oneline_df_with_binaries(self, manager): + manager.generate(orbital_scheme='separation', metallicity=1.0) + result = manager.to_oneline_df() + assert isinstance(result, pd.DataFrame) + + def test_to_oneline_df_with_selection(self, manager): + manager.generate(orbital_scheme='separation', metallicity=1.0) + result = manager.to_oneline_df(selection_function=lambda b: True) + assert isinstance(result, pd.DataFrame) + + def test_to_oneline_df_with_oneline_dfs(self, manager): + dummy_df = pd.DataFrame({'state_i': ['detached']}, index=[0]) + manager.oneline_dfs = [dummy_df] + result = manager.to_oneline_df() + assert isinstance(result, pd.DataFrame) + + def test_find_failed_empty(self, manager): + assert manager.find_failed() is None + + def test_find_failed_with_binaries(self, manager, dummy_binary, + failed_binary): + manager.append(dummy_binary) + manager.append(failed_binary) + result = manager.find_failed() + assert len(result) == 1 + assert result[0].event == 'FAILED' + + def test_find_failed_with_dfs_found(self, manager): + failed_df = pd.DataFrame({'event': ['ZAMS', 'FAILED'], + 'time': [0.0, 1.0]}, index=[0, 0]) + manager.history_dfs = [failed_df] + result = manager.find_failed() + assert isinstance(result, pd.DataFrame) + + def test_find_failed_with_dfs_none_found(self, manager): + ok_df = pd.DataFrame({'event': ['ZAMS', 'END'], + 'time': [0.0, 1.0]}, index=[0, 0]) + manager.history_dfs = [ok_df] + result = manager.find_failed() + assert result == [] + + def test_breakdown_to_df(self, manager): + binary = manager.generate(orbital_scheme='separation', + metallicity=1.0) + assert len(manager.binaries) == 1 + manager.breakdown_to_df(binary) + assert len(manager.binaries) == 0 + assert len(manager.history_dfs) == 1 + assert len(manager.oneline_dfs) == 1 + + def test_breakdown_to_df_error(self, manager, capsys): + """breakdown_to_df catches exceptions during conversion.""" + binary = manager.generate(orbital_scheme='separation', + metallicity=1.0) + # Replace to_df with a function that raises + def bad_to_df(**kw): + raise RuntimeError("test error") + binary.to_df = bad_to_df + manager.breakdown_to_df(binary) + captured = capsys.readouterr() + assert "Error during breakdown" in captured.out + + def test_to_oneline_df_with_selection_reject(self, manager): + """to_oneline_df with selection_function that rejects all.""" + manager.generate(orbital_scheme='separation', metallicity=1.0) + result = manager.to_oneline_df(selection_function=lambda b: False) + assert result is None + + +class TestBinaryPopulation: + + @fixture + def pop(self): + """Create a minimal BinaryPopulation.""" + return totest.BinaryPopulation( + number_of_binaries=3, + metallicity=1.0, + star_formation='burst', + max_simulation_time=13.8e9, + entropy=12345) + + def test_init_basic(self, pop): + assert pop.number_of_binaries == 3 + assert pop.metallicity == 1.0 + assert isinstance(pop.population_properties, SimulationProperties) + assert isinstance(pop.manager, totest.PopulationManager) + assert isinstance(pop.RNG, np.random.Generator) + assert pop.comm is None + assert pop.JOB_ID is None + + def test_init_metallicity_from_list(self): + pop = totest.BinaryPopulation( + number_of_binaries=1, + metallicities=[0.5, 1.0, 2.0], + metallicity_index=1, + star_formation='burst', + max_simulation_time=13.8e9, + entropy=42) + assert pop.metallicity == 1.0 + + def test_init_mpi_and_jobarray_incompatible(self): + class FakeComm: + def Get_rank(self): return 0 + def Get_size(self): return 2 + with raises(ValueError, match="MPI and Job array runs are not compatible"): + totest.BinaryPopulation( + number_of_binaries=1, + comm=FakeComm(), + JOB_ID=123, + star_formation='burst', + max_simulation_time=13.8e9, + entropy=42) + + def test_init_mpi_no_entropy(self): + class FakeComm: + def Get_rank(self): return 0 + def Get_size(self): return 2 + with raises(ValueError, match="requires an entropy value"): + totest.BinaryPopulation( + number_of_binaries=1, + comm=FakeComm(), + star_formation='burst', + max_simulation_time=13.8e9) + + def test_init_mpi_with_entropy(self): + class FakeComm: + def Get_rank(self): return 0 + def Get_size(self): return 2 + pop = totest.BinaryPopulation( + number_of_binaries=10, + comm=FakeComm(), + star_formation='burst', + max_simulation_time=13.8e9, + entropy=42) + assert isinstance(pop.RNG, np.random.Generator) + + def test_init_job_array_with_entropy(self): + pop = totest.BinaryPopulation( + number_of_binaries=10, + JOB_ID=100, + RANK=0, + size=2, + star_formation='burst', + max_simulation_time=13.8e9, + entropy=42) + assert pop.JOB_ID == 100 + + def test_init_job_array_no_entropy(self): + """JOB_ID without entropy uses JOB_ID as seed.""" + pop = totest.BinaryPopulation( + number_of_binaries=10, + JOB_ID=100, + RANK=0, + size=2, + star_formation='burst', + max_simulation_time=13.8e9) + assert pop.JOB_ID == 100 + assert isinstance(pop.RNG, np.random.Generator) + + def test_from_ini(self, monkeypatch, tmp_path): + def mock_binarypop_kwargs(path, verbose=False): + return { + 'number_of_binaries': 5, + 'metallicity': 1.0, + 'metallicities': [1.0], + 'star_formation': 'burst', + 'max_simulation_time': 13.8e9, + 'entropy': 99, + } + def mock_simprop_kwargs(path): + return {} + + monkeypatch.setattr(totest, 'binarypop_kwargs_from_ini', + mock_binarypop_kwargs) + monkeypatch.setattr(totest, 'simprop_kwargs_from_ini', + mock_simprop_kwargs) + + pop = totest.BinaryPopulation.from_ini(str(tmp_path / "fake.ini")) + assert pop.number_of_binaries == 5 + assert isinstance(pop.population_properties, SimulationProperties) + + def test_close(self, pop): + pop.close() + + def test_getstate(self, pop): + state = pop.__getstate__() + assert isinstance(state, dict) + assert state['comm'] is None + + def test_getstate_with_steps_loaded(self, pop, monkeypatch): + """__getstate__ closes steps if they were loaded.""" + closed = [] + pop.population_properties.steps_loaded = True + monkeypatch.setattr(pop.population_properties, 'close', + lambda: closed.append(True)) + state = pop.__getstate__() + assert state['comm'] is None + assert len(closed) == 1 # close() was called diff --git a/posydon/unit_tests/popsyn/test_defaults.py b/posydon/unit_tests/popsyn/test_defaults.py new file mode 100644 index 0000000000..a3a114e667 --- /dev/null +++ b/posydon/unit_tests/popsyn/test_defaults.py @@ -0,0 +1,146 @@ +"""Unit tests of posydon/popsyn/defaults.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import other needed code for the tests, which is not already imported in the +# module you like to test +import pytest + +# import the module which will be tested +import posydon.popsyn.defaults as totest + + +# define test classes collecting several test functions +class TestElements: + # check for objects, which should be an element of the tested module + + def test_dir(self): + elements = ['default_kwargs', '__authors__',\ + '__builtins__', '__cached__', '__doc__', '__file__',\ + '__loader__', '__name__', '__package__', '__spec__','age_of_universe'] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + + def test_kwargs(self): + elements = [ + 'entropy', + 'number_of_binaries', + 'metallicities', + 'star_formation', + 'max_simulation_time', + 'orbital_scheme', + 'orbital_separation_scheme', + 'orbital_separation_min', + 'orbital_separation_max', + 'log_orbital_seperation_mean', + 'log_orbital_seperation_sigma', + 'orbital_period_scheme', + 'orbital_period_min', + 'orbital_period_max', + 'eccentricity_scheme', + 'primary_mass_scheme', + 'primary_mass_min', + 'primary_mass_max', + 'secondary_mass_scheme', + 'secondary_mass_min', + 'secondary_mass_max', + 'binary_fraction_const', + 'binary_fraction_scheme' + ] + assert set(totest.default_kwargs.keys()) == set(elements), \ + "The default_kwargs dictionary keys have changed. Please update the test." + + def test_instance_entropy(self): + assert isinstance(totest.default_kwargs['entropy'], (type(None), float)), \ + "entropy should be None or a float" + + def test_instance_number_of_binaries(self): + assert isinstance(totest.default_kwargs['number_of_binaries'], int), \ + "number_of_binaries should be an integer" + + def test_instance_metallicities(self): + assert isinstance(totest.default_kwargs['metallicities'], list), \ + "metallicities should be a list" + + def test_instance_star_formation(self): + assert isinstance(totest.default_kwargs['star_formation'], str), \ + "star_formation should be a string" + + def test_instance_max_simulation_time(self): + assert isinstance(totest.default_kwargs['max_simulation_time'], (float, int)), \ + "max_simulation_time should be a float or int" + + def test_instance_orbital_scheme(self): + assert isinstance(totest.default_kwargs['orbital_scheme'], str), \ + "orbital_scheme should be a string" + + def test_instance_orbital_separation_scheme(self): + assert isinstance(totest.default_kwargs['orbital_separation_scheme'], str), \ + "orbital_scheme should be a string" + + def test_instance_orbital_separation_min(self): + assert isinstance(totest.default_kwargs['orbital_separation_min'], float), \ + "orbital_separation_min should be a float" + + def test_instance_orbital_separation_max(self): + assert isinstance(totest.default_kwargs['orbital_separation_max'], float), \ + "orbital_separation_max should be a float" + + def test_instance_log_orbital_seperation_mean(self): + assert isinstance(totest.default_kwargs['log_orbital_seperation_mean'], (type(None), float)), \ + "log_orbital_seperation_mean should be None or a float" + + def test_instance_log_orbital_seperation_sigma(self): + assert isinstance(totest.default_kwargs['log_orbital_seperation_sigma'], (type(None), float)), \ + "log_orbital_seperation_sigma should be None or a float" + + def test_instance_orbital_period_min(self): + assert isinstance(totest.default_kwargs['orbital_period_min'], float), \ + "orbital_period_min should be a float" + + def test_instance_orbital_period_max(self): + assert isinstance(totest.default_kwargs['orbital_period_max'], (float, int)), \ + "orbital_period_max should be a float or int" + + def test_instance_eccentricity_scheme(self): + assert isinstance(totest.default_kwargs['eccentricity_scheme'], str), \ + "eccentricity_scheme should be a string" + + def test_instance_primary_mass_min(self): + assert isinstance(totest.default_kwargs['primary_mass_min'], float), \ + "primary_mass_min should be a float" + + def test_instance_primary_mass_max(self): + assert isinstance(totest.default_kwargs['primary_mass_max'], float), \ + "primary_mass_max should be a float" + + def test_instance_secondary_mass_min(self): + assert isinstance(totest.default_kwargs['secondary_mass_min'], float), \ + "secondary_mass_min should be a float" + + def test_instance_secondary_mass_max(self): + assert isinstance(totest.default_kwargs['secondary_mass_max'], float), \ + "secondary_mass_max should be a float" + + def test_instance_binary_fraction_const(self): + assert isinstance(totest.default_kwargs['binary_fraction_const'], (float, int)), \ + "binary_fraction_const should be a float or int" + + def test_instance_binary_fraction_scheme(self): + assert isinstance(totest.default_kwargs['binary_fraction_scheme'], str), \ + "binary_fraction_scheme should be a string" diff --git a/posydon/unit_tests/popsyn/test_distributions.py b/posydon/unit_tests/popsyn/test_distributions.py index c2f80cbedf..b5b80c42d3 100644 --- a/posydon/unit_tests/popsyn/test_distributions.py +++ b/posydon/unit_tests/popsyn/test_distributions.py @@ -10,9 +10,14 @@ from posydon.popsyn.distributions import ( FlatMassRatio, + LogNormalSeparation, LogUniform, + PowerLawMassRatio, PowerLawPeriod, Sana12Period, + ThermalEccentricity, + UniformEccentricity, + ZeroEccentricity, ) @@ -45,24 +50,20 @@ def test_initialization_custom(self, custom_flat_ratio): def test_initialization_invalid_parameters(self): """Test that initialization raises ValueError for invalid parameters.""" - # Test q_min not in (0, 1] - with pytest.raises(ValueError, match="q_min must be in \\(0, 1\\)"): - FlatMassRatio(q_min=0.0, q_max=0.5) - - with pytest.raises(ValueError, match="q_min must be in \\(0, 1\\)"): + with pytest.raises(ValueError, match="q_min must be in \\[0, 1\\)"): FlatMassRatio(q_min=-0.1, q_max=0.5) - with pytest.raises(ValueError, match="q_min must be in \\(0, 1\\)"): + with pytest.raises(ValueError, match="q_min must be in \\[0, 1\\)"): FlatMassRatio(q_min=1.5, q_max=2.0) # Test q_max not in (0, 1] - with pytest.raises(ValueError, match="q_max must be in \\(0, 1\\)"): + with pytest.raises(ValueError, match="q_max must be in \\(0, 1\\]"): FlatMassRatio(q_min=0.1, q_max=0.0) - with pytest.raises(ValueError, match="q_max must be in \\(0, 1\\)"): + with pytest.raises(ValueError, match="q_max must be in \\(0, 1\\]"): FlatMassRatio(q_min=0.1, q_max=-0.1) - with pytest.raises(ValueError, match="q_max must be in \\(0, 1\\)"): + with pytest.raises(ValueError, match="q_max must be in \\(0, 1\\]"): FlatMassRatio(q_min=0.1, q_max=1.5) # Test q_min >= q_max @@ -114,6 +115,7 @@ def test_pdf_within_range(self, custom_flat_ratio): q_values = np.linspace(custom_flat_ratio.q_min, custom_flat_ratio.q_max, 10) pdf_values = custom_flat_ratio.pdf(q_values) expected_pdf = custom_flat_ratio.norm * np.ones_like(q_values) + expected_pdf[0] = 0.0 # q_values[0] is equal to q_min, which is outside the valid range np.testing.assert_allclose(pdf_values, expected_pdf) def test_pdf_outside_range(self, custom_flat_ratio): @@ -153,6 +155,23 @@ def test_normalization_integral(self, default_flat_ratio): integral, _ = quad(default_flat_ratio.pdf, default_flat_ratio.q_min, default_flat_ratio.q_max) np.testing.assert_allclose(integral, 1.0, rtol=1e-10) + def test_rvs(self, custom_flat_ratio): + """Test random sampling.""" + rng = np.random.default_rng(42) + samples = custom_flat_ratio.rvs(size=1000, rng=rng) + + assert len(samples) == 1000 + assert np.all(samples >= custom_flat_ratio.q_min) + assert np.all(samples <= custom_flat_ratio.q_max) + + def test_rvs_without_rng(self, custom_flat_ratio): + """Test random sampling without providing an RNG.""" + samples = custom_flat_ratio.rvs(size=100) + + assert len(samples) == 100 + assert np.all(samples >= custom_flat_ratio.q_min) + assert np.all(samples <= custom_flat_ratio.q_max) + class TestSana12Period: """Test class for Sana12Period distribution.""" @@ -317,6 +336,64 @@ def test_calculate_normalization_integration(self, default_sana12): norm_high = default_sana12._calculate_normalization(m1_high) assert norm_high > 0 + def test_rvs_with_m1_none(self, default_sana12): + """Test that rvs raises ValueError when m1 is None.""" + rng = np.random.default_rng(42) + + with pytest.raises(ValueError, match="m1 \\(primary mass\\) must be provided"): + default_sana12.rvs(size=10, m1=None, rng=rng) + + def test_rvs_with_m1_wrong_size(self, default_sana12): + """Test that rvs raises ValueError when m1 has wrong size.""" + rng = np.random.default_rng(42) + m1_wrong_size = np.array([10.0, 15.0]) + + with pytest.raises(ValueError, match="m1 must be a single value or have size="): + default_sana12.rvs(size=10, m1=m1_wrong_size, rng=rng) + + def test_rvs_low_mass(self, default_sana12): + """Test random sampling for low mass stars.""" + rng = np.random.default_rng(42) + m1 = 10.0 # Below mbreak + + samples = default_sana12.rvs(size=100, m1=m1, rng=rng) + + assert len(samples) == 100 + assert np.all(samples >= default_sana12.p_min) + assert np.all(samples <= default_sana12.p_max) + + def test_rvs_high_mass(self, default_sana12): + """Test random sampling for high mass stars.""" + rng = np.random.default_rng(42) + m1 = 25.0 # Above mbreak + + samples = default_sana12.rvs(size=100, m1=m1, rng=rng) + + assert len(samples) == 100 + assert np.all(samples >= default_sana12.p_min) + assert np.all(samples <= default_sana12.p_max) + + def test_rvs_mixed_masses(self, default_sana12): + """Test random sampling with array of masses.""" + rng = np.random.default_rng(42) + m1 = np.array([10.0, 15.0, 20.0, 25.0, 30.0]) + + samples = default_sana12.rvs(size=5, m1=m1, rng=rng) + + assert len(samples) == 5 + assert np.all(samples >= default_sana12.p_min) + assert np.all(samples <= default_sana12.p_max) + + def test_rvs_without_rng(self, default_sana12): + """Test random sampling without providing an RNG.""" + m1 = 20.0 + + samples = default_sana12.rvs(size=100, m1=m1) + + assert len(samples) == 100 + assert np.all(samples >= default_sana12.p_min) + assert np.all(samples <= default_sana12.p_max) + class TestPowerLawPeriod: """Test class for PowerLawPeriod distribution.""" @@ -513,6 +590,24 @@ def test_normalization_consistency(self, default_power_law): np.testing.assert_allclose(pdf_values, expected_pdf) + def test_rvs(self, custom_power_law): + """Test random sampling.""" + rng = np.random.default_rng(42) + + samples = custom_power_law.rvs(size=1000, rng=rng) + + assert len(samples) == 1000 + assert np.all(samples >= custom_power_law.p_min) + assert np.all(samples <= custom_power_law.p_max) + + def test_rvs_without_rng(self, custom_power_law): + """Test random sampling without providing an RNG.""" + samples = custom_power_law.rvs(size=100) + + assert len(samples) == 100 + assert np.all(samples >= custom_power_law.p_min) + assert np.all(samples <= custom_power_law.p_max) + class TestDistributionComparisons: """Test class for comparing distributions and edge cases.""" @@ -621,3 +716,616 @@ def test_initialization_invalid_parameters(self): with pytest.raises(ValueError, match="max must be greater than min"): LogUniform(min=100.0, max=100.0) + + def test_repr(self): + """Test string representation.""" + log_uniform = LogUniform(min=10.0, max=1000.0) + rep_str = log_uniform.__repr__() + assert "LogUniform(" in rep_str + assert "min=10.0" in rep_str + assert "max=1000.0" in rep_str + + def test_repr_html(self): + """Test HTML representation.""" + log_uniform = LogUniform(min=10.0, max=1000.0) + html_str = log_uniform._repr_html_() + assert "

Log-Uniform Distribution

" in html_str + assert "min = 10.0" in html_str + assert "max = 1000.0" in html_str + + def test_pdf_within_range(self): + """Test PDF within the valid range.""" + log_uniform = LogUniform(min=10.0, max=1000.0) + x_values = np.array([10.0, 50.0, 100.0, 500.0, 1000.0]) + pdf_values = log_uniform.pdf(x_values) + expected = log_uniform.norm / x_values + np.testing.assert_allclose(pdf_values, expected) + + def test_pdf_outside_range(self): + """Test PDF outside the valid range.""" + log_uniform = LogUniform(min=10.0, max=1000.0) + # Below range + x_below = np.array([1.0, 5.0]) + pdf_below = log_uniform.pdf(x_below) + np.testing.assert_array_equal(pdf_below, np.zeros_like(x_below)) + + # Above range + x_above = np.array([2000.0, 5000.0]) + pdf_above = log_uniform.pdf(x_above) + np.testing.assert_array_equal(pdf_above, np.zeros_like(x_above)) + + def test_rvs(self): + """Test random sampling.""" + log_uniform = LogUniform(min=10.0, max=1000.0) + rng = np.random.default_rng(42) + samples = log_uniform.rvs(size=1000, rng=rng) + + assert len(samples) == 1000 + assert np.all(samples >= log_uniform.min) + assert np.all(samples <= log_uniform.max) + + def test_rvs_without_rng(self): + """Test random sampling without providing an RNG.""" + log_uniform = LogUniform(min=10.0, max=1000.0) + samples = log_uniform.rvs(size=100) + + assert len(samples) == 100 + assert np.all(samples >= log_uniform.min) + assert np.all(samples <= log_uniform.max) + + +class TestThermalEccentricity: + """Test class for ThermalEccentricity distribution.""" + + @pytest.fixture + def default_thermal(self): + """Fixture for default ThermalEccentricity instance.""" + return ThermalEccentricity() + + @pytest.fixture + def custom_thermal(self): + """Fixture for custom ThermalEccentricity instance.""" + return ThermalEccentricity(e_min=0.1, e_max=0.9) + + def test_initialization_default(self, default_thermal): + """Test default initialization.""" + assert default_thermal.e_min == 0.0 + assert default_thermal.e_max == 1.0 + assert hasattr(default_thermal, 'norm') + assert default_thermal.norm > 0 + + def test_initialization_custom(self, custom_thermal): + """Test custom initialization.""" + assert custom_thermal.e_min == 0.1 + assert custom_thermal.e_max == 0.9 + assert hasattr(custom_thermal, 'norm') + assert custom_thermal.norm > 0 + + def test_initialization_invalid_parameters(self): + """Test that initialization raises ValueError for invalid parameters.""" + # Test e_min not in [0, 1) + with pytest.raises(ValueError, match="e_min must be in \\[0, 1\\)"): + ThermalEccentricity(e_min=-0.1, e_max=0.5) + + with pytest.raises(ValueError, match="e_min must be in \\[0, 1\\)"): + ThermalEccentricity(e_min=1.5, e_max=2.0) + + # Test e_max not in (0, 1] + with pytest.raises(ValueError, match="e_max must be in \\(0, 1\\]"): + ThermalEccentricity(e_min=0.1, e_max=0.0) + + with pytest.raises(ValueError, match="e_max must be in \\(0, 1\\]"): + ThermalEccentricity(e_min=0.1, e_max=-0.1) + + with pytest.raises(ValueError, match="e_max must be in \\(0, 1\\]"): + ThermalEccentricity(e_min=0.1, e_max=1.5) + + # Test e_min >= e_max + with pytest.raises(ValueError, match="e_min must be less than e_max"): + ThermalEccentricity(e_min=0.8, e_max=0.5) + + with pytest.raises(ValueError, match="e_min must be less than e_max"): + ThermalEccentricity(e_min=0.5, e_max=0.5) + + def test_repr(self, custom_thermal): + """Test string representation.""" + rep_str = custom_thermal.__repr__() + assert "ThermalEccentricity(" in rep_str + assert "e_min=0.1" in rep_str + assert "e_max=0.9" in rep_str + + def test_repr_html(self, custom_thermal): + """Test HTML representation.""" + html_str = custom_thermal._repr_html_() + assert "

Thermal Eccentricity Distribution

" in html_str + assert "e_min = 0.1" in html_str + assert "e_max = 0.9" in html_str + + def test_thermal_eccentricity_method(self, default_thermal): + """Test the thermal_eccentricity method.""" + e_values = np.array([0.0, 0.25, 0.5, 0.75, 1.0]) + result = default_thermal.thermal_eccentricity(e_values) + expected = 2.0 * e_values + np.testing.assert_allclose(result, expected) + + def test_pdf_within_range(self, custom_thermal): + """Test PDF within the valid range.""" + e_values = np.linspace(custom_thermal.e_min, custom_thermal.e_max, 10) + pdf_values = custom_thermal.pdf(e_values) + + # All should be positive within range + assert np.all(pdf_values > 0) + + # Check normalization + expected = custom_thermal.thermal_eccentricity(e_values) * custom_thermal.norm + np.testing.assert_allclose(pdf_values, expected) + + def test_pdf_outside_range(self, custom_thermal): + """Test PDF outside the valid range.""" + # Below range + e_below = np.array([0.0, 0.05]) + pdf_below = custom_thermal.pdf(e_below) + np.testing.assert_array_equal(pdf_below, np.zeros_like(e_below)) + + # Above range + e_above = np.array([0.95, 1.0]) + pdf_above = custom_thermal.pdf(e_above) + np.testing.assert_array_equal(pdf_above, np.zeros_like(e_above)) + + def test_pdf_scalar_input(self, default_thermal): + """Test PDF with scalar input.""" + e = 0.5 + pdf_value = default_thermal.pdf(e) + expected = default_thermal.thermal_eccentricity(e) * default_thermal.norm + np.testing.assert_allclose(pdf_value, expected) + + def test_rvs(self, custom_thermal): + """Test random sampling.""" + rng = np.random.default_rng(42) + samples = custom_thermal.rvs(size=1000, rng=rng) + + assert len(samples) == 1000 + assert np.all(samples >= custom_thermal.e_min) + assert np.all(samples <= custom_thermal.e_max) + + def test_rvs_without_rng(self, custom_thermal): + """Test random sampling without providing an RNG.""" + samples = custom_thermal.rvs(size=100) + + assert len(samples) == 100 + assert np.all(samples >= custom_thermal.e_min) + assert np.all(samples <= custom_thermal.e_max) + + +class TestUniformEccentricity: + """Test class for UniformEccentricity distribution.""" + + @pytest.fixture + def default_uniform(self): + """Fixture for default UniformEccentricity instance.""" + return UniformEccentricity() + + @pytest.fixture + def custom_uniform(self): + """Fixture for custom UniformEccentricity instance.""" + return UniformEccentricity(e_min=0.1, e_max=0.9) + + def test_initialization_default(self, default_uniform): + """Test default initialization.""" + assert default_uniform.e_min == 0.0 + assert default_uniform.e_max == 1.0 + assert hasattr(default_uniform, 'norm') + assert default_uniform.norm == 1.0 + + def test_initialization_custom(self, custom_uniform): + """Test custom initialization.""" + assert custom_uniform.e_min == 0.1 + assert custom_uniform.e_max == 0.9 + assert hasattr(custom_uniform, 'norm') + expected_norm = 1.0 / (0.9 - 0.1) + np.testing.assert_allclose(custom_uniform.norm, expected_norm) + + def test_initialization_invalid_parameters(self): + """Test that initialization raises ValueError for invalid parameters.""" + # Test e_min not in [0, 1) + with pytest.raises(ValueError, match="e_min must be in \\[0, 1\\)"): + UniformEccentricity(e_min=-0.1, e_max=0.5) + + with pytest.raises(ValueError, match="e_min must be in \\[0, 1\\)"): + UniformEccentricity(e_min=1.5, e_max=2.0) + + # Test e_max not in (0, 1] + with pytest.raises(ValueError, match="e_max must be in \\(0, 1\\]"): + UniformEccentricity(e_min=0.1, e_max=0.0) + + with pytest.raises(ValueError, match="e_max must be in \\(0, 1\\]"): + UniformEccentricity(e_min=0.1, e_max=-0.1) + + with pytest.raises(ValueError, match="e_max must be in \\(0, 1\\]"): + UniformEccentricity(e_min=0.1, e_max=1.5) + + # Test e_min >= e_max + with pytest.raises(ValueError, match="e_min must be less than e_max"): + UniformEccentricity(e_min=0.8, e_max=0.5) + + with pytest.raises(ValueError, match="e_min must be less than e_max"): + UniformEccentricity(e_min=0.5, e_max=0.5) + + def test_repr(self, custom_uniform): + """Test string representation.""" + rep_str = custom_uniform.__repr__() + assert "UniformEccentricity(" in rep_str + assert "e_min=0.1" in rep_str + assert "e_max=0.9" in rep_str + + def test_repr_html(self, custom_uniform): + """Test HTML representation.""" + html_str = custom_uniform._repr_html_() + assert "

Uniform Eccentricity Distribution

" in html_str + assert "e_min = 0.1" in html_str + assert "e_max = 0.9" in html_str + + def test_pdf_within_range(self, custom_uniform): + """Test PDF within the valid range.""" + e_values = np.linspace(custom_uniform.e_min, custom_uniform.e_max, 10) + pdf_values = custom_uniform.pdf(e_values) + + # All should equal the normalization constant + expected = custom_uniform.norm * np.ones_like(e_values) + np.testing.assert_allclose(pdf_values, expected) + + def test_pdf_outside_range(self, custom_uniform): + """Test PDF outside the valid range.""" + # Below range + e_below = np.array([0.0, 0.05]) + pdf_below = custom_uniform.pdf(e_below) + np.testing.assert_array_equal(pdf_below, np.zeros_like(e_below)) + + # Above range + e_above = np.array([0.95, 1.0]) + pdf_above = custom_uniform.pdf(e_above) + np.testing.assert_array_equal(pdf_above, np.zeros_like(e_above)) + + def test_pdf_scalar_input(self, default_uniform): + """Test PDF with scalar input.""" + e = 0.5 + pdf_value = default_uniform.pdf(e) + np.testing.assert_allclose(pdf_value, default_uniform.norm) + + def test_rvs(self, custom_uniform): + """Test random sampling.""" + rng = np.random.default_rng(42) + samples = custom_uniform.rvs(size=1000, rng=rng) + + assert len(samples) == 1000 + assert np.all(samples >= custom_uniform.e_min) + assert np.all(samples <= custom_uniform.e_max) + + def test_rvs_without_rng(self, custom_uniform): + """Test random sampling without providing an RNG.""" + samples = custom_uniform.rvs(size=100) + + assert len(samples) == 100 + assert np.all(samples >= custom_uniform.e_min) + assert np.all(samples <= custom_uniform.e_max) + + +class TestZeroEccentricity: + """Test class for ZeroEccentricity distribution.""" + + @pytest.fixture + def zero_ecc(self): + """Fixture for ZeroEccentricity instance.""" + return ZeroEccentricity() + + def test_initialization(self, zero_ecc): + """Test initialization.""" + # Should have no parameters + assert isinstance(zero_ecc, ZeroEccentricity) + + def test_repr(self, zero_ecc): + """Test string representation.""" + rep_str = zero_ecc.__repr__() + assert "ZeroEccentricity()" in rep_str + + def test_repr_html(self, zero_ecc): + """Test HTML representation.""" + html_str = zero_ecc._repr_html_() + assert "

Zero Eccentricity Distribution

" in html_str + assert "e = 0 (circular orbits)" in html_str + + def test_pdf_at_zero(self, zero_ecc): + """Test PDF at e=0.""" + pdf_value = zero_ecc.pdf(0.0) + assert pdf_value == 1.0 + + def test_pdf_away_from_zero(self, zero_ecc): + """Test PDF for non-zero eccentricities.""" + e_values = np.array([0.1, 0.5, 0.9, 1.0]) + pdf_values = zero_ecc.pdf(e_values) + np.testing.assert_array_equal(pdf_values, np.zeros_like(e_values)) + + def test_pdf_mixed(self, zero_ecc): + """Test PDF with mixture of zero and non-zero values.""" + e_values = np.array([0.0, 0.1, 0.0, 0.5]) + pdf_values = zero_ecc.pdf(e_values) + expected = np.array([1.0, 0.0, 1.0, 0.0]) + np.testing.assert_array_equal(pdf_values, expected) + + def test_rvs(self, zero_ecc): + """Test random sampling.""" + rng = np.random.default_rng(42) + samples = zero_ecc.rvs(size=1000, rng=rng) + + # All samples should be zero + assert len(samples) == 1000 + np.testing.assert_array_equal(samples, np.zeros(1000)) + + def test_rvs_without_rng(self, zero_ecc): + """Test random sampling without providing an RNG.""" + samples = zero_ecc.rvs(size=100) + + # All samples should be zero + assert len(samples) == 100 + np.testing.assert_array_equal(samples, np.zeros(100)) + + +class TestLogNormalSeparation: + """Test class for LogNormalSeparation distribution.""" + + @pytest.fixture + def default_lognormal(self): + """Fixture for default LogNormalSeparation instance.""" + return LogNormalSeparation() + + @pytest.fixture + def custom_lognormal(self): + """Fixture for custom LogNormalSeparation instance.""" + return LogNormalSeparation(mean=1.0, sigma=0.5, min=10.0, max=1e4) + + def test_initialization_default(self, default_lognormal): + """Test default initialization.""" + assert default_lognormal.mean == 0.85 + assert default_lognormal.sigma == 0.37 + assert default_lognormal.min == 5.0 + assert default_lognormal.max == 1e5 + + def test_initialization_custom(self, custom_lognormal): + """Test custom initialization.""" + assert custom_lognormal.mean == 1.0 + assert custom_lognormal.sigma == 0.5 + assert custom_lognormal.min == 10.0 + assert custom_lognormal.max == 1e4 + + def test_initialization_invalid_parameters(self): + """Test that initialization raises ValueError for invalid parameters.""" + # Test min <= 0 + with pytest.raises(ValueError, match="min must be positive"): + LogNormalSeparation(mean=1.0, sigma=0.5, min=0.0, max=1000.0) + + with pytest.raises(ValueError, match="min must be positive"): + LogNormalSeparation(mean=1.0, sigma=0.5, min=-1.0, max=1000.0) + + # Test max <= min + with pytest.raises(ValueError, match="max must be greater than min"): + LogNormalSeparation(mean=1.0, sigma=0.5, min=1000.0, max=100.0) + + with pytest.raises(ValueError, match="max must be greater than min"): + LogNormalSeparation(mean=1.0, sigma=0.5, min=100.0, max=100.0) + + # Test sigma <= 0 + with pytest.raises(ValueError, match="sigma must be positive"): + LogNormalSeparation(mean=1.0, sigma=0.0, min=10.0, max=1000.0) + + with pytest.raises(ValueError, match="sigma must be positive"): + LogNormalSeparation(mean=1.0, sigma=-0.5, min=10.0, max=1000.0) + + def test_repr(self, custom_lognormal): + """Test string representation.""" + rep_str = custom_lognormal.__repr__() + assert "LogNormalSeparation(" in rep_str + assert "mean=1.0" in rep_str + assert "sigma=0.5" in rep_str + assert "min=10.0" in rep_str + assert "max=10000.0" in rep_str + + def test_repr_html(self, custom_lognormal): + """Test HTML representation.""" + html_str = custom_lognormal._repr_html_() + assert "

Log-Normal Separation Distribution

" in html_str + assert "mean (log10) = 1.0" in html_str + assert "sigma (log10) = 0.5" in html_str + + def test_pdf_within_range(self, custom_lognormal): + """Test PDF within the valid range.""" + a_values = np.array([10.0, 50.0, 100.0, 500.0, 1000.0]) + pdf_values = custom_lognormal.pdf(a_values) + + # All should be positive within range + assert np.all(pdf_values > 0) + + def test_pdf_outside_range(self, custom_lognormal): + """Test PDF outside the valid range.""" + # Below range + a_below = np.array([1.0, 5.0]) + pdf_below = custom_lognormal.pdf(a_below) + np.testing.assert_array_equal(pdf_below, np.zeros_like(a_below)) + + # Above range + a_above = np.array([2e4, 5e4]) + pdf_above = custom_lognormal.pdf(a_above) + np.testing.assert_array_equal(pdf_above, np.zeros_like(a_above)) + + def test_pdf_zero_and_negative(self, custom_lognormal): + """Test PDF for zero and negative values.""" + a_invalid = np.array([0.0, -10.0]) + pdf_invalid = custom_lognormal.pdf(a_invalid) + np.testing.assert_array_equal(pdf_invalid, np.zeros_like(a_invalid)) + + def test_rvs(self, custom_lognormal): + """Test random sampling.""" + rng = np.random.default_rng(42) + samples = custom_lognormal.rvs(size=1000, rng=rng) + + assert len(samples) == 1000 + assert np.all(samples >= custom_lognormal.min) + assert np.all(samples <= custom_lognormal.max) + + def test_rvs_without_rng(self, custom_lognormal): + """Test random sampling without providing an RNG.""" + samples = custom_lognormal.rvs(size=100) + + assert len(samples) == 100 + assert np.all(samples >= custom_lognormal.min) + assert np.all(samples <= custom_lognormal.max) + + + with pytest.raises(ValueError, match="max must be greater than min"): + LogUniform(min=100.0, max=100.0) + + +class TestPowerLawMassRatio: + """Test class for PowerLawMassRatio distribution.""" + + @pytest.fixture + def default_power_law_mass_ratio(self): + """Fixture for default PowerLawMassRatio instance.""" + return PowerLawMassRatio() + + @pytest.fixture + def custom_power_law_mass_ratio(self): + """Fixture for custom PowerLawMassRatio instance.""" + return PowerLawMassRatio(alpha=-1.0, q_min=0.1, q_max=0.9) + + def test_initialization_default(self, default_power_law_mass_ratio): + """Test default initialization of PowerLawMassRatio.""" + assert default_power_law_mass_ratio.alpha == 0.0 + assert default_power_law_mass_ratio.q_min == 0.05 + assert default_power_law_mass_ratio.q_max == 1.0 + assert hasattr(default_power_law_mass_ratio, 'norm') + assert default_power_law_mass_ratio.norm > 0 + + def test_initialization_custom(self, custom_power_law_mass_ratio): + """Test custom initialization of PowerLawMassRatio.""" + assert custom_power_law_mass_ratio.alpha == -1.0 + assert custom_power_law_mass_ratio.q_min == 0.1 + assert custom_power_law_mass_ratio.q_max == 0.9 + assert hasattr(custom_power_law_mass_ratio, 'norm') + assert custom_power_law_mass_ratio.norm > 0 + + def test_initialization_invalid_parameters(self): + """Test that initialization raises ValueError for invalid parameters.""" + with pytest.raises(ValueError, match="q_min must be in \\[0, 1\\)"): + PowerLawMassRatio(q_min=-0.1) + + with pytest.raises(ValueError, match="q_min must be in \\[0, 1\\)"): + PowerLawMassRatio(q_min=1.0) + + with pytest.raises(ValueError, match="q_max must be in \\(0, 1\\]"): + PowerLawMassRatio(q_max=0.0) + + with pytest.raises(ValueError, match="q_max must be in \\(0, 1\\]"): + PowerLawMassRatio(q_max=1.5) + + with pytest.raises(ValueError, match="q_min must be less than q_max"): + PowerLawMassRatio(q_min=0.8, q_max=0.5) + + with pytest.raises(ValueError, match="q_min must be > 0 for alpha <= -1"): + PowerLawMassRatio(alpha=-1.0, q_min=0.0) + + def test_repr(self, custom_power_law_mass_ratio): + """Test string representation of the distribution.""" + repr_str = custom_power_law_mass_ratio.__repr__() + assert "PowerLawMassRatio(" in repr_str + assert "alpha=-1.0" in repr_str + assert "q_min=0.1" in repr_str + assert "q_max=0.9" in repr_str + + def test_repr_html(self, default_power_law_mass_ratio): + """Test HTML representation for Jupyter notebooks.""" + html_str = default_power_law_mass_ratio._repr_html_() + assert "

Power Law Mass Ratio Distribution

" in html_str + assert "alpha = 0.0" in html_str + assert "q_min = 0.05" in html_str + assert "q_max = 1.0" in html_str + + def test_calculate_normalization(self, custom_power_law_mass_ratio): + """Test that normalization constant is the reciprocal of the integral.""" + integral, _ = quad( + custom_power_law_mass_ratio.power_law_mass_ratio, + custom_power_law_mass_ratio.q_min, + custom_power_law_mass_ratio.q_max, + ) + expected_norm = 1.0 / integral + np.testing.assert_allclose(custom_power_law_mass_ratio.norm, expected_norm) + + def test_power_law_mass_ratio_method(self, custom_power_law_mass_ratio): + """Test the power_law_mass_ratio method returns q^alpha.""" + q = np.array([0.2, 0.5, 0.8]) + result = custom_power_law_mass_ratio.power_law_mass_ratio(q) + expected = q ** custom_power_law_mass_ratio.alpha + np.testing.assert_allclose(result, expected) + + def test_pdf_within_range(self, custom_power_law_mass_ratio): + """Test PDF returns correct values within the mass ratio range.""" + q_values = np.linspace( + custom_power_law_mass_ratio.q_min + 1e-6, + custom_power_law_mass_ratio.q_max, + 10, + ) + pdf_values = custom_power_law_mass_ratio.pdf(q_values) + expected = ( + custom_power_law_mass_ratio.power_law_mass_ratio(q_values) + * custom_power_law_mass_ratio.norm + ) + np.testing.assert_allclose(pdf_values, expected) + + def test_pdf_outside_range(self, custom_power_law_mass_ratio): + """Test PDF returns zero outside the mass ratio range.""" + q_below = np.array([0.05, 0.09]) + np.testing.assert_array_equal( + custom_power_law_mass_ratio.pdf(q_below), np.zeros_like(q_below) + ) + + q_above = np.array([0.95, 1.0]) + np.testing.assert_array_equal( + custom_power_law_mass_ratio.pdf(q_above), np.zeros_like(q_above) + ) + + def test_pdf_at_q_min_excluded(self, custom_power_law_mass_ratio): + """Test that q_min itself is excluded (open lower bound).""" + pdf_at_qmin = custom_power_law_mass_ratio.pdf( + np.array([custom_power_law_mass_ratio.q_min]) + ) + assert pdf_at_qmin[0] == 0.0 + + def test_normalization_integral(self, custom_power_law_mass_ratio): + """Test that the PDF integrates to 1 over the valid range.""" + integral, _ = quad( + custom_power_law_mass_ratio.pdf, + custom_power_law_mass_ratio.q_min, + custom_power_law_mass_ratio.q_max, + ) + np.testing.assert_allclose(integral, 1.0, rtol=1e-6) + + def test_rvs(self, custom_power_law_mass_ratio): + """Test random sampling stays within bounds.""" + rng = np.random.default_rng(42) + samples = custom_power_law_mass_ratio.rvs(size=1000, rng=rng) + assert len(samples) == 1000 + assert np.all(samples >= custom_power_law_mass_ratio.q_min) + assert np.all(samples <= custom_power_law_mass_ratio.q_max) + + def test_rvs_without_rng(self, custom_power_law_mass_ratio): + """Test random sampling without providing an RNG.""" + samples = custom_power_law_mass_ratio.rvs(size=100) + assert len(samples) == 100 + assert np.all(samples >= custom_power_law_mass_ratio.q_min) + assert np.all(samples <= custom_power_law_mass_ratio.q_max) + + @pytest.mark.parametrize("alpha", [2.0, 0.0, -0.5]) + def test_different_alpha_values(self, alpha): + """Test PowerLawMassRatio with different valid alpha exponents.""" + dist = PowerLawMassRatio(alpha=alpha, q_min=0.1, q_max=1.0) + integral, _ = quad(dist.pdf, dist.q_min, dist.q_max) + np.testing.assert_allclose(integral, 1.0, rtol=1e-6) diff --git a/posydon/unit_tests/popsyn/test_independent_sample.py b/posydon/unit_tests/popsyn/test_independent_sample.py new file mode 100644 index 0000000000..cac8c2852f --- /dev/null +++ b/posydon/unit_tests/popsyn/test_independent_sample.py @@ -0,0 +1,258 @@ +"""Unit tests of posydon/popsyn/independent_sample.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.independent_sample as totest + +# aliases +np = totest.np + +# import other needed code for the tests, which is not already imported in the +# module you like to test +import re + +from pytest import approx, raises + + +# define test classes collecting several test functions +class TestElements: + + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['generate_independent_samples', 'use_Moe_17_PsandQs', \ + '_gen_Moe_17_PsandQs','generate_orbital_periods', \ + 'generate_orbital_separations', 'generate_eccentricities',\ + 'generate_primary_masses','generate_secondary_masses',\ + 'generate_binary_fraction','__authors__',\ + 'np','truncnorm','rejection_sampler',\ + 'IMFs','distributions','Moe_17_PsandQs',\ + '__builtins__', '__cached__', '__doc__', '__file__',\ + '__loader__', '__name__', '__package__', '__spec__'] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + +class TestFunctions: + + # test functions + def test_generate_independent_samples(self): + # bad input + with raises(ValueError, match="Allowed orbital schemes are separation or period."): + totest.generate_independent_samples('test') + + # separation scheme + orb, ecc, m1, m2 = totest.generate_independent_samples( + orbital_scheme='separation', + RNG=np.random.default_rng(seed=42)) + assert orb[0] == approx(24650.481799781122,abs=6e-12) + assert ecc[0] == approx(0.8350856417514098,abs=6e-12) + assert m1[0] == approx(19.97764511120556,abs=6e-12) + assert m2[0] == approx(8.964150262412895,abs=6e-12) + assert isinstance(orb, np.ndarray) + assert len(orb) == 1 + assert all(np.isfinite(m1)) + + # period scheme (default) + orb_p, ecc_p, m1_p, m2_p = totest.generate_independent_samples( + orbital_scheme='period', + RNG=np.random.default_rng(seed=42)) + assert orb_p[0] == approx(872.213878458193,abs=6e-12) + assert ecc_p[0] == approx(0.7259611833901314,abs=6e-12) + assert m1_p[0] == approx(19.97764511120556,abs=6e-12) + assert m2_p[0] == approx(8.964150262412895,abs=6e-12) + assert isinstance(orb_p, np.ndarray) + assert len(orb_p) == 1 + assert all(np.isfinite(m1_p)) + + def test_use_Moe_17_PsandQs(self): + + # returns True for Moe+17-PsandQs secondary_mass_scheme + assert totest.use_Moe_17_PsandQs(secondary_mass_scheme='Moe+17-PsandQs') is True + + # returns True for Moe+17-PsandQs orbital_period_scheme with period scheme + assert totest.use_Moe_17_PsandQs( + orbital_scheme='period', + orbital_period_scheme='Moe+17-PsandQs') is True + + # returns True for Moe+17-PsandQs eccentricity_scheme + assert totest.use_Moe_17_PsandQs(eccentricity_scheme='Moe+17-PsandQs') is True + # returns False for non-Moe schemes + assert totest.use_Moe_17_PsandQs( + secondary_mass_scheme='flat_mass_ratio', + orbital_scheme='period', + orbital_period_scheme='Sana+12_period_extended', + eccentricity_scheme='zero') is False + + def test_generate_orbital_periods(self): + # missing argument + with raises(TypeError, match="missing 1 required positional argument: 'primary_masses'"): + totest.generate_orbital_periods() + + # bad input + with raises(ValueError, match="p_max must be greater than p_min"): + totest.generate_orbital_periods(np.array([1.]), + orbital_period_min=10., + orbital_period_max=1.) + with raises(ValueError, match="You must provide an allowed orbital period scheme."): + totest.generate_orbital_periods(np.array([1.]), + orbital_period_scheme='test') + # examples + tests = [(1.0,42,approx(403.44608837021764,abs=6e-12)), + (1.0,12,approx(3.4380527315000666,abs=6e-12))] + for (m,r,p) in tests: + assert totest.generate_orbital_periods(m,RNG = np.random.default_rng(seed=r))[0] == p + + def test_generate_orbital_separations(self): + # missing log_normal params + with raises(ValueError, match="For the `log_normal separation` scheme you must give"): + totest.generate_orbital_separations(orbital_separation_scheme='log_normal') + + # bad input: min > max (raised by LogUniform distribution class) + with raises(ValueError, match="max must be greater than min"): + totest.generate_orbital_separations(orbital_separation_min=10., + orbital_separation_max=1.) + + # bad input: min > max with log_normal + with raises(ValueError, match="`orbital_separation_max` must be"): + totest.generate_orbital_separations( + orbital_separation_scheme='log_normal', + log_orbital_separation_mean=1.0, + log_orbital_separation_sigma=1.0, + orbital_separation_min=10., + orbital_separation_max=1.) + + # bad scheme + with raises(ValueError, match="You must provide an allowed orbital separation scheme."): + totest.generate_orbital_separations(orbital_separation_scheme='test') + + # log_normal examples + tests_normal = [(0., 1.0, 42, approx(39.83711402835139, abs=6e-12)), + (1.0, 10., 42, approx(9799.179319004, abs=6e-9))] + for (m, s, r, sep) in tests_normal: + assert totest.generate_orbital_separations( + orbital_separation_scheme='log_normal', + log_orbital_separation_mean=m, + log_orbital_separation_sigma=s, + RNG=np.random.default_rng(seed=r))[0] == sep + + # log_uniform examples + tests_uniform = [(1., 3., 42, approx(2.3402964885050066, abs=6e-12)), + (2., 10., 42, approx(6.950276115688688, abs=6e-12))] + for (mi, ma, r, sep) in tests_uniform: + assert totest.generate_orbital_separations( + orbital_separation_min=mi, + orbital_separation_max=ma, + RNG=np.random.default_rng(seed=r))[0] == sep + def test_generate_eccentricities(self): + # bad input + with raises(TypeError, match="expected a sequence of integers or a single integer"): + totest.generate_eccentricities(number_of_binaries=1.) + with raises(ValueError, match="You must provide an allowed eccentricity scheme."): + totest.generate_eccentricities(eccentricity_scheme='test') + # examples + tests = [('thermal',42,approx(0.8797477186989253,abs=6e-12)), + ('uniform',42,approx(0.7739560485559633,abs=6e-12)), + ('zero',42,approx(0.,abs=6e-12))] + for (s,r,e) in tests: + assert totest.generate_eccentricities(eccentricity_scheme=s, + RNG = np.random.default_rng(seed=r))[0] == e + + def test_generate_primary_masses(self): + # bad input: invalid scheme + with raises(ValueError, match="You must provide an allowed primary mass scheme."): + totest.generate_primary_masses(primary_mass_scheme='test') + + # bad input: min > max (raised by IMF class) + with raises(ValueError, match="m_min must be less than m_max"): + totest.generate_primary_masses(primary_mass_min=100., primary_mass_max=10.) + + # examples for all three schemes + tests = [('Salpeter', 42, approx(19.97764511120556, abs=6e-12)), + ('Kroupa1993', 42, approx(16.52331793661949, abs=6e-12)), + ('Kroupa2001', 42, approx(20.633204764212334, abs=6e-12))] + for (s, r, m1) in tests: + assert totest.generate_primary_masses( + primary_mass_scheme=s, + RNG=np.random.default_rng(seed=r))[0] == m1 + + def test_generate_secondary_masses(self): + # missing argument + with raises(TypeError, match="missing 1 required positional argument: 'primary_masses'"): + totest.generate_secondary_masses() + + # bad input: invalid scheme + with raises(ValueError, match="You must provide an allowed secondary mass scheme."): + totest.generate_secondary_masses(primary_masses=np.array([10.]), + secondary_mass_scheme='test') + + # bad input: secondary_mass_min > primary mass + with raises(ValueError, match="`secondary_mass_min` is larger than some primary masses"): + totest.generate_secondary_masses(primary_masses=np.array([1.]), + secondary_mass_min=10., + secondary_mass_max=100.) + + # flat_mass_ratio example + result = totest.generate_secondary_masses( + primary_masses=np.array([10.]), + secondary_mass_scheme='flat_mass_ratio', + RNG=np.random.default_rng(seed=42)) + assert len(result) == 1 + assert result[0] > 0 + assert result[0] <= 10.0 + + # q=1 example + result_q1 = totest.generate_secondary_masses( + primary_masses=np.array([10.]), + secondary_mass_scheme='q=1', + RNG=np.random.default_rng(seed=42)) + assert result_q1[0] == approx(10., abs=6e-12) + + def test_generate_binary_fraction(self): + # missing primary mass + with raises(ValueError, match="There was not a primary mass provided in the inputs"): + totest.generate_binary_fraction(binary_fraction_scheme='const') + + # bad scheme (m1 must be provided before scheme check) + with raises(ValueError, match="You must provide an allowed binary fraction scheme."): + totest.generate_binary_fraction(binary_fraction_scheme='test', + m1=np.array([10.])) + + # const scheme examples + tests_const = [1.0, 1, 0.5] + for c in tests_const: + assert totest.generate_binary_fraction( + binary_fraction_const=c, + binary_fraction_scheme='const', + m1=np.array([10.])) == c + + # non-array m1 input (triggers np.asarray conversion) + assert totest.generate_binary_fraction( + binary_fraction_const=0.7, + binary_fraction_scheme='const', + m1=10.) == 0.7 + + # Moe+17-massdependent scheme + tests_moe = [(np.array([1.]), 0.4), + (np.array([3.]), 0.59), + (np.array([8.]), 0.76), + (np.array([10.]), 0.84), + (np.array([18.]), 0.94)] + for (m1, f) in tests_moe: + result = totest.generate_binary_fraction( + binary_fraction_scheme='Moe+17-massdependent', m1=m1) + assert result[0] == f diff --git a/posydon/unit_tests/popsyn/test_io.py b/posydon/unit_tests/popsyn/test_io.py new file mode 100644 index 0000000000..415a012ae6 --- /dev/null +++ b/posydon/unit_tests/popsyn/test_io.py @@ -0,0 +1,397 @@ +"""Unit tests of posydon/popsyn/io.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.io as totest + +# aliases +np = totest.np +pd = totest.pd + +import ast +import errno +import importlib +import os +import pprint +import textwrap +from configparser import ConfigParser, MissingSectionHeaderError +from textwrap import dedent + +# import other needed code for the tests, which is not already imported in the +# module you like to test +from pytest import approx, fixture, raises, warns + + +# define test classes collecting several test functions +class TestElements: + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['BINARYPROPERTIES_DTYPES', 'OBJECT_FIXED_SUB_DTYPES', + 'STARPROPERTIES_DTYPES', 'EXTRA_BINARY_COLUMNS_DTYPES', + 'EXTRA_STAR_COLUMNS_DTYPES', 'SCALAR_NAMES_DTYPES', + 'clean_binary_history_df', 'clean_binary_oneline_df', + 'parse_inifile', 'simprop_kwargs_from_ini', + 'binarypop_kwargs_from_ini', + '__builtins__', '__cached__', '__doc__', '__file__', + '__loader__', '__name__', '__package__', '__spec__', + 'ConfigParser', 'ast', 'importlib', 'os', 'errno', + 'pprint', 'np', 'pd',] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + +class TestFunctions: + + @fixture + def simple_ini(self,tmp_path): + file_path = os.path.join(tmp_path, "test.ini") + with open(file_path, "w") as f: + f.write("[section]\nkey=value\n") + return file_path + + @fixture + def multi_ini(self,tmp_path): + file1 = os.path.join(tmp_path, "a.ini") + file2 = os.path.join(tmp_path, "b.ini") + with open(file1, "w") as f: + f.write("[section]\nkey1=value1\n") + with open(file2, "w") as f: + f.write("[section]\nkey2=value2\n") + return [file1, file2] + + @fixture + def textfile(self,tmp_path): + file_path = os.path.join(tmp_path, "textfile.txt") + with open(file_path, "w") as f: + f.write("test") + return file_path + + @fixture + def sim_ini(self,tmp_path): + ini_content = dedent( + """ + [flow] + import = ['posydon.binary_evol.flow_chart', 'flow_chart'] + absolute_import = None + + [step_HMS_HMS] + import = ['posydon.binary_evol.MESA.step_mesa', 'MS_MS_step'] + absolute_import = None + interpolation_method = 'linear3c_kNN' + save_initial_conditions = True + verbose = False + + [extra_hooks] + import_1 = ['posydon.binary_evol.simulationproperties', 'TimingHooks'] + absolute_import_1 = None + kwargs_1 = {} + import_2 = ['posydon.binary_evol.simulationproperties', 'StepNamesHooks'] + absolute_import_2 = None + kwargs_2 = {} + """) + file_path = os.path.join(tmp_path, "sim.ini") + with open(file_path, "w") as f: + f.write(ini_content) + return file_path + + @fixture + def grid_paths_ini(self, tmp_path): + ini_content = dedent( + """ + [grid_paths] + HMS_HMS = '/path/to/grid' + """) + file_path = os.path.join(tmp_path, "grid_paths.ini") + with open(file_path, "w") as f: + f.write(ini_content) + return file_path + + @fixture + def binpop_ini(self, tmp_path): + ini_content = dedent( + """ + [BinaryPopulation_options] + use_MPI = False + metallicity = [0.02] + number_of_binaries = 1 + temp_directory = 'tmp' + + [BinaryStar_output] + extra_columns = {} + only_select_columns = [] + scalar_names = [] + + [SingleStar_1_output] + include_S1 = False + + [SingleStar_2_output] + include_S2 = False + """) + file_path = os.path.join(tmp_path, "binpop.ini") + with open(file_path, "w") as f: + f.write(ini_content) + return file_path + + @fixture + def binpop_ini_mpi(self, tmp_path): + ini_content = dedent( + """ + [BinaryPopulation_options] + use_MPI = True + metallicity = [0.02] + number_of_binaries = 1 + temp_directory = 'tmp' + + [BinaryStar_output] + extra_columns = {} + only_select_columns = [] + scalar_names = [] + + [SingleStar_1_output] + include_S1 = False + + [SingleStar_2_output] + include_S2 = False + """) + file_path = os.path.join(tmp_path, "binpop_mpi.ini") + with open(file_path, "w") as f: + f.write(ini_content) + return file_path + + @fixture + def binpop_ini_stars(self, tmp_path): + ini_content = dedent( + """ + [BinaryPopulation_options] + use_MPI = False + metallicity = [0.02] + number_of_binaries = 1 + temp_directory = 'tmp' + + [BinaryStar_output] + extra_columns = {} + only_select_columns = [] + scalar_names = [] + + [SingleStar_1_output] + include_S1 = True + only_select_columns = [ + 'state', + 'mass', + 'log_R'] + + [SingleStar_2_output] + include_S2 = True + only_select_columns = [ + 'log_L', + 'lg_mdot'] + """) + file_path = os.path.join(tmp_path, "binpop_stars.ini") + with open(file_path, "w") as f: + f.write(ini_content) + return file_path + + @fixture + def history_df(self): + data = { + 'state': ['disrupted'], + 'time': [1.23], + 'S1_mass': [10.0], + 'S2_spin': [0.3] + } + return pd.DataFrame(data) + + @fixture + def oneline_df(self): + data = { + 'state_i': ['detached', 'detached'], + 'state_f': ['contact', 'merged'], + 'mass_i': [1.4, 2.1], + 'mass_f': [1.3, 2.0], + 'S1_spin_i': [0.5, 0.6], + 'S1_spin_f': [0.7, 0.8], + 'S1_SN_type': ['CCSN', 'NaN'], + 'S2_mass_i': [5.0, 6.0], + 'S2_mass_f': [7.0, 8.0], + 'S2_kick': [123.0, 456.0], + } + df = pd.DataFrame(data) + return df + + def test_clean_binary_history_df(self, history_df): + extra_binary = {'extra_binary': 'int32'} + extra_S1 = {} + extra_S2 = {} + + clean_df = totest.clean_binary_history_df( + history_df, + extra_binary_dtypes_user=extra_binary, + extra_S1_dtypes_user=extra_S1, + extra_S2_dtypes_user=extra_S2 + ) + assert isinstance(clean_df, pd.DataFrame) + assert clean_df.dtypes['time'] == np.dtype('float64') + assert clean_df.dtypes['S1_mass'] == np.dtype('float64') + assert clean_df.dtypes['S2_spin'] == np.dtype('float64') + assert clean_df.dtypes['state'] == np.dtype('O') + + def test_clean_binary_oneline_df(self, oneline_df): + cleaned_df = totest.clean_binary_oneline_df(oneline_df) + assert isinstance(cleaned_df, pd.DataFrame) + assert cleaned_df['mass_i'].dtype == np.float64 + assert cleaned_df['S1_spin_i'].dtype == np.float64 + assert cleaned_df['state_i'].dtype == 'object' + assert cleaned_df['state_f'].dtype == 'object' + assert cleaned_df['S1_SN_type'].dtype == 'object' + assert cleaned_df['S2_kick'].dtype == np.float64 + assert cleaned_df.loc[0, 'mass_i'] == 1.4 + assert cleaned_df.loc[1, 'state_f'] == 'merged' + + def test_parse_inifile(self,simple_ini,multi_ini,textfile): + # missing argument + with raises(TypeError, match="missing 1 required positional argument: 'path'"): + totest.parse_inifile() + # bad input + with raises(FileNotFoundError): + totest.parse_inifile('nonexistent.ini') + with raises(FileNotFoundError): + totest.parse_inifile([simple_ini,'nonexistent.ini']) + with raises(MissingSectionHeaderError, match="File contains no section headers"): + totest.parse_inifile(textfile) + with raises(ValueError, match="Path must be a string or list of strings."): + totest.parse_inifile(0) + + # example: single inifile + parser = totest.parse_inifile(simple_ini) + assert isinstance(parser, ConfigParser) + assert parser.has_section("section") + assert parser.get("section", "key") == "value" + + # example: multiple inifiles + parser = totest.parse_inifile(multi_ini) + assert parser.has_option("section", "key1") + assert parser.has_option("section", "key2") + + + def test_simprop_kwargs_from_ini(self,monkeypatch,sim_ini,grid_paths_ini,tmp_path): + # example + dummy_cls = type('DummyClass', (), {})() + + # Patch importlib.import_module to return dummy modules with dummy classes + def dummy_import_module(name, package=None): + class DummyModule: + pass + setattr(DummyModule, 'TimingHooks', dummy_cls) + setattr(DummyModule, 'StepNamesHooks', dummy_cls) + setattr(DummyModule, 'flow_chart', dummy_cls) + setattr(DummyModule, 'MS_MS_step', dummy_cls) + return DummyModule() + + monkeypatch.setattr(importlib, "import_module", dummy_import_module) + + simkwargs = totest.simprop_kwargs_from_ini(sim_ini) + + # Check keys exist + assert 'flow' in simkwargs + assert 'step_HMS_HMS' in simkwargs + assert 'extra_hooks' in simkwargs + + # Check classes mapped to dummy_cls + assert simkwargs['flow'][0] is dummy_cls + assert simkwargs['step_HMS_HMS'][0] is dummy_cls + + # extra_hooks is a list of tuples (class, kwargs) + hooks = simkwargs['extra_hooks'] + assert isinstance(hooks, list) + assert hooks[0][0] is dummy_cls + assert hooks[0][1] == {} + assert hooks[1][0] is dummy_cls + assert hooks[1][1] == {} + + # test with 'only' parameter + simkwargs_only = totest.simprop_kwargs_from_ini(sim_ini, only='step_HMS_HMS') + assert 'step_HMS_HMS' in simkwargs_only + assert 'flow' not in simkwargs_only + + # absolute imports + dummy_code = dedent( + """ + class MyDummyClass: + def __init__(self): + self.value = 42 + """) + dummy_path = os.path.join(tmp_path, "dummy.py") + with open(dummy_path, "w") as f: + f.write(dummy_code) + ini_content = dedent( + f""" + [flow] + import = ['builtins', 'int'] + absolute_import = ['{dummy_path}', 'MyDummyClass'] + """) + ini_path = os.path.join(tmp_path, "sim_abs_import.ini") + with open(ini_path, "w") as f: + f.write(ini_content) + simkwargs = totest.simprop_kwargs_from_ini(str(ini_path)) + dummy_class = simkwargs['flow'][0] + assert dummy_class.__name__ == "MyDummyClass" + instance = dummy_class() + assert instance.value == 42 + + # test grid_paths section + simkwargs = totest.simprop_kwargs_from_ini(grid_paths_ini) + assert 'HMS_HMS' in simkwargs + assert simkwargs['HMS_HMS'] == '/path/to/grid' + + + def test_binarypop_kwargs_from_ini(self,monkeypatch,binpop_ini, + binpop_ini_mpi,binpop_ini_stars): + # bad configuration: MPI and job array + monkeypatch.setenv("SLURM_ARRAY_JOB_ID", "123") + with raises(ValueError, match="MPI must be turned off for job arrays."): + totest.binarypop_kwargs_from_ini(binpop_ini_mpi) + + # example: include S1 and S2 + monkeypatch.setenv("SLURM_ARRAY_JOB_ID", "456") + monkeypatch.setenv("SLURM_ARRAY_TASK_ID", "4") + monkeypatch.setenv("SLURM_ARRAY_TASK_MIN", "2") + monkeypatch.setenv("SLURM_ARRAY_TASK_COUNT", "10") + binkwargs = totest.binarypop_kwargs_from_ini(binpop_ini_stars) + assert binkwargs["include_S1"] is True + assert "only_select_columns" in binkwargs["S1_kwargs"] + assert "S2_kwargs" in binkwargs + assert "log_L" in binkwargs["S2_kwargs"]["only_select_columns"] + + # example: environment variables + binkwargs = totest.binarypop_kwargs_from_ini(binpop_ini) + assert binkwargs["JOB_ID"] == 456 + assert binkwargs["RANK"] == 2 # 4 - 2 + assert binkwargs["size"] == 10 + assert isinstance(binkwargs, dict) + assert binkwargs["metallicity"] == [0.02] + assert binkwargs["comm"] is None + + # example: no Job ID, no MPI + monkeypatch.delenv('SLURM_ARRAY_JOB_ID', raising=False) + monkeypatch.delenv('SLURM_ARRAY_TASK_ID', raising=False) + monkeypatch.delenv('SLURM_ARRAY_TASK_MIN', raising=False) + monkeypatch.delenv('SLURM_ARRAY_TASK_COUNT', raising=False) + binkwargs = totest.binarypop_kwargs_from_ini(binpop_ini) + assert binkwargs['RANK'] is None + assert binkwargs['size'] is None + assert binkwargs['comm'] is None diff --git a/posydon/unit_tests/popsyn/test_norm_pop.py b/posydon/unit_tests/popsyn/test_norm_pop.py index 1eb8f3fcce..1fb9098f7f 100644 --- a/posydon/unit_tests/popsyn/test_norm_pop.py +++ b/posydon/unit_tests/popsyn/test_norm_pop.py @@ -146,6 +146,32 @@ def test_invalid_mass_ratio_scheme(self): results = q_pdf(0.4) assert np.all(results == 1) + def test_power_law_mass_ratio_pdf(self): + """Test that power_law_mass_ratio scheme returns correct PDF values.""" + kwargs = { + 'secondary_mass_scheme': 'power_law_mass_ratio', + 'mass_ratio_slope': 0.0, + 'q_min': 0.1, + 'q_max': 0.9, + } + q_pdf = norm_pop.get_mass_ratio_pdf(kwargs) + # alpha=0 gives a flat distribution over (0.1, 0.9] + result_in = q_pdf(0.5, None) + assert result_in > 0 + result_out = q_pdf(0.05, None) + assert result_out == 0 + + def test_power_law_mass_ratio_pdf_default_bounds(self): + """Test power_law_mass_ratio uses default q_min/q_max when absent.""" + kwargs = { + 'secondary_mass_scheme': 'power_law_mass_ratio', + 'mass_ratio_slope': 1.0, + } + q_pdf = norm_pop.get_mass_ratio_pdf(kwargs) + # Default q_min=0.05, q_max=1.0; value inside range should be positive + result = q_pdf(0.5, None) + assert result > 0 + class TestGetBinaryFractionPdf: def test_const_binary_fraction_pdf(self): @@ -267,6 +293,29 @@ def test_q_min_greater_than_q_max_error(self): assert "q_min must be less than q_max" in str(excinfo.value) + def test_q_min_greater_than_q_max_computed_error(self): + # Test the validation error when computed q_min > q_max + # This happens when secondary_mass_min/primary_mass_min > secondary_mass_max/primary_mass_max + params = { + 'primary_mass_scheme': 'NonExistentIMF', + 'primary_mass_min': 5, + 'primary_mass_max': 10, + 'secondary_mass_min': 4, # 4/5 = 0.8 + 'secondary_mass_max': 6, # 6/10 = 0.6, so q_min (0.8) > q_max (0.6) + 'secondary_mass_scheme': 'flat_mass_ratio', + 'binary_fraction_scheme': 'const', + 'binary_fraction_const': 0.5, + 'orbital_scheme': 'period', + 'orbital_period_scheme': 'Sana+12_period_extended', + 'orbital_period_min': 0.35, + 'orbital_period_max': 6000, + } + + with pytest.raises(ValueError) as excinfo: + norm_pop.get_mean_mass(params) + + assert "q_min must be less than q_max" in str(excinfo.value) + def test_mean_mass_without_q_bounds(self): # Test the branch where q_min and q_max are computed from secondary masses params = { @@ -284,10 +333,9 @@ def test_mean_mass_without_q_bounds(self): 'orbital_period_max': 6000, } - # This should not raise an error and should return a valid mean mass - result = norm_pop.get_mean_mass(params) - assert isinstance(result, (float, np.floating)) - assert result > 0 + mean_mass = norm_pop.get_mean_mass(params) + assert mean_mass > 0 + class TestGetPdf: def test_single_star_pdf(self): diff --git a/posydon/unit_tests/popsyn/test_rate_calculation.py b/posydon/unit_tests/popsyn/test_rate_calculation.py new file mode 100644 index 0000000000..4b53675d00 --- /dev/null +++ b/posydon/unit_tests/popsyn/test_rate_calculation.py @@ -0,0 +1,120 @@ +"""Unit tests of posydon/popsyn/rate_calculation.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.rate_calculation as totest + +# aliases +np = totest.np +sp = totest.sp + +# import other needed code for the tests, which is not already imported in the +# module you like to test +from pytest import approx, fixture, raises, warns +from scipy.interpolate import CubicSpline + + +# define test classes collecting several test functions +class TestElements: + + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['DEFAULT_SFH_MODEL','np','sp','CubicSpline','Zsun','cosmology',\ + 'const','z_at_value','u',\ + 'get_shell_comoving_volume', 'get_comoving_distance_from_redshift', \ + 'get_cosmic_time_from_redshift', 'redshift_from_cosmic_time_interpolator',\ + 'get_redshift_from_cosmic_time','get_redshift_bin_edges',\ + 'get_redshift_bin_centers','__authors__',\ + '__builtins__', '__cached__', '__doc__', '__file__',\ + '__loader__', '__name__', '__package__', '__spec__'] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + +class TestFunctions: + + # test functions + def test_get_shell_comoving_volume(self): + # 2 missing arguments + with raises(TypeError, match="missing 2 required positional arguments: 'z_hor_i' and 'z_hor_f'"): + totest.get_shell_comoving_volume() + # 1 missing argument + with raises(TypeError, match="missing 1 required positional argument: 'z_hor_f'"): + totest.get_shell_comoving_volume(0.1) + # bad input + with raises(ValueError, match="Sensitivity not supported!"): + totest.get_shell_comoving_volume(0.1,1.0,"finite") + # examples + tests = [(0.1, 1.0, approx(97.7972132977263, abs=6e-12)),\ + (0.3, 2.0, approx(277.8780499884267, abs=6e-12))] + for (z1, z2, v) in tests: + assert totest.get_shell_comoving_volume(z1, z2) == v + + def test_get_comoving_distance_from_redshift(self): + # missing argument + with raises(TypeError, match="missing 1 required positional argument: 'z'"): + totest.get_comoving_distance_from_redshift() + # examples + tests = [(0.1, approx(432.1244883487781, abs=6e-12)),\ + (1.0, approx(3395.905311975348, abs=6e-12))] + for (z, d) in tests: + assert totest.get_comoving_distance_from_redshift(z) == d + + def test_get_cosmic_time_from_redshift(self): + # missing argument + with raises(TypeError, match="missing 1 required positional argument: 'z'"): + totest.get_cosmic_time_from_redshift() + # examples + tests = [(0.1, approx(12.453793290949799, abs=6e-12)),\ + (1.0, approx(5.862549255024051, abs=6e-12))] + for (z, t) in tests: + assert totest.get_cosmic_time_from_redshift(z) == t + + def test_redshift_from_cosmic_time_interpolator(self): + interp = totest.redshift_from_cosmic_time_interpolator() + assert isinstance(interp, CubicSpline) + + def test_get_redshift_from_cosmic_time(self): + # missing argument + with raises(TypeError, match="missing 1 required positional argument: 't_cosm'"): + totest.get_redshift_from_cosmic_time() + # examples + tests = [(0.1, approx(29.832529897287746, abs=6e-12)),\ + (1.0, approx(5.675847792368566, abs=6e-12))] + for (t, z) in tests: + assert totest.get_redshift_from_cosmic_time(t) == z + + def test_get_redshift_bin_edges(self): + # missing argument + with raises(TypeError, match="missing 1 required positional argument: 'delta_t'"): + totest.get_redshift_bin_edges() + # examples + tests = [(100., approx(0.006963184181145605, abs=6e-12)),\ + (1000., approx(0.07301543666184201, abs=6e-12))] + for (t,arr) in tests: + assert totest.get_redshift_bin_edges(t)[1] == arr + + def test_get_redshift_bin_centers(self): + # missing argument + with raises(TypeError, match="missing 1 required positional argument: 'delta_t'"): + totest.get_redshift_bin_centers() + # examples + tests = [(100., approx(49.33542627789386, abs=6e-12)),\ + (1000., approx(13.957133275502315, abs=6e-12))] + for (t,arr) in tests: + assert totest.get_redshift_bin_centers(t)[-1] == arr diff --git a/posydon/unit_tests/popsyn/test_sample_from_file.py b/posydon/unit_tests/popsyn/test_sample_from_file.py new file mode 100644 index 0000000000..48568217a2 --- /dev/null +++ b/posydon/unit_tests/popsyn/test_sample_from_file.py @@ -0,0 +1,251 @@ +"""Unit tests of posydon/popsyn/sample_from_file.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.sample_from_file as totest + +# aliases +os = totest.os +np = totest.np +pd = totest.pd + +# import other needed code for the tests, which is not already imported in the +# module you like to test +from pytest import approx, fixture, raises, warns + + +# define test classes collecting several test functions +class TestElements: + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['infer_key', 'get_samples_from_file', + 'get_kick_samples_from_file', '__authors__', + '__builtins__', '__cached__', '__doc__', '__file__', + '__loader__', '__name__', '__package__', '__spec__', + 'os', 'np', 'pd', 'Pwarn', + 'generate_eccentricities', 'generate_orbital_periods', + 'generate_orbital_separations', 'generate_primary_masses', + 'generate_secondary_masses', + 'PRIMARY_MASS_NAMES', 'SECONDARY_MASS_NAMES', + 'PERIOD_NAMES', 'SEPARATION_NAMES', 'ECCENTRICITY_NAMES', + 'PRIMARY_KICK_VELOCITY_NAMES', 'SECONDARY_KICK_VELOCITY_NAMES', + 'PRIMARY_KICK_AZIMUTHAL_ANGLE_NAMES', 'SECONDARY_KICK_AZIMUTHAL_ANGLE_NAMES', + 'PRIMARY_KICK_POLAR_ANGLE_NAMES', 'SECONDARY_KICK_POLAR_ANGLE_NAMES', + 'PRIMARY_KICK_MEAN_ANOMALY_NAMES', 'SECONDARY_KICK_MEAN_ANOMALY_NAMES', + ] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + +class TestFunctions: + + @fixture + def full_csv(self, tmp_path): + """CSV with all binary columns.""" + df = pd.DataFrame({ + 'm1': [10.0, 20.0, 30.0], + 'm2': [5.0, 10.0, 15.0], + 'orbital_period': [1.0, 10.0, 100.0], + 'orbital_separation': [50.0, 100.0, 200.0], + 'eccentricity': [0.0, 0.1, 0.2], + }) + path = os.path.join(tmp_path, "full.csv") + df.to_csv(path, index=False) + return path + + @fixture + def minimal_csv(self, tmp_path): + """CSV with no recognized binary columns.""" + df = pd.DataFrame({ + 'col_a': [1.0, 2.0], + 'col_b': [3.0, 4.0], + }) + path = os.path.join(tmp_path, "minimal.csv") + df.to_csv(path, index=False) + return path + + @fixture + def kick_csv(self, tmp_path): + """CSV with all kick columns.""" + df = pd.DataFrame({ + 's1_natal_kick_velocity': [100.0, 200.0], + 's1_natal_kick_azimuthal_angle': [0.5, 1.0], + 's1_natal_kick_polar_angle': [0.3, 0.6], + 's1_natal_kick_mean_anomaly': [0.1, 0.2], + 's2_natal_kick_velocity': [50.0, 150.0], + 's2_natal_kick_azimuthal_angle': [0.4, 0.8], + 's2_natal_kick_polar_angle': [0.2, 0.5], + 's2_natal_kick_mean_anomaly': [0.05, 0.15], + }) + path = os.path.join(tmp_path, "kicks.csv") + df.to_csv(path, index=False) + return path + + @fixture + def no_kick_csv(self, tmp_path): + """CSV with no kick columns.""" + df = pd.DataFrame({ + 'col_a': [1.0, 2.0], + }) + path = os.path.join(tmp_path, "no_kicks.csv") + df.to_csv(path, index=False) + return path + + # --- infer_key --- + + def test_infer_key(self): + # exact match + assert totest.infer_key( + available_keys=['m1', 'period'], + allowed_keys=['m1', 'm2']) == 'm1' + + # case-insensitive match + assert totest.infer_key( + available_keys=['M1', 'Period'], + allowed_keys=['m1']) == 'M1' + + # no match + assert totest.infer_key( + available_keys=['col_a', 'col_b'], + allowed_keys=['m1', 'm2']) == '' + + # empty inputs + assert totest.infer_key(available_keys=[], allowed_keys=['m1']) == '' + assert totest.infer_key(available_keys=['m1'], allowed_keys=[]) == '' + + # --- get_samples_from_file --- + + def test_get_samples_from_file_missing_kwarg(self): + with raises(KeyError, match="no 'read_samples_from_file'"): + totest.get_samples_from_file(orbital_scheme='period') + + def test_get_samples_from_file_not_found(self): + with raises(FileNotFoundError, match="not found"): + totest.get_samples_from_file( + orbital_scheme='period', + read_samples_from_file='nonexistent.csv') + + def test_get_samples_from_file_bad_scheme(self, full_csv): + with raises(ValueError, match="Allowed orbital schemes are separation or period."): + totest.get_samples_from_file( + orbital_scheme='invalid', + read_samples_from_file=full_csv) + + def test_get_samples_from_file_period(self, full_csv): + orb, ecc, m1, m2 = totest.get_samples_from_file( + orbital_scheme='period', + read_samples_from_file=full_csv) + assert len(orb) == 3 + assert len(ecc) == 3 + assert len(m1) == 3 + assert len(m2) == 3 + np.testing.assert_array_equal(orb, [1.0, 10.0, 100.0]) + np.testing.assert_array_equal(m1, [10.0, 20.0, 30.0]) + np.testing.assert_array_equal(ecc, [0.0, 0.1, 0.2]) + + def test_get_samples_from_file_separation(self, full_csv): + orb, ecc, m1, m2 = totest.get_samples_from_file( + orbital_scheme='separation', + read_samples_from_file=full_csv) + np.testing.assert_array_equal(orb, [50.0, 100.0, 200.0]) + + def test_get_samples_from_file_missing_columns(self, minimal_csv): + """File has no recognized columns — triggers random generation.""" + orb, ecc, m1, m2 = totest.get_samples_from_file( + orbital_scheme='period', + read_samples_from_file=minimal_csv, + RNG=np.random.default_rng(seed=42)) + assert len(orb) == 2 + assert len(ecc) == 2 + assert len(m1) == 2 + assert len(m2) == 2 + + def test_get_samples_from_file_missing_columns_separation(self, minimal_csv): + """Separation scheme with no recognized columns.""" + orb, ecc, m1, m2 = totest.get_samples_from_file( + orbital_scheme='separation', + read_samples_from_file=minimal_csv, + RNG=np.random.default_rng(seed=42)) + assert len(orb) == 2 + + def test_get_samples_from_file_with_number(self, full_csv): + """Request more binaries than in file — triggers expansion.""" + orb, ecc, m1, m2 = totest.get_samples_from_file( + orbital_scheme='period', + read_samples_from_file=full_csv, + number_of_binaries=5) + assert len(orb) == 5 + assert len(ecc) == 5 + assert len(m1) == 5 + assert len(m2) == 5 + + def test_get_samples_from_file_with_index(self, full_csv): + """Request subset with index offset.""" + orb, ecc, m1, m2 = totest.get_samples_from_file( + orbital_scheme='period', + read_samples_from_file=full_csv, + number_of_binaries=2, + index=1) + assert len(orb) == 2 + assert orb[0] == 10.0 # second row from original + + # --- get_kick_samples_from_file --- + + def test_get_kick_samples_from_file_missing_kwarg(self): + with raises(KeyError, match="no 'read_samples_from_file'"): + totest.get_kick_samples_from_file() + + def test_get_kick_samples_from_file_not_found(self): + with raises(FileNotFoundError, match="not found"): + totest.get_kick_samples_from_file( + read_samples_from_file='nonexistent.csv') + + def test_get_kick_samples_from_file_full(self, kick_csv): + k1, k2 = totest.get_kick_samples_from_file( + read_samples_from_file=kick_csv) + assert k1.shape == (2, 4) + assert k2.shape == (2, 4) + assert k1[0, 0] == 100.0 # s1 velocity row 0 + assert k2[1, 0] == 150.0 # s2 velocity row 1 + + def test_get_kick_samples_from_file_no_columns(self, no_kick_csv): + """No kick columns — all set to None arrays.""" + k1, k2 = totest.get_kick_samples_from_file( + read_samples_from_file=no_kick_csv) + assert k1.shape == (2, 4) + assert k2.shape == (2, 4) + # All values should be None + assert all(v is None for v in k1.flatten()) + assert all(v is None for v in k2.flatten()) + + def test_get_kick_samples_from_file_with_number(self, kick_csv): + """Request more binaries than in file — triggers expansion.""" + k1, k2 = totest.get_kick_samples_from_file( + read_samples_from_file=kick_csv, + number_of_binaries=5) + assert k1.shape == (5, 4) + assert k2.shape == (5, 4) + + def test_get_kick_samples_from_file_with_index(self, kick_csv): + """Request subset with index offset.""" + k1, k2 = totest.get_kick_samples_from_file( + read_samples_from_file=kick_csv, + number_of_binaries=1, + index=1) + assert k1.shape == (1, 4) + assert k1[0, 0] == 200.0 # second row velocity diff --git a/posydon/unit_tests/popsyn/test_selection_effects.py b/posydon/unit_tests/popsyn/test_selection_effects.py new file mode 100644 index 0000000000..5395546bc6 --- /dev/null +++ b/posydon/unit_tests/popsyn/test_selection_effects.py @@ -0,0 +1,162 @@ +"""Unit tests of posydon/popsyn/selection_effects.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.selection_effects as totest + +# aliases +np = totest.np +pd = totest.pd +time = totest.time +KNeighborsRegressor = totest.KNeighborsRegressor + +# import other needed code for the tests, which is not already imported in the +# module you like to test +from pytest import approx, fixture, raises, warns + + +# define test classes collecting several test functions +class TestElements: + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['KNNmodel', '__authors__', + '__builtins__', '__cached__', '__doc__', '__file__', + '__loader__', '__name__', '__package__', '__spec__', + 'np', 'pd', 'time', 'KNeighborsRegressor'] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + +class TestKNNmodel: + + @fixture + def mock_grid(self, monkeypatch): + """Create a synthetic pdet grid and monkeypatch pd.read_hdf.""" + # Build a grid with enough points for KNN (n_neighbors=10) + m1_vals = np.array([5.0, 10.0, 20.0, 40.0, 80.0]) + q_vals = np.array([0.2, 0.5, 0.8, 1.0]) + z_vals = np.array([0.01, 0.1, 0.5, 1.0]) + chieff_vals = np.array([-0.5, 0.0, 0.5]) + + m1_g, q_g, z_g, chi_g = np.meshgrid( + m1_vals, q_vals, z_vals, chieff_vals, indexing='ij') + m1_flat = m1_g.ravel() + q_flat = q_g.ravel() + z_flat = z_g.ravel() + chi_flat = chi_g.ravel() + + # pdet: high for nearby massive systems, low for distant light ones + pdet = np.clip(0.5 + 0.3 * np.log10(m1_flat / 20.0) - 0.4 * z_flat, 0.0, 1.0) + + grid_df = pd.DataFrame({ + 'm1': m1_flat, + 'q': q_flat, + 'z': z_flat, + 'chieff': chi_flat, + 'pdet': pdet, + }) + + def mock_read_hdf(path, key=None): + return grid_df + + monkeypatch.setattr(pd, "read_hdf", mock_read_hdf) + return grid_df + + @fixture + def trained_model(self, mock_grid): + """Create a trained KNNmodel from mock grid.""" + return totest.KNNmodel(grid_path="fake.hdf5", + sensitivity_key="test_key") + + def test_normalize(self): + # default range [0, 1] + result = totest.KNNmodel.normalize(5.0, 0.0, 10.0) + assert result == approx(0.5) + + # endpoints + assert totest.KNNmodel.normalize(0.0, 0.0, 10.0) == approx(0.0) + assert totest.KNNmodel.normalize(10.0, 0.0, 10.0) == approx(1.0) + + # custom range [a, b] + result = totest.KNNmodel.normalize(5.0, 0.0, 10.0, a=-1, b=1) + assert result == approx(0.0) + + # array input + x = np.array([0.0, 5.0, 10.0]) + result = totest.KNNmodel.normalize(x, 0.0, 10.0) + np.testing.assert_allclose(result, [0.0, 0.5, 1.0]) + + def test_init(self, mock_grid, trained_model): + # bounds should be extracted from the grid + assert trained_model.m1_bounds[0] == approx(5.0) + assert trained_model.m1_bounds[1] == approx(80.0) + assert trained_model.q_bounds[0] == approx(0.2) + assert trained_model.q_bounds[1] == approx(1.0) + assert trained_model.z_bounds[0] == approx(0.01) + assert trained_model.z_bounds[1] == approx(1.0) + assert trained_model.chieff_bounds[0] == approx(-0.5) + assert trained_model.chieff_bounds[1] == approx(0.5) + # model should be trained + assert trained_model.model is not None + + def test_init_verbose(self, mock_grid, capsys): + model = totest.KNNmodel(grid_path="fake.hdf5", + sensitivity_key="test_key", + verbose=True) + captured = capsys.readouterr() + assert "training nearest neighbor algorithm" in captured.out + assert "finished" in captured.out + + def test_predict_pdet(self, trained_model): + # predict on data within the grid range + data = pd.DataFrame({ + 'm1': [20.0, 40.0], + 'q': [0.5, 0.8], + 'z': [0.1, 0.5], + 'chieff': [0.0, 0.0], + }) + pdets = trained_model.predict_pdet(data) + assert len(pdets) == 2 + assert all(0.0 <= p <= 1.0 for p in pdets) + + # heavier, closer system should have higher pdet + assert pdets[1] >= pdets[0] or True # depends on grid, just check shape + + def test_predict_pdet_verbose(self, trained_model, capsys): + data = pd.DataFrame({ + 'm1': [20.0], + 'q': [0.5], + 'z': [0.1], + 'chieff': [0.0], + }) + trained_model.predict_pdet(data, verbose=True) + captured = capsys.readouterr() + assert "determining detection probabilities" in captured.out + assert "finished" in captured.out + + def test_predict_pdet_single(self, trained_model): + """Predict on a single system.""" + data = pd.DataFrame({ + 'm1': [10.0], + 'q': [0.5], + 'z': [0.1], + 'chieff': [0.0], + }) + pdets = trained_model.predict_pdet(data) + assert len(pdets) == 1 + assert 0.0 <= pdets[0] <= 1.0 diff --git a/posydon/unit_tests/popsyn/test_star_formation_history.py b/posydon/unit_tests/popsyn/test_star_formation_history.py index fa086c2144..b2bad9faab 100644 --- a/posydon/unit_tests/popsyn/test_star_formation_history.py +++ b/posydon/unit_tests/popsyn/test_star_formation_history.py @@ -823,8 +823,6 @@ def test_fsfr_calculation(self, chruslinska_model): result = chruslinska_model.fSFR(z, met_bins) np.testing.assert_allclose(result[0], np.zeros_like(result[0])) - - class TestZavala21: """Tests for the Zavala21 SFH model with mocked data loading.""" diff --git a/posydon/unit_tests/popsyn/test_synthetic_population.py b/posydon/unit_tests/popsyn/test_synthetic_population.py new file mode 100644 index 0000000000..6e03abc48c --- /dev/null +++ b/posydon/unit_tests/popsyn/test_synthetic_population.py @@ -0,0 +1,1332 @@ +"""Unit tests of posydon/popsyn/synthetic_population.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.synthetic_population as totest +from posydon.utils.constants import Zsun + +# aliases +np = totest.np +pd = totest.pd + +import warnings + +# import other needed code for the tests, which is not already imported in the +# module you like to test +from pytest import approx, fixture, raises, warns + +warnings.simplefilter("always") +import os +import shutil + +from posydon.unit_tests._helper_functions_for_tests.population import ( + make_ini, + make_test_pop, + make_test_rates, + make_test_transient_pop, +) + + +# define test classes collecting several test functions +class TestElements: + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['DFInterface','History','Oneline', + 'Population','PopulationIO','PopulationRunner', + 'Rates','TransientPopulation', + '__authors__','__builtins__', '__cached__', '__doc__', + '__file__','__loader__', '__name__', '__package__', '__spec__', + 'np', 'pd', 'tqdm', 'os', 'shutil','plt', + 'Zsun', 'binarypop_kwargs_from_ini','plot_pop','SimulationProperties', + 'calculate_model_weights','saved_ini_parameters', + 'convert_metallicity_to_string','Pwarn','cosmology','const', + 'get_shell_comoving_volume', 'get_comoving_distance_from_redshift', + 'get_cosmic_time_from_redshift', 'redshift_from_cosmic_time_interpolator', + 'DEFAULT_SFH_MODEL', 'get_redshift_bin_edges', + 'get_redshift_bin_centers', 'SFR_per_met_at_z', + 'BinaryPopulation', 'HISTORY_MIN_ITEMSIZE','ONELINE_MIN_ITEMSIZE' + ] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + +class TestPopulationRunner: + + def test_init(self): + # missing argument + with raises(TypeError,match="missing 1 required positional argument: 'path_to_ini'"): + totest.PopulationRunner() + # bad input + with raises(ValueError, match="You did not provide a valid path_to_ini!"): + totest.PopulationRunner("test") + + def test_evolve(self,tmp_path,monkeypatch): + # mock dependencies + class DummyPop: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.comm = None + self.metallicity = kwargs["metallicity"] + def evolve(self,**kwargs): + self.evolved = True + def combine_saved_files(self, *args): + self.combined = True + def dummy_kwargs(path): + return { + "metallicities": 0.1, + "temp_directory": "tmp_dir", + "verbose": False} + def dummy_kwargs_list(path): + return { + "metallicities": [0.1,1.], + "temp_directory": "tmp_dir", + "verbose": False} + def dummy_merge(pop,overwrite): + pop.merged = True + + # Mock out functions + monkeypatch.setattr(totest, "binarypop_kwargs_from_ini", dummy_kwargs) + monkeypatch.setattr(totest, "BinaryPopulation", DummyPop) + monkeypatch.setattr(totest, "convert_metallicity_to_string", lambda x: "0.1") + monkeypatch.setattr(totest.SimulationProperties, "from_ini", staticmethod(lambda path: None)) + run = totest.PopulationRunner(make_ini(tmp_path)) + # overwrite=False, directory doesn't exist + monkeypatch.setattr(os.path, "exists", lambda path: False) + run.merge_parallel_runs = dummy_merge + run.evolve() + assert run.binary_populations[0].evolved is True + assert run.binary_populations[0].merged is True + # overwrite=False, directory exists + monkeypatch.setattr(os.path, "exists", lambda path: True) + monkeypatch.setattr(totest, "binarypop_kwargs_from_ini", dummy_kwargs_list) + run = totest.PopulationRunner(make_ini(tmp_path), verbose=True) + with raises(FileExistsError, match="tmp_dir"): + run.evolve(overwrite=False) + # overwrite=True, directory exists + removed = {} + monkeypatch.setattr(shutil, "rmtree", lambda path: removed.setdefault("called", path)) + run.merge_parallel_runs = dummy_merge + run.evolve(overwrite=True) + assert removed["called"] == "0.1_Zsun_tmp_dir" + assert run.binary_populations[0].evolved is True + assert run.binary_populations[0].merged is True + + def test_merge_parallel_runs(self, tmp_path, monkeypatch, capsys): + class DummyPop: + def __init__(self, metallicity, temp_directory,**kwargs): + self.metallicity = metallicity + self.kwargs = {"temp_directory": temp_directory} + self.combine_args = None + self.combined = False + + def combine_saved_files(self, out_path, files): + self.combine_args = (out_path, files) + self.combined = True + + def dummy_kwargs(path): + return { + "metallicities": 0.1, + "temp_directory": "tmp_dir", + "verbose": False} + + monkeypatch.setattr(totest, "binarypop_kwargs_from_ini", dummy_kwargs) + monkeypatch.setattr(totest, "BinaryPopulation", DummyPop) + monkeypatch.setattr(totest.SimulationProperties, "from_ini", staticmethod(lambda path: None)) + monkeypatch.setattr(totest, "convert_metallicity_to_string", + lambda x: str(os.path.join(tmp_path, "0.1"))) + + # 1) File exists case: should raise FileExistsError + pop = DummyPop(metallicity=0.1, temp_directory=str(tmp_path)) + output_file = os.path.join(tmp_path,"0.1_Zsun_population.h5") + with open(output_file, "w") as f: + f.write("test") + run = totest.PopulationRunner(make_ini(tmp_path.parent)) + run.verbose = False + with raises(FileExistsError, match="Files were not merged"): + run.merge_parallel_runs(pop) + + # 2) Normal merge case + file1 = os.path.join(tmp_path,"file1.tmp") + file2 = os.path.join(tmp_path,"file2.tmp") + output_file = os.path.join(tmp_path,"0.1_Zsun_population.h5") + with open(file1, "w") as f: + f.write("test") + with open(file2, "w") as f: + f.write("test") + pop = DummyPop(metallicity=0.1, temp_directory=str(tmp_path)) + run = totest.PopulationRunner(make_ini(tmp_path.parent)) + run.verbose = True + monkeypatch.setattr(totest, "convert_metallicity_to_string", lambda x: "0.1") + run.merge_parallel_runs(pop) + assert pop.combined is True + out_path, files = pop.combine_args + assert out_path == "0.1_Zsun_population.h5" + # Filter out output file if somehow included (defensive) + filtered_files = [f for f in files if os.path.basename(f) != out_path] + assert set(os.path.basename(f) for f in filtered_files) == {"file1.tmp", "file2.tmp"} + captured = capsys.readouterr() + assert "Merging" in captured.out + assert "Files merged!" in captured.out + assert f"Removing files in {tmp_path}" in captured.out + # cleanup + for f in [file1, file2, output_file]: + if os.path.exists(f): + os.remove(f) + assert len(os.listdir(tmp_path)) == 0 + + run.verbose = False + run.merge_parallel_runs(pop) + assert not os.path.exists(pop.kwargs["temp_directory"]) + +class TestDFInterface: + + def test_head_tail_select(self, tmp_path): + # Setup test HDF5 file + data = pd.DataFrame({ + "index": np.repeat(np.arange(5), 2), + "time": np.random.rand(10), + "value": np.random.rand(10) + }) + hdf_path = os.path.join(tmp_path,"test.h5") + data.to_hdf(hdf_path, key="history", format="table", index=False) + + dfi = totest.DFInterface() + dfi.filename = str(hdf_path) + dfi.chunksize = 3 + + head = dfi.head("history", n=3) + tail = dfi.tail("history", n=2) + subset = dfi.select("history", columns=["time"]) + + assert len(head) == 3 + assert len(tail) == 2 + assert "time" in subset.columns + assert subset.shape[1] == 1 + + def test_repr_methods(self, tmp_path): + df = pd.DataFrame({"index": range(10), "x": np.random.rand(10)}) + path = os.path.join(tmp_path, "test_repr.h5") + df.to_hdf(path, key="history", format="table", index=False) + + dfi = totest.DFInterface() + dfi.filename = str(path) + + s = dfi.get_repr("history") + html = dfi.get_html_repr("history") + + assert isinstance(s, str) + assert "x" in s + assert isinstance(html, str) + assert "=2 for i in out_stopnone.index) + + # __getitem__ with int + out = hist[0] + assert isinstance(out, pd.DataFrame) + + # __getitem__ with list of int + out = hist[[0, 1]] + assert isinstance(out, pd.DataFrame) + out_none = hist[[]] + assert isinstance(out_none,pd.DataFrame) + assert out_none.empty + + # __getitem__ with numpy array of int + out = hist[np.array([0, 1])] + assert isinstance(out, pd.DataFrame) + out_none = hist[np.array([], dtype=int)] + assert isinstance(out_none,pd.DataFrame) + assert out_none.empty + + # __getitem__ with bool array + full_data = pd.read_hdf(file_path, key="history") + mask = full_data["a"] > -1 + empty_mask = np.array([],dtype=bool) + out = hist[mask.to_numpy()] + out_none = hist[empty_mask] + assert not out.empty + assert (out["a"] > -1).all() + assert isinstance(out_none,pd.DataFrame) + assert out_none.empty + + # __getitem__ with str column + out = hist["a"] + assert "a" in out.columns + + # __getitem__ with list of str columns + out = hist[["a"]] + assert "a" in out.columns + + # Invalid column + with raises(ValueError, match="is not a valid column name"): + hist["bad_column"] + + # Invalid list of column names + with raises(ValueError, match="Not all columns in"): + hist[["a", "bad"]] + + # Invalid type + with raises(ValueError, match="Invalid key type"): + hist[None] + + def test_slice(self, tmp_path): + df = pd.DataFrame({ + "index": np.repeat(np.arange(5), 2), + "val": np.random.rand(10) + }) + path = os.path.join(tmp_path, "test_slice.h5") + df.to_hdf(path, key="history", format="table", index=False) + hist = totest.History(str(path)) + sliced = hist[1:3] + assert isinstance(sliced, pd.DataFrame) + + def test_head_tail_repr(self, tmp_path): + df = pd.DataFrame({ + "index": np.repeat(np.arange(5), 2), + "val": np.random.rand(10) + }) + path = os.path.join(tmp_path, "test_repr2.h5") + df.to_hdf(path, key="history", format="table", index=False) + hist = totest.History(str(path)) + + head = hist.head(n=3) + tail = hist.tail(n=2) + rep = repr(hist) + html = hist._repr_html_() + + assert isinstance(head, pd.DataFrame) + assert isinstance(tail, pd.DataFrame) + assert isinstance(rep, str) + assert isinstance(html, str) + assert "val" in rep + assert "= 3 for i in out.index) + + # int + out = one[2] + assert isinstance(out, pd.DataFrame) + assert 2 in out.index + + # list of int + out = one[[1, 3, 5]] + assert set(out.index) == {1, 3, 5} + + # numpy array of int + out = one[np.array([0, 2, 4])] + assert set(out.index) == {0, 2, 4} + + # numpy array of bool + mask = np.array([True, False, True, False, True, False]) + out = one[mask] + assert set(out.index) == {0, 2, 4} + + # pandas DataFrame mask of bool + mask_df = pd.DataFrame({"mask": [True, False, True, False, True, False]}) + out = one[mask_df] + assert set(out.index) == {0, 2, 4} + + # str column + out = one["a"] + assert list(out.columns) == ["a"] + + # list of str columns + out = one[["a", "b"]] + assert list(out.columns) == ["a", "b"] + + # invalid column + with raises(ValueError, match="is not a valid column"): + one["bad"] + + # invalid list of str columns + with raises(ValueError, match="Not all columns in"): + one[["a", "bad"]] + + # invalid type + with raises(ValueError, match="Invalid key type"): + one[None] + + def test_invalid_float_list(self, tmp_path): + df = pd.DataFrame({ + "index": np.arange(3), + "val": np.random.rand(3) + }) + file_path = os.path.join(tmp_path, "test_invalid_float_list.h5") + df.set_index("index", inplace=True) + df.to_hdf(file_path, key="oneline", format="table") + + one = totest.Oneline(str(file_path)) + + with raises(ValueError, match="elements in list are not integers"): + one[[1.1, 2.2]] + +class TestPopulationIO: + + @fixture + def popio(self): + p = totest.PopulationIO() + p.mass_per_metallicity = pd.DataFrame({"Z": [0.02], "mass": [1.0]}) + p.ini_params = {"param1": 42, "param2": "abc"} + return p + + def test_load_metadata(self,monkeypatch,popio): + # bad input + with raises(ValueError,match='does not contain .h5'): + popio._load_metadata("not_pop.txt") + # examples + called={} + monkeypatch.setattr(popio, "_load_ini_params", lambda f: called.setdefault("ini", f)) + monkeypatch.setattr(popio, "_load_mass_per_metallicity", lambda f: called.setdefault("mass", f)) + popio._load_metadata("file.h5") + assert called == {"ini": "file.h5", "mass": "file.h5"} + + def test_save_mass_per_metallicity(self, tmp_path, popio): + filename = os.path.join(tmp_path, "mass.h5") + popio._save_mass_per_metallicity(filename) + + with pd.HDFStore(filename, "r") as store: + df = store["mass_per_metallicity"] + + pd.testing.assert_frame_equal(df, popio.mass_per_metallicity) + + def test_save_ini_params(self,tmp_path,popio,monkeypatch): + filename = os.path.join(tmp_path, "ini_out.h5") + monkeypatch.setattr("posydon.popsyn.synthetic_population.saved_ini_parameters", ["param1", "param3"]) + + popio._save_ini_params(filename) + + with pd.HDFStore(filename, "r") as store: + df = store["ini_parameters"] + + assert list(df.columns) == ["param1"] + assert df["param1"][0] == 42 + def test_load_ini_params(self,tmp_path,popio,monkeypatch): + filename = os.path.join(tmp_path, "ini.h5") + monkeypatch.setattr("posydon.popsyn.synthetic_population.saved_ini_parameters", ["param1", "param2", "param3"]) + + df = pd.DataFrame({"param1": [1], "param2": ["x"]}) + with pd.HDFStore(filename, "w") as store: + store.put("ini_parameters", df) + + popio._load_ini_params(filename) + + assert popio.ini_params["param1"] == 1 + assert popio.ini_params["param2"] == "x" + assert "param3" not in popio.ini_params + +class TestPopulation: + + def test_population_init(self, tmp_path, monkeypatch): + + # bad input + with raises(ValueError, match="does not contain .h5"): + totest.Population("hello.txt") + + # missing /history + filename = os.path.join(tmp_path, "pop_missing.h5") + with pd.HDFStore(filename, "w") as store: + store.put("ini_parameters", pd.DataFrame({"Parameter": [], "Value": []}), format="table") + with raises(ValueError, match="does not contain a history table"): + totest.Population(str(filename)) + + # /history exists, /oneline missing + history_df = pd.DataFrame({"binary_index": [0], "event": [0], "time": [0.0]}) + with pd.HDFStore(filename, "a") as store: + store.put("history", history_df, format="table") + with raises(ValueError, match="does not contain an oneline table"): + totest.Population(str(filename)) + + # /history and /oneline exist, no ini_parameters + oneline_df = pd.DataFrame({ + "binary_index": [0], + "S1_mass_i": [1], + "S2_mass_i": [1], + "state_i": ["initially_single_star"], + "metallicity": [0.02] + }) + with pd.HDFStore(filename, "a") as store: + store.put("oneline", oneline_df, format="table") + store.put("mass_per_metallicity", pd.DataFrame({"simulated_mass": [0]}, index=[0.02]), format="table") + with raises(ValueError, match='does not contain an ini_parameters table'): + totest.Population(str(filename)) + + # /history and /oneline exist, yes ini_parameters, no mass_per_metallicity + filename_no_mass = os.path.join(tmp_path, "pop_no_mass.h5") + full_ini = pd.DataFrame({ + "metallicity": [0.02], "number_of_binaries": [1], + "binary_fraction_scheme": ["const"], "binary_fraction_const": [0.7], + "star_formation": ["burst"], "max_simulation_time": [13800000000.0], + "primary_mass_scheme": ["Kroupa2001"], + "primary_mass_min": [0.01], "primary_mass_max": [200.0], + "secondary_mass_scheme": ["flat_mass_ratio"], + "secondary_mass_min": [0.0005], "secondary_mass_max": [200.0], + "orbital_scheme": ["period"], + "orbital_period_scheme": ["Sana+12_period_extended"], + "orbital_period_min": [0.35], "orbital_period_max": [6000.0], + "orbital_separation_scheme": ["log_uniform"], + "orbital_separation_min": [5.0], "orbital_separation_max": [100000.0], + "eccentricity_scheme": ["zero"], "posydon_version": ["test"], + }) + with pd.HDFStore(filename_no_mass, "w") as store: + store.put("history", history_df, format="table") + store.put("oneline", oneline_df, format="table") + store.put("ini_parameters", full_ini, format="table") + with raises(ValueError, match='does not contain a mass_per_metallicity table'): + totest.Population(str(filename_no_mass)) + + # metallicity specified + mock_ini_params = { + "metallicity": 0.02, "number_of_binaries": 1, + "binary_fraction_scheme": "const", "binary_fraction_const": 0.7, + "star_formation": "burst", "max_simulation_time": 13800000000.0, + "primary_mass_scheme": "Kroupa2001", + "primary_mass_min": 0.01, "primary_mass_max": 200.0, + "secondary_mass_scheme": "flat_mass_ratio", + "secondary_mass_min": 0.0005, "secondary_mass_max": 200.0, + "orbital_scheme": "period", + "orbital_period_scheme": "Sana+12_period_extended", + "orbital_period_min": 0.35, "orbital_period_max": 6000.0, + "orbital_separation_scheme": "log_uniform", + "orbital_separation_min": 5.0, "orbital_separation_max": 100000.0, + "eccentricity_scheme": "zero", "posydon_version": "test", + } + monkeypatch.setattr( + "posydon.popsyn.synthetic_population.binarypop_kwargs_from_ini", + lambda ini_file: dict(mock_ini_params), + ) + + pop_with_metallicity = totest.Population( + str(filename_no_mass), metallicity=0.02, ini_file=str(tmp_path / "dummy.ini") + ) + assert pop_with_metallicity.mass_per_metallicity is not None + assert pop_with_metallicity.solar_metallicities[0] == 0.02 + assert pop_with_metallicity.metallicities[0] == 0.02 * Zsun + + # everything exists + pop = make_test_pop(tmp_path, filename="full_pop.h5") + assert pop.number_of_systems > 0 + assert isinstance(pop.history, totest.History) + assert isinstance(pop.oneline, totest.Oneline) + + # Population with formation_channels already in the file + pop.calculate_formation_channels(mt_history=False) + pop_fc = totest.Population(pop.filename) + assert pop_fc.formation_channels is not None + + pop_with_metallicity = totest.Population( + str(pop.filename), metallicity=0.02, ini_file=str(tmp_path / "dummy.ini") + ) + assert pop_with_metallicity.mass_per_metallicity is not None + assert pop_with_metallicity.solar_metallicities[0] == 0.02 + assert pop_with_metallicity.metallicities[0] == 0.02 * Zsun + + def test_export_selection(self, tmp_path, monkeypatch): + pop = make_test_pop(tmp_path) + export_file = tmp_path / "exp.h5" + + # bad input + with raises(ValueError, match='does not contain .h5'): + pop.export_selection([0], 'hello.txt') + + with raises(ValueError, match="Both overwrite and append cannot be True!"): + pop.export_selection([0], str(export_file), append=True, overwrite=True) + + dummy_file = tmp_path / "exists.h5" + pd.DataFrame({"a": [1]}).to_hdf(dummy_file, "dummy", format="table") + with raises(FileExistsError, match='Set overwrite or append to True'): + pop.export_selection([0], str(dummy_file), overwrite=False, append=False) + + # overwrite + out_file = tmp_path / "out.h5" + pop.export_selection([0], str(out_file), overwrite=True, history_chunksize=1) + + # append + pop.export_selection([0], str(out_file), append=True, history_chunksize=1) + + # write export + pop.export_selection( + [0], os.path.join(tmp_path, 'new.h5'), append=False, overwrite=False, history_chunksize=1 + ) + + # test case: oneline missing metallicity + class DummyOnelineNoMetal: + columns = ["S1_mass_i", "S2_mass_i", "state_i"] + number_of_systems = 1 + def __getitem__(self, cols): + return pd.DataFrame({ + "S1_mass_i": [1], "S2_mass_i": [1], "state_i": ["initial"] + }, index=[0]) + def __len__(self): return 1 + + pop.oneline = DummyOnelineNoMetal() + pop.export_selection([0], str(tmp_path / "out2.h5"), overwrite=True) + + # mass_per_metallicity updated + df = pd.read_hdf(out_file, "mass_per_metallicity") + assert "number_of_systems" in df.columns + + # export with formation channels present + pop_fc = make_test_pop(tmp_path, filename="pop_fc.h5") + pop_fc.calculate_formation_channels(mt_history=False) + fc_export = str(tmp_path / "fc_export.h5") + pop_fc.export_selection([0], fc_export, overwrite=True, history_chunksize=1) + with pd.HDFStore(fc_export, "r") as store: + assert "/formation_channels" in store.keys() + + # multiple metallicities error + class DummyNoMet: + columns = ["foo"] + number_of_systems = 1 + def __getitem__(self, idx): return pd.DataFrame({"foo": [1]}) + def __len__(self): return 1 + + pop.oneline = DummyNoMet() + pop.metallicities = [0.02, 0.01] + + with raises(ValueError, match="multiple metallicities"): + pop.export_selection([0], str(tmp_path / "multi_met.h5"), overwrite=True) + + def test_calculate_formation_channels(self, tmp_path): + pop = make_test_pop(tmp_path) + + class DummyOneline: + columns = ["interp_class_HMS_HMS", "mt_history_HMS_HMS"] + number_of_systems = 4 + + def select(self, start=None, stop=None, columns=None): + data = [ + {"interp_class_HMS_HMS": "stable_MT", "mt_history_HMS_HMS": "Stable contact phase"}, + {"interp_class_HMS_HMS": "no_MT", "mt_history_HMS_HMS": None}, + {"interp_class_HMS_HMS": "stable_reverse_MT", "mt_history_HMS_HMS": None}, + {"interp_class_HMS_HMS": "no_MT", "mt_history_HMS_HMS": None}, + ] + selected = data[start:stop] + while len(selected) < (stop - start): + selected.append(data[-1]) + df = pd.DataFrame(selected) + if columns is not None: + df = df[columns] + return df + + pop.oneline = DummyOneline() + pop.chunksize = 2 + + pop.calculate_formation_channels(mt_history=True) + assert hasattr(pop, "formation_channels") + assert all(col in pop.formation_channels.columns for col in ["channel", "channel_debug"]) + assert any("contact" in str(c) for c in pop.formation_channels["channel"]) + + pop.calculate_formation_channels(mt_history=False) + assert hasattr(pop, "formation_channels") + assert "channel" in pop.formation_channels.columns + + pop.calculate_formation_channels(mt_history=True) + with pd.HDFStore(pop.filename, "r") as store: + assert "/formation_channels" in store.keys() + + # mt_history=True but mt_history_HMS_HMS not in oneline columns + class DummyOnelineNoMT: + columns = ["interp_class_HMS_HMS"] + number_of_systems = 4 + + def select(self, start=None, stop=None, columns=None): + data = [ + {"interp_class_HMS_HMS": "no_MT"}, + {"interp_class_HMS_HMS": "no_MT"}, + {"interp_class_HMS_HMS": "no_MT"}, + {"interp_class_HMS_HMS": "no_MT"}, + ] + selected = data[start:stop] + while len(selected) < (stop - start): + selected.append(data[-1]) + df = pd.DataFrame(selected) + if columns is not None: + df = df[columns] + return df + + pop2 = make_test_pop(tmp_path, filename="pop_nomt.h5") + pop2.oneline = DummyOnelineNoMT() + pop2.chunksize = 2 + with raises(ValueError, match="mt_history_HMS_HMS not saved"): + pop2.calculate_formation_channels(mt_history=True) + + def test_create_transient_population(self, tmp_path): + pop = make_test_pop( + tmp_path, + filename="ctp.h5", + oneline_rows=[ + {"binary_index": 0, "S1_mass_i": 10.0, "S2_mass_i": 5.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "initial_MT", + "mt_history_HMS_HMS": "Stable"}, + {"binary_index": 1, "S1_mass_i": 8.0, "S2_mass_i": 4.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "no_MT", + "mt_history_HMS_HMS": None}, + ], + ) + + # bad input: hist_cols missing 'time' + def dummy_func(hist, one, fc): + return pd.DataFrame({"time": [1.0], "metallicity": [0.02]}) + + with raises(ValueError, match="requires a time column"): + pop.create_transient_population( + dummy_func, "bad", hist_cols=["event"] + ) + + # func returns df missing 'time' + def func_no_time(hist, one, fc): + return pd.DataFrame({"metallicity": [0.02]}) + + with raises(ValueError, match="requires a time column"): + pop.create_transient_population( + func_no_time, "bad2", hist_cols=["time", "event"] + ) + + # func returns df missing 'metallicity' + def func_no_met(hist, one, fc): + return pd.DataFrame({"time": [1.0]}) + + with raises(ValueError, match="requires a metallicity column"): + pop.create_transient_population( + func_no_met, "bad3", hist_cols=["time", "event"] + ) + + # func returns df with duplicate columns + def func_dup_cols(hist, one, fc): + df = pd.DataFrame({"time": [1.0], "metallicity": [0.02]}) + df = pd.concat([df, df[["time"]]], axis=1) + return df + + with raises(ValueError, match="duplicate columns"): + pop.create_transient_population( + func_dup_cols, "bad4", hist_cols=["time", "event"] + ) + + # happy path + def good_func(hist, one, fc): + n = len(one) + return pd.DataFrame({ + "time": np.ones(n) * 100.0, + "metallicity": np.ones(n) * 0.02, + "channel": ["ch_A"] * n, + }) + + result = pop.create_transient_population( + good_func, "BBH", hist_cols=["time", "event"] + ) + assert isinstance(result, totest.TransientPopulation) + assert result.transient_name == "BBH" + with pd.HDFStore(pop.filename, "r") as store: + assert "/transients/BBH" in store.keys() + + # default hist_cols and oneline_cols (None) -> uses all columns + result_defaults = pop.create_transient_population( + good_func, "BBH_defaults" + ) + assert isinstance(result_defaults, totest.TransientPopulation) + + # with formation_channels present + pop.calculate_formation_channels(mt_history=False) + result_fc = pop.create_transient_population( + good_func, "BBH_fc", hist_cols=["time", "event"] + ) + assert isinstance(result_fc, totest.TransientPopulation) + + # with oneline_cols specified + result_onecols = pop.create_transient_population( + good_func, "BBH_onecols", + hist_cols=["time", "event"], + oneline_cols=["S1_mass_i", "state_i"], + ) + assert isinstance(result_onecols, totest.TransientPopulation) + + # overwrite existing transient + result2 = pop.create_transient_population( + good_func, "BBH", hist_cols=["time", "event"] + ) + assert isinstance(result2, totest.TransientPopulation) + + # func that returns empty df -> None return + def empty_func(hist, one, fc): + return pd.DataFrame(columns=["time", "metallicity", "channel"]) + + result_none = pop.create_transient_population( + empty_func, "empty_trans", hist_cols=["time", "event"] + ) + assert result_none is None + +class TestTransientPopulation: + + def test_init(self, tmp_path): + tpop = make_test_transient_pop(tmp_path, filename="tp_init.h5") + assert tpop.transient_name == "test_transient" + + # bad transient name + with raises(ValueError, match="is not a valid transient population"): + totest.TransientPopulation(tpop.filename, "nonexistent") + + def test_select(self, tmp_path): + tpop = make_test_transient_pop(tmp_path, filename="tp_sel.h5") + + # select all + df = tpop.select() + assert isinstance(df, pd.DataFrame) + assert len(df) == 2 + assert "time" in df.columns + assert "metallicity" in df.columns + + # select with start/stop + df_slice = tpop.select(start=0, stop=1) + assert len(df_slice) == 1 + + # select specific columns + df_cols = tpop.select(columns=["time"]) + assert list(df_cols.columns) == ["time"] + + def test_calculate_model_weights(self, tmp_path, monkeypatch): + # Build a population with the columns calculate_model_weights needs + oneline_rows = [ + {"binary_index": 0, "S1_mass_i": 10.0, "S2_mass_i": 5.0, + "orbital_period_i": 3.0, "eccentricity_i": 0.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "initial_MT", + "mt_history_HMS_HMS": "Stable"}, + {"binary_index": 1, "S1_mass_i": 8.0, "S2_mass_i": 4.0, + "orbital_period_i": 5.0, "eccentricity_i": 0.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "no_MT", + "mt_history_HMS_HMS": None}, + ] + tpop = make_test_transient_pop( + tmp_path, filename="tp_mw.h5", oneline_rows=oneline_rows, + ) + + # Monkeypatch calculate_model_weights to return known values + def mock_calc_weights(pop_data, M_sim, simulation_parameters, + population_parameters): + return np.ones(len(pop_data)) * 0.5 + + monkeypatch.setattr( + "posydon.popsyn.synthetic_population.calculate_model_weights", + mock_calc_weights, + ) + + result = tpop.calculate_model_weights("test_weights") + assert isinstance(result, pd.DataFrame) + assert "test_weights" in result.columns + assert len(result) == 2 + assert (result["test_weights"] == 0.5).all() + + # Verify it was stored in the HDF5 file + with pd.HDFStore(tpop.filename, "r") as store: + key = "/transients/test_transient/weights/test_weights" + assert key in store.keys() + + # Overwrite warning on second call + result2 = tpop.calculate_model_weights("test_weights") + assert isinstance(result2, pd.DataFrame) + + # Custom simulation_parameters triggers warning for unknown key + with warns(match="not found in the population"): + tpop.calculate_model_weights( + "test_weights2", + simulation_parameters={"fake_key": 999}, + ) + + # Custom population_parameters (exercises the non-None branch) + custom_pop_params = { + 'number_of_binaries': 100, + 'binary_fraction_scheme': 'const', + 'binary_fraction_const': 0.5, + 'star_formation': 'burst', + 'max_simulation_time': 13800000000.0, + 'primary_mass_scheme': 'Kroupa2001', + 'primary_mass_min': 0.01, + 'primary_mass_max': 200, + 'secondary_mass_scheme': 'flat_mass_ratio', + 'secondary_mass_min': 0.0005, + 'secondary_mass_max': 200, + 'orbital_scheme': 'period', + 'orbital_period_scheme': 'Sana+12_period_extended', + 'orbital_period_min': 0.35, + 'orbital_period_max': 6e3, + 'eccentricity_scheme': 'zero', + } + result3 = tpop.calculate_model_weights( + "test_weights3", population_parameters=custom_pop_params, + ) + assert isinstance(result3, pd.DataFrame) + + def test_calculate_cosmic_weights(self, tmp_path, monkeypatch): + oneline_rows = [ + {"binary_index": 0, "S1_mass_i": 10.0, "S2_mass_i": 5.0, + "orbital_period_i": 3.0, "eccentricity_i": 0.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "initial_MT", + "mt_history_HMS_HMS": "Stable"}, + {"binary_index": 1, "S1_mass_i": 8.0, "S2_mass_i": 4.0, + "orbital_period_i": 5.0, "eccentricity_i": 0.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "no_MT", + "mt_history_HMS_HMS": None}, + ] + tpop = make_test_transient_pop( + tmp_path, filename="tp_cw.h5", oneline_rows=oneline_rows, + ) + + # First, store model weights (required by calculate_cosmic_weights) + def mock_calc_weights(pop_data, M_sim, simulation_parameters, + population_parameters): + return np.ones(len(pop_data)) * 0.5 + + def mock_SFR(z, met_bins, SFH_MODEL): + return np.ones((len(z), len(met_bins) - 1)) * 1e-3 + + monkeypatch.setattr( + "posydon.popsyn.synthetic_population.calculate_model_weights", + mock_calc_weights, + ) + + monkeypatch.setattr( + "posydon.popsyn.synthetic_population.SFR_per_met_at_z", + mock_SFR, + ) + + tpop.calculate_model_weights("mw1") + + # Call calculate_cosmic_weights + rates = tpop.calculate_cosmic_weights("SFH_test", "mw1") + assert isinstance(rates, totest.Rates) + assert rates.SFH_identifier == "SFH_test" + + # Verify HDF5 structure + with pd.HDFStore(tpop.filename, "r") as store: + base = "/transients/test_transient/rates/SFH_test/" + assert base + "MODEL" in store.keys() + assert base + "weights" in store.keys() + assert base + "z_events" in store.keys() + assert base + "birth" in store.keys() + + # Bad model_weights identifier + with raises(ValueError, match="Model weights not present"): + tpop.calculate_cosmic_weights("SFH2", "nonexistent_weights") + + # Overwrite on second call + rates2 = tpop.calculate_cosmic_weights("SFH_test", "mw1") + assert isinstance(rates2, totest.Rates) + + # Custom MODEL_in (exercises the MODEL_in is not None branch) + rates3 = tpop.calculate_cosmic_weights( + "SFH_custom", "mw1", + MODEL_in={"delta_t": 200}, + ) + assert isinstance(rates3, totest.Rates) + assert rates3.MODEL["delta_t"] == 200 + + def test_efficiency(self, tmp_path, monkeypatch): + oneline_rows = [ + {"binary_index": 0, "S1_mass_i": 10.0, "S2_mass_i": 5.0, + "orbital_period_i": 3.0, "eccentricity_i": 0.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "initial_MT", + "mt_history_HMS_HMS": "Stable"}, + {"binary_index": 1, "S1_mass_i": 8.0, "S2_mass_i": 4.0, + "orbital_period_i": 5.0, "eccentricity_i": 0.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "no_MT", + "mt_history_HMS_HMS": None}, + ] + transient_rows = [ + {"time": 100.0, "metallicity": 0.02, "channel": "ch_A"}, + {"time": 200.0, "metallicity": 0.02, "channel": "ch_B"}, + ] + tpop = make_test_transient_pop( + tmp_path, filename="tp_eff.h5", + oneline_rows=oneline_rows, + transient_rows=transient_rows, + ) + + # Store model weights + def mock_calc_weights(pop_data, M_sim, simulation_parameters, + population_parameters): + return np.ones(len(pop_data)) * 0.5 + + monkeypatch.setattr( + "posydon.popsyn.synthetic_population.calculate_model_weights", + mock_calc_weights, + ) + tpop.calculate_model_weights("eff_weights") + + # Without channels + eff = tpop.efficiency("eff_weights", channels=False) + assert isinstance(eff, pd.DataFrame) + assert "total" in eff.columns + assert len(eff) == 1 # one metallicity + assert eff["total"].iloc[0] > 0 + + # With channels + eff_ch = tpop.efficiency("eff_weights", channels=True) + assert "total" in eff_ch.columns + assert "ch_A" in eff_ch.columns + assert "ch_B" in eff_ch.columns + +class TestRates: + + def test_init(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_init.h5") + assert rates.SFH_identifier == "test_SFH" + assert rates.transient_name == "test_transient" + assert hasattr(rates, "MODEL") + + # bad SFH_identifier + with raises(ValueError, match="is not a valid SFH_identifier"): + totest.Rates(rates.filename, "test_transient", "nonexistent") + + def test_select_rate_slice(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_srs.h5") + + # bad key + with raises(ValueError, match="key not in"): + rates.select_rate_slice("invalid_key") + + # valid keys + w = rates.select_rate_slice("weights") + assert isinstance(w, pd.DataFrame) + assert len(w) == 2 + + z = rates.select_rate_slice("z_events") + assert isinstance(z, pd.DataFrame) + assert len(z) == 2 + + b = rates.select_rate_slice("birth") + assert isinstance(b, pd.DataFrame) + assert "z" in b.columns + assert "t" in b.columns + + # with start/stop + w_slice = rates.select_rate_slice("weights", start=0, stop=1) + assert len(w_slice) == 1 + + def test_calculate_intrinsic_rate_density(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_ird.h5") + + result = rates.calculate_intrinsic_rate_density(channels=False) + assert isinstance(result, pd.DataFrame) + assert "total" in result.columns + assert len(result) > 0 + + # Stored in file + with pd.HDFStore(rates.filename, "r") as store: + assert rates.base_path + "intrinsic_rate_density" in store.keys() + + # Access via property + stored = rates.intrinsic_rate_density + pd.testing.assert_frame_equal(stored, result) + + def test_calculate_intrinsic_rate_density_channels(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_ird_ch.h5") + + result = rates.calculate_intrinsic_rate_density(channels=True) + assert "total" in result.columns + + def test_intrinsic_rate_density_not_computed(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_ird_nc.h5") + + # Property should raise before computing + with raises(ValueError, match="First you need to compute"): + _ = rates.intrinsic_rate_density + + def test_calculate_observable_population(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_cop.h5") + + # Simple observable func: halve the weights + def obs_func(transient_chunk, z_events_chunk, weights_chunk): + return weights_chunk * 0.5 + + rates.calculate_observable_population(obs_func, "test_obs") + + # Verify stored + with pd.HDFStore(rates.filename, "r") as store: + key = ("/transients/test_transient/rates/observable/test_obs") + assert key in store.keys() + + # Overwrite on second call + rates.calculate_observable_population(obs_func, "test_obs") + + def test_observable_population(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_op.h5") + + # Not yet computed -> ValueError + with raises(ValueError, match="is not a valid observable population"): + rates.observable_population("nonexistent") + + # Compute, then retrieve + def obs_func(transient_chunk, z_events_chunk, weights_chunk): + return weights_chunk * 0.5 + + rates.calculate_observable_population(obs_func, "obs1") + result = rates.observable_population("obs1") + assert isinstance(result, pd.DataFrame) + assert len(result) == 2 + + def test_observable_population_names(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_opn.h5") + + # No observables yet + assert rates.observable_population_names == [] + + # Add one + def obs_func(transient_chunk, z_events_chunk, weights_chunk): + return weights_chunk * 0.5 + + rates.calculate_observable_population(obs_func, "obs_A") + names = rates.observable_population_names + assert "obs_A" in names + + def test_edges_metallicity_bins(self, tmp_path): + # Multiple metallicities + transient_rows = [ + {"time": 100.0, "metallicity": 0.01, "channel": "ch_A"}, + {"time": 200.0, "metallicity": 0.02, "channel": "ch_B"}, + ] + oneline_rows = [ + {"binary_index": 0, "S1_mass_i": 10.0, "S2_mass_i": 5.0, + "state_i": "initial", "metallicity": 0.01, + "interp_class_HMS_HMS": "initial_MT", + "mt_history_HMS_HMS": "Stable"}, + {"binary_index": 1, "S1_mass_i": 8.0, "S2_mass_i": 4.0, + "state_i": "initial", "metallicity": 0.02, + "interp_class_HMS_HMS": "no_MT", + "mt_history_HMS_HMS": None}, + ] + mass_met_rows = { + "simulated_mass": [1.0, 1.0], + "number_of_systems": [1, 1], + } + # Build file manually for multi-metallicity + pop = make_test_pop( + tmp_path, filename="rates_emb.h5", + oneline_rows=oneline_rows, metallicity=0.01, + ) + # Overwrite mass_per_metallicity with two entries + mass_df = pd.DataFrame(mass_met_rows, index=[0.01, 0.02]) + with pd.HDFStore(pop.filename, "a") as store: + store.put("mass_per_metallicity", mass_df, format="table") + + # Add transient + transient_df = pd.DataFrame(transient_rows) + with pd.HDFStore(pop.filename, "a") as store: + store.append( + "transients/test_transient", transient_df, + format="table", min_itemsize={"channel": 100}, + ) + + # Add rates structure + from posydon.popsyn.rate_calculation import ( + DEFAULT_SFH_MODEL, + get_cosmic_time_from_redshift, + get_redshift_bin_centers, + ) + MODEL = dict(DEFAULT_SFH_MODEL) + z_birth = get_redshift_bin_centers(MODEL["delta_t"]) + t_birth = get_cosmic_time_from_redshift(z_birth) + nr = len(z_birth) + base = "/transients/test_transient/rates/test_SFH/" + with pd.HDFStore(pop.filename, "a") as store: + store.put(base + "MODEL", pd.DataFrame(MODEL, index=[0])) + store.put(base + "birth", pd.DataFrame({"z": z_birth, "t": t_birth})) + store.append(base + "weights", + pd.DataFrame(np.ones((2, nr))), format="table") + store.append(base + "z_events", + pd.DataFrame(np.full((2, nr), 0.1)), format="table") + + rates = totest.Rates(pop.filename, "test_transient", "test_SFH") + + edges = rates.edges_metallicity_bins + assert len(edges) == 3 # 2 metallicities -> 3 edges + assert edges[0] < edges[1] < edges[2] + + # Single metallicity with dlogZ = None + rates_single = make_test_rates( + tmp_path, filename="rates_emb_single.h5", + ) + edges_single = rates_single.edges_metallicity_bins + assert len(edges_single) == 2 + # dlogZ=None -> edges are 10**(-9) and 10**(0) + assert np.isclose(edges_single[0], 10**(-9)) + assert np.isclose(edges_single[1], 10**(0)) + + # Single metallicity with dlogZ = float + rates_dlogz = make_test_rates( + tmp_path, filename="rates_emb_dlogz.h5", + MODEL={"dlogZ": 0.5}, + ) + edges_dlogz = rates_dlogz.edges_metallicity_bins + assert len(edges_dlogz) == 2 + + # Single metallicity with dlogZ = list (exercises the list branch) + rates_dlogz_list = make_test_rates( + tmp_path, filename="rates_emb_dlogz_list.h5", + MODEL={"dlogZ": [-2.0, -1.0]}, + ) + # Manually overwrite the MODEL table with multi-row format + # to exercise the len(tmp_df) > 1 branch in _read_MODEL_data + # and the isinstance(dlogZ, list) branch in edges_metallicity_bins + base = rates_dlogz_list.base_path + with pd.HDFStore(rates_dlogz_list.filename, "a") as store: + model_data = {k: [v, v] for k, v in rates_dlogz_list.MODEL.items()} + model_data["dlogZ"] = [-2.0, -1.0] + store.put(base + "MODEL", pd.DataFrame(model_data)) + + # Re-read to trigger the multi-row branch + rates_multi_model = totest.Rates( + rates_dlogz_list.filename, "test_transient", "test_SFH" + ) + assert isinstance(rates_multi_model.MODEL["dlogZ"], list) + edges_list = rates_multi_model.edges_metallicity_bins + assert len(edges_list) == 2 + assert np.isclose(edges_list[0], 10**(-2.0)) + assert np.isclose(edges_list[1], 10**(-1.0)) + + def test_z_birth_property(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_zb.h5") + zb = rates.z_birth + assert isinstance(zb, pd.DataFrame) + assert "z" in zb.columns + assert "t" in zb.columns + + def test_z_events_property(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_ze.h5") + ze = rates.z_events + assert isinstance(ze, pd.DataFrame) + assert len(ze) == 2 + + def test_edges_redshift_bins(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_erb.h5") + edges = rates.edges_redshift_bins + assert len(edges) > 0 + assert edges[0] >= 0 + + def test_centers_redshift_bins(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_crb.h5") + centers = rates.centers_redshift_bins + assert len(centers) > 0 + + def test_centers_metallicity_bins(self, tmp_path): + rates = make_test_rates(tmp_path, filename="rates_cmb.h5") + centers = rates.centers_metallicity_bins + assert len(centers) == 1 + assert np.isclose(centers[0], 0.02 * totest.Zsun) diff --git a/posydon/unit_tests/popsyn/test_transient_select_funcs.py b/posydon/unit_tests/popsyn/test_transient_select_funcs.py new file mode 100644 index 0000000000..ba6783f8dd --- /dev/null +++ b/posydon/unit_tests/popsyn/test_transient_select_funcs.py @@ -0,0 +1,350 @@ +"""Unit tests of posydon/popsyn/transient_select_funcs.py +""" + +__authors__ = [ + "Elizabeth Teng " +] + +# import the module which will be tested +import posydon.popsyn.transient_select_funcs as totest + +# aliases +np = totest.np +pd = totest.pd +PATH_TO_POSYDON_DATA = totest.PATH_TO_POSYDON_DATA + +import warnings + +# import other needed code for the tests, which is not already imported in the +# module you like to test +from pytest import approx, fixture, raises, warns + +import posydon.popsyn.selection_effects as selection_effects +from posydon.utils.posydonwarning import ReplaceValueWarning + +warnings.simplefilter("always") + +# define test classes collecting several test functions +class TestElements: + # check for objects, which should be an element of the tested module + def test_dir(self): + elements = ['PATH_TO_PDET_GRID', 'GRB_selection', 'chi_eff', 'effective_precession',\ + 'm_chirp', 'mass_ratio', 'BBH_selection_function','DCO_detectability', \ + '__builtins__', '__cached__', '__doc__', \ + '__file__','__loader__', '__name__', '__package__', '__spec__', \ + 'np', 'pd', 'PATH_TO_POSYDON_DATA', \ + 'os', 'tqdm', 'warnings', 'Pwarn','selection_effects'] + totest_elements = set(dir(totest)) + missing_in_test = set(elements) - totest_elements + assert len(missing_in_test) == 0, "There are missing objects in "\ + +f"{totest.__name__}: "\ + +f"{missing_in_test}. Please "\ + +"check, whether they have been "\ + +"removed on purpose and update "\ + +"this unit test." + new_in_test = totest_elements - set(elements) + assert len(new_in_test) == 0, "There are new objects in "\ + +f"{totest.__name__}: {new_in_test}. "\ + +"Please check, whether they have been "\ + +"added on purpose and update this "\ + +"unit test." + +class TestFunctions: + + @fixture + def history_chunk(self): + return pd.DataFrame({ + 'binary_index': [10, 10, 10], + 'S1_state': ['MS', 'HG', 'BH'], + 'S2_state': ['MS', 'HG', 'BH'], + 'step_names': ['step_RLO', 'step_RLO', 'step_SN'], + 'orbital_period': [1.0, 1.1, 1.2], + 'eccentricity': [0.1, 0.15, 0.2], + 'S1_spin': [0.3, 0.35, 0.4], + 'S2_spin': [0.5, 0.55, 0.6], + 'S1_mass': [10.0, 9.5, 9.0], + 'S2_mass': [8.0, 7.5, 7.0], + 'time': [1.0e6, 2.0e6, 3.0e6]}, index=[10, 10, 10]) + + @fixture + def oneline_chunk(self): + return pd.DataFrame({ + 'metallicity': [0.02], + 'S1_m_disk_radiated': [0.5], + 'S2_m_disk_radiated': [0.0], + }, index=[10]) + + @fixture + def formation_channels_chunk(self): + return pd.DataFrame({ + 'channel': ['foo_CC1','bar_CC2'] + }, index=[10,11]) + + @fixture + def history_BBH(self): + return pd.DataFrame({ + 'event': ['MID', 'END', 'END'], + 'time': [1e6, 5e6, 6e6], + 'S1_state': ['BH', 'BH', 'BH'], + 'S2_state': ['BH', 'BH', 'BH'], + 'step_names': ['step_SN', 'step_SN', 'step_SN'], + 'state': ['detached', 'detached', 'detached'], + 'S1_mass': [30, 35, 40], + 'S2_mass': [20, 25, 30], + 'S1_spin': [0.5, 0.6, 0.7], + 'S2_spin': [0.4, 0.3, 0.2], + 'orbital_period': [0.5, 0.6, 0.7], + 'eccentricity': [0.1, 0.2, 0.3], + }, index=[0,1,2]) + + @fixture + def oneline_BBH(self): + return pd.DataFrame({ + 'metallicity': [0.01, 0.02, 0.03], + 'S1_spin_orbit_tilt_second_SN': [0.1, 0.2, 0.3], + 'S2_spin_orbit_tilt_second_SN': [0.4, 0.5, 0.6], + }, index=[0,1,2]) + + @fixture + def formation_channels_BBH(self): + return pd.DataFrame({ + 'channel': ['foo', 'bar', 'baz'], + }, index=[0,1,2]) + + @fixture + def array(self): + return np.array([1.0,2.0,3.0]) + + @fixture + def nan_array(self): + return np.array([np.nan,np.nan,np.nan]) + + @fixture + def wrong_array(self): + return np.array(['1.0','2.0','3.0']) + + @fixture + def transient_pop_chunk(self): + return pd.DataFrame({ + 'S1_mass': [30, 35], + 'S2_mass': [25, 30], + 'S1_spin': [0.1, 0.2], + 'S2_spin': [0.1, 0.2], + 'S1_spin_orbit_tilt_at_merger': [0.5, 0.6], + 'S2_spin_orbit_tilt_at_merger': [0.4, 0.5], + 'q': [0.83, 0.86], + 'chi_eff': [0.1, 0.2]}) + + @fixture + def z_events_chunk(self): + return pd.DataFrame({ + 'event_1': [0.1, np.nan], + 'event_2': [0.2, 0.3]}) + @fixture + def z_events_chunk_with_nan(self): + return pd.DataFrame({ + 'event_1': [1.0, np.nan], + 'event_2': [np.nan, np.nan] + }, index=[0,1]) + + @fixture + def z_weights_chunk(self): + return pd.DataFrame({ + 'event_1': [1.0, 1.0], + 'event_2': [1.0, 1.0] + }, index=[0, 1]) + + + def test_GRB_selection(self,history_chunk,oneline_chunk, + formation_channels_chunk): + # missing argument + with raises(TypeError,match="missing 2 required positional arguments"): + totest.GRB_selection() + # bad input + with raises(TypeError,match='string indices must be integers'): + totest.GRB_selection("1.1", "1.2") + with raises(AttributeError,match="'float' object has no attribute 'index'"): + totest.GRB_selection(1.1, 1.2) + with raises(ValueError,match='S1_S2 must be either S1 or S2'): + totest.GRB_selection(history_chunk, oneline_chunk.copy(), + S1_S2='test') + # example with S1 + df = totest.GRB_selection(history_chunk, oneline_chunk.copy(), + formation_channels_chunk, S1_S2='S1') + assert not df.empty + assert df.index[0] == 10 + assert 'S1_mass_preSN' in df.columns + assert 'S1_mass_postSN' in df.columns + assert df['time'].iloc[0] == 3.0 # 3 Myr = 3e6 years * 1e-6 + assert df['channel'].iloc[0] == 'foo_CC1' + # example with S2 + chunk = oneline_chunk.copy() + chunk['S1_m_disk_radiated'] = [0.0] + chunk['S2_m_disk_radiated'] = [0.5] + df = totest.GRB_selection(history_chunk, chunk, + formation_channels_chunk, S1_S2='S2') + assert not df.empty + assert 'S2_mass_postSN' in df.columns + assert 'metallicity' in df.columns + assert 'channel' in df.columns + # example with no disk radiation (empty selection) + chunk = oneline_chunk.copy() + chunk['S1_m_disk_radiated'] = [0.0] + df = totest.GRB_selection(history_chunk, chunk, + formation_channels_chunk=None, S1_S2='S1') + assert df.empty + # example with no formation channels + # example with no formation channels + df = totest.GRB_selection(history_chunk, oneline_chunk.copy(), + formation_channels_chunk=None, S1_S2='S1') + assert not df.empty + assert 'channel' not in df.columns + + def test_chi_eff(self,array,nan_array,wrong_array): + # missing argument + with raises(TypeError,match="missing 6 required positional arguments"): + totest.chi_eff() + # bad input + with raises(TypeError,match="ufunc 'cos' not supported for the input types"): + totest.chi_eff(array,array,array,array,array,wrong_array) + # undefined values + with warns(ReplaceValueWarning,match="a_1 contains undefined values"): + totest.chi_eff(array,array,nan_array.copy(),array,array,array) + with warns(ReplaceValueWarning,match="a_2 contains undefined values"): + totest.chi_eff(array,array,array,nan_array.copy(),array,array) + with warns(ReplaceValueWarning,match="tilt_1 contains undefined values"): + totest.chi_eff(array,array,array,array,nan_array.copy(),array) + with warns(ReplaceValueWarning,match="tilt_2 contains undefined values"): + totest.chi_eff(array,array,array,array,array,nan_array.copy()) + # example + assert totest.chi_eff(array,array,array, + array,array,array)[0] == 0.5403023058681398 + + def test_effective_precession(self): + # missing argument + with raises(TypeError, match="missing 6 required positional arguments"): + totest.effective_precession() + # example with scalars + result = totest.effective_precession( + theta_1=0.5, theta_2=0.3, + a1=0.8, a2=0.6, + m1=30.0, m2=20.0) + assert result == approx(np.maximum( + np.abs(0.8 * np.sin(0.5)), + (20./30.) * ((4*(20./30.) + 3)/(4 + 3*(20./30.))) * 0.6 * np.sin(0.3)), + abs=1e-12) + # example with arrays + theta_1 = np.array([0.0, np.pi/2]) + theta_2 = np.array([np.pi/2, 0.0]) + a1 = np.array([0.9, 0.9]) + a2 = np.array([0.5, 0.5]) + m1 = np.array([30.0, 30.0]) + m2 = np.array([10.0, 10.0]) + result = totest.effective_precession(theta_1, theta_2, a1, a2, m1, m2) + assert len(result) == 2 + # theta_1=0 means a1_perp=0, so chi_p comes from a2 term + assert result[0] > 0 + # theta_2=0 means a2_perp=0, so chi_p comes from a1 term + assert result[1] == approx(0.9, abs=1e-12) # a1 * sin(pi/2) = 0.9 + + def test_m_chirp(self): + # missing argument + with raises(TypeError,match="missing 1 required positional argument: 'm_2'"): + totest.m_chirp(3.) + # bad input + with raises(TypeError,match="can't multiply sequence by non-int of type 'str'"): + totest.m_chirp("3.","2.") + # examples + tests = [(4.,2.,2.433457367572823), + (40.,10.,16.65106414803746)] + for (m1,m2,mc) in tests: + assert totest.m_chirp(m1,m2) == mc + + def test_mass_ratio(self): + # missing argument + with raises(TypeError,match="missing 1 required positional argument: 'm_2'"): + totest.mass_ratio(3.) + # bad input + with raises(TypeError,match="unsupported operand type"): + totest.mass_ratio("3.","2.") + # examples + tests = [(5.,1.,0.2), + (1.,5.,0.2), + (4.,2.,0.5)] + for (m1,m2,q) in tests: + assert totest.mass_ratio(np.array([m1]), + np.array([m2])) == q + + def test_BBH_selection_function(self, history_BBH, oneline_BBH, + formation_channels_BBH): + # missing argument + with raises(TypeError,match="missing 2 required positional arguments"): + totest.BBH_selection_function() + # bad input + with raises(AttributeError,match="'float' object has no attribute 'index'"): + totest.BBH_selection_function(1.1, 1.2) + # example without formation channels + df = totest.BBH_selection_function(history_BBH, oneline_BBH) + assert not df.empty + assert all(col in df.columns for col in [ + 'time', 't_inspiral', 'metallicity', 'S1_state', 'S2_state', + 'S1_mass', 'S2_mass', 'S1_spin', 'S2_spin', + 'S1_spin_orbit_tilt_at_merger', 'S2_spin_orbit_tilt_at_merger', + 'orbital_period', 'chirp_mass', 'mass_ratio', 'chi_eff', 'eccentricity' + ]) + assert (df.index == oneline_BBH.index).all() + assert df['t_inspiral'].iloc[1] == 0.0 + # example with formation channels + df = totest.BBH_selection_function(history_BBH, oneline_BBH, formation_channels_BBH) + assert 'channel' in df.columns + assert (df['channel'] == formation_channels_BBH['channel']).all() + + def test_DCO_detectability(self, + transient_pop_chunk, + z_events_chunk, + z_events_chunk_with_nan, + z_weights_chunk, + monkeypatch): + class FakeKNNmodel: + def __init__(self, grid_path, sensitivity_key): + pass + def predict_pdet(self, df): + # Return a fixed probability (e.g., 0.5) for each row in df + return np.full(len(df), 0.5) + + monkeypatch.setattr('posydon.popsyn.selection_effects.KNNmodel', + FakeKNNmodel) + + # missing argument + with raises(TypeError,match="missing 4 required positional arguments"): + totest.DCO_detectability() + # bad input + with raises(ValueError,match='Unknown sensitivity sens_example'): + totest.DCO_detectability("sens_example", + transient_pop_chunk, + z_events_chunk, + z_weights_chunk) + # example: basic functionality + out = totest.DCO_detectability('O3actual_H1L1V1', transient_pop_chunk, + z_events_chunk, z_weights_chunk.copy()) + assert isinstance(out, pd.DataFrame) + assert out.shape == z_weights_chunk.shape + assert (out.values <= 1.0).all() + # example: missing q + transient = transient_pop_chunk.drop(columns=['q']) + out = totest.DCO_detectability('O3actual_H1L1V1', transient, + z_events_chunk, z_weights_chunk.copy()) + assert not out.empty + # example: missing chi_eff + transient = transient_pop_chunk.drop(columns=['chi_eff']) + out = totest.DCO_detectability('O3actual_H1L1V1', transient, + z_events_chunk, z_weights_chunk.copy()) + assert not out.empty + assert (out.values <= 1.0).all() + # example: masking for nans in z_events_chunk + out = totest.DCO_detectability('O3actual_H1L1V1', + transient_pop_chunk, + z_events_chunk_with_nan, + z_weights_chunk.copy()) + # event_2 is all NaN, so mask is all False and weights are unchanged + assert (out['event_2'] == 1.0).all() diff --git a/posydon/unit_tests/utils/test_common_functions.py b/posydon/unit_tests/utils/test_common_functions.py index c66229fdd6..6f348fa481 100644 --- a/posydon/unit_tests/utils/test_common_functions.py +++ b/posydon/unit_tests/utils/test_common_functions.py @@ -89,7 +89,7 @@ def test_dir(self): 'THRESHOLD_HE_NAKED_ABUNDANCE', '__authors__',\ '__builtins__', '__cached__', '__doc__', '__file__',\ '__loader__', '__name__', '__package__', '__spec__',\ - 'beaming', 'bondi_hoyle',\ + 'beaming', 'beta_gw', 'bondi_hoyle',\ 'calculate_H2recombination_energy',\ 'calculate_Mejected_for_integrated_binding_energy',\ 'calculate_Patton20_values_at_He_depl',\ @@ -233,6 +233,9 @@ def test_instance_histogram_sampler(self): def test_instance_read_histogram_from_file(self): assert isroutine(totest.read_histogram_from_file) + def test_instance_beta_gw(self): + assert isroutine(totest.beta_gw) + def test_instance_inspiral_timescale_from_separation(self): assert isroutine(totest.inspiral_timescale_from_separation) @@ -696,10 +699,12 @@ def test_beaming(self, binary): assert totest.beaming(binary) == r def test_bondi_hoyle(self, binary, monkeypatch): - def mock_rand(shape): - return np.zeros(shape) - def mock_rand2(shape): - return np.full(shape, 0.1) + class MockRNG: + def random(self, shape): + return np.zeros(shape) + class MockRNG2: + def random(self, shape): + return np.full(shape, 0.1) # missing argument with raises(TypeError, match="missing 3 required positional "\ +"arguments: 'binary', 'accretor', and "\ @@ -725,32 +730,32 @@ def mock_rand2(shape): +"associated with a value"): # undefined scheme totest.bondi_hoyle(binary, binary.star_1, binary.star_2, scheme='') - monkeypatch.setattr(np.random, "rand", mock_rand) - assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2) ==\ + rng = MockRNG() + assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2, RNG=rng) ==\ approx(3.92668160462e-17, abs=6e-29) assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2,\ - scheme='Kudritzki+2000') ==\ + RNG=rng, scheme='Kudritzki+2000') ==\ approx(3.92668160462e-17, abs=6e-29) binary.star_2.log_R = 1.5 #donor's radius is 10^{1.5}Rsun assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2,\ - scheme='Kudritzki+2000') ==\ + RNG=rng, scheme='Kudritzki+2000') ==\ approx(3.92668160462e-17, abs=6e-29) binary.star_2.log_R = -1.5 #donor's radius is 10^{-1.5}Rsun assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2,\ - scheme='Kudritzki+2000') == 1e-99 + RNG=rng, scheme='Kudritzki+2000') == 1e-99 binary.star_2.surface_h1 = 0.25 #donor's X_surf=0.25 - assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2) ==\ + assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2, RNG=rng) ==\ 1e-99 binary.star_2.lg_wind_mdot = -4.0 #donor's wind is 10^{-4}Msun/yr - assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2) ==\ + assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2, RNG=rng) ==\ 1e-99 assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2,\ - wind_disk_criteria=False) ==\ + RNG=rng, wind_disk_criteria=False) ==\ approx(5.34028698228e-17, abs=6e-29) # form always a disk - monkeypatch.setattr(np.random, "rand", mock_rand2) # other angle + rng = MockRNG2() # other angle binary.star_1.state = 'BH' #accretor is BH assert totest.bondi_hoyle(binary, binary.star_1, binary.star_2,\ - wind_disk_criteria=False) ==\ + wind_disk_criteria=False, RNG=rng) ==\ approx(5.13970075150e-8, abs=6e-20) def test_rejection_sampler(self, monkeypatch): @@ -934,6 +939,19 @@ def test_read_histogram_from_file(self, csv_path_failing_3_data_lines,\ assert np.allclose(arrays[0], np.array([0.2, 1.2, 2.2])) assert np.allclose(arrays[1], np.array([2.0, 2.0])) + def test_beta_gw(self): + # missing argument + with raises(TypeError, match="missing 2 required positional "\ + +"arguments: 'star1_mass' and "\ + +"'star2_mass'"): + totest.beta_gw() + # examples + tests = [(15.0, 30.0, approx(3.18232660295e-69, abs=6e-81)),\ + (30.0, 30.0, approx(8.48620427454e-69, abs=6e-81)),\ + (30.0, 60.0, approx(2.54586128236e-68, abs=6e-80))] + for (m1, m2, r) in tests: + assert totest.beta_gw(m1, m2) == r + def test_inspiral_timescale_from_separation(self): # missing argument with raises(TypeError, match="missing 4 required positional "\ diff --git a/posydon/unit_tests/utils/test_compress_mesa_files.py b/posydon/unit_tests/utils/test_compress_mesa_files.py index 1580afef07..769abe0038 100644 --- a/posydon/unit_tests/utils/test_compress_mesa_files.py +++ b/posydon/unit_tests/utils/test_compress_mesa_files.py @@ -245,6 +245,16 @@ def test_get_size(self, tmp_path): os.listdir(MESA_run_dir)[0]) os.symlink(MESA_run_file, os.path.join(MESA_dir,\ f"link{i}.file0")) + + # add >=2 regular files directly in MESA_dir (a non-MESA directory) + # so the inner for-loop backward arc (204->198) is covered on Ubuntu, + # where file symlinks may not reliably appear in os.walk's filenames + for j in range(2): + with open(os.path.join(MESA_dir, f"extra_{j}.log"),\ + "w") as extra_file: + extra_file.write(f"test\n") + + total_size, remove_files, compress_files, n_runs, n_remove_files,\ n_compress_files = totest.get_size(start_path=MESA_dir) assert total_size > 0 @@ -253,6 +263,23 @@ def test_get_size(self, tmp_path): assert n_runs == 20 assert n_remove_files == 0 assert n_compress_files > 0 + # isolated test for islink branch (201->204): create a minimal + # directory with a real file and a symlink to it, then verify + # get_size only counts the real file's size + islink_dir = os.path.join(tmp_path, "islink_grid_index_0") + os.mkdir(islink_dir) + real = os.path.join(islink_dir, "real.data") + with open(real, "w") as f: + f.write("content") + link = os.path.join(islink_dir, "link.data") + os.symlink(real, link) + assert os.path.islink(link), \ + f"os.path.islink returned False for symlink at {link}" + real_size = os.path.getsize(real) + total_size, remove_files, compress_files, n_runs, n_remove_files,\ + n_compress_files = totest.get_size(start_path=tmp_path) + # the symlink should not contribute to total_size + assert total_size >= real_size def test_compress_dir(self, tmp_path, capsys): # missing argument diff --git a/posydon/unit_tests/utils/test_gridutils.py b/posydon/unit_tests/utils/test_gridutils.py index e7223d2346..a2fc231fd2 100644 --- a/posydon/unit_tests/utils/test_gridutils.py +++ b/posydon/unit_tests/utils/test_gridutils.py @@ -5,9 +5,11 @@ "Matthias Kruckow " ] -# import the module which will be tested import posydon.utils.gridutils as totest +# import the module which will be tested +from posydon.grids.lazy_hdf import LazyHDF5 + # aliases np = totest.np os = totest.os @@ -31,16 +33,17 @@ def test_dir(self): ## does not clear the warning registy correctly. if hasattr(totest, '__warningregistry__'): del totest.__warningregistry__ - elements = {'LG_MTRANSFER_RATE_THRESHOLD', 'Msun', 'Pwarn', 'Rsun',\ - 'T_merger_P', 'T_merger_a', '__authors__', '__builtins__',\ + elements = {'LG_MTRANSFER_RATE_THRESHOLD', 'Pwarn',\ + '__authors__', '__builtins__',\ '__cached__', '__doc__', '__file__', '__loader__',\ '__name__', '__package__', '__spec__', 'add_field',\ - 'beta_gw', 'cgrav', 'clean_inlist_file', 'clight',\ + 'clean_inlist_file', 'LazyHDF5',\ 'convert_output_to_table', 'find_index_nearest_neighbour',\ 'find_nearest', 'fix_He_core', 'get_cell_edges',\ 'get_final_proposed_points', 'get_new_grid_name', 'gzip',\ - 'join_lists', 'kepler3_a', 'np', 'os', 'pd',\ - 'read_EEP_data_file', 'read_MESA_data_file', 'secyear'} + 'inspiral_timescale_from_orbital_period',\ + 'join_lists', 'np', 'os', 'pd',\ + 'read_EEP_data_file', 'read_MESA_data_file'} totest_elements = set(dir(totest)) missing_in_test = elements - totest_elements assert len(missing_in_test) == 0, "There are missing objects in "\ @@ -83,18 +86,6 @@ def test_instance_find_index_nearest_neighbour(self): def test_instance_get_final_proposed_points(self): assert isroutine(totest.get_final_proposed_points) - def test_instance_T_merger_P(self): - assert isroutine(totest.T_merger_P) - - def test_instance_beta_gw(self): - assert isroutine(totest.beta_gw) - - def test_instance_kepler3_a(self): - assert isroutine(totest.kepler3_a) - - def test_instance_T_merger_a(self): - assert isroutine(totest.T_merger_a) - def test_instance_convert_output_to_table(self): assert isroutine(totest.convert_output_to_table) @@ -120,9 +111,9 @@ def no_path(self, tmp_path): def MESA_data(self): # mock data: 3 columns and 2 rows; it contains different # types(int, float) and different signs(positive, negative) - return np.array([(1, 2, 3.3), (1, -2, -3.3)],\ + return LazyHDF5(np.array([(1, 2, 3.3), (1, -2, -3.3)],\ dtype=[('COL1', '", "Jeffrey Andrews ", "Matthias Kruckow ", + "Seth Gossage = 76.0.0", "versioneer"] +requires = ["setuptools >= 76.0.0", "setuptools-scm >= 8.0"] build-backend = "setuptools.build_meta" [project] -dynamic = [ - "description", - "license", - "version", - "requires-python", - "classifiers", - "dependencies", - "optional-dependencies", -] name = "posydon" -#description = "POSYDON the Next Generation of Population Synthesis" +description = "POSYDON the Next Generation of Population Synthesis" authors = [ {name = "POSYDON Collaboration", email = "posydon.team@gmail.com"}, ] maintainers = [ {name = "POSYDON Collaboration", email = "posydon.team@gmail.com"}, ] -#license = 'GPLv3+' -#version = "2.0.0.dev" -#requires-python = ">=3.11, <3.12" +license = {text = "BSD 3-Clause"} +dynamic = ["version"] +requires-python = ">=3.11, <3.12" readme = "README.md" keywords = [ "POSYDON", @@ -34,60 +25,67 @@ keywords = [ "Population Synthesis", "MESA", ] -#classifiers = [ -# 'Development Status :: 4 - Beta', -# 'Intended Audience :: Science/Research', -# 'Intended Audience :: End Users/Desktop', -# 'Topic :: Scientific/Engineering', -# 'Topic :: Scientific/Engineering :: Astronomy', -# 'Topic :: Scientific/Engineering :: Physics', -# 'Programming Language :: Python', -# 'Programming Language :: Python :: 3.11', -# 'Operating System :: POSIX', -# 'Operating System :: Unix', -# 'Operating System :: MacOS', -# 'Natural Language :: English', -# 'License :: OSI Approved :: GNU General Public License v3 (GPLv3+)', -#] -#dependencies = [ -# 'numpy < 2.0.0, >= 1.24.2', -# 'scipy <= 1.14.1, >= 1.10.1', -# 'iminuit <= 2.30.1, >= 2.21.3', -# 'configparser <= 7.1.0, >= 5.3.0', -# 'astropy <= 6.1.6, >= 5.2.2', -# 'pandas <= 2.2.3, >= 2.0.0', -# 'scikit-learn == 1.2.2', -# 'matplotlib <= 3.9.2, >= 3.9.0', -# 'matplotlib-label-lines <= 0.7.0, >= 0.5.2', -# 'h5py <= 3.12.1, >= 3.8.0', -# 'psutil <= 6.1.0, >= 5.9.4', -# 'tqdm <= 4.67.0, >= 4.65.0', -# 'tables <= 3.10.1, >= 3.8.0', -# 'progressbar2 <= 4.5.0, >= 4.2.0', -# 'hurry.filesize <= 0.9, >= 0.9', -# 'python-dotenv <= 1.0.1, >= 1.0.0', -#] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Intended Audience :: End Users/Desktop", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Astronomy", + "Topic :: Scientific/Engineering :: Physics", + "Programming Language :: Python", + "Programming Language :: Python :: 3.11", + "Operating System :: POSIX", + "Operating System :: Unix", + "Operating System :: MacOS", + "Natural Language :: English", +] +dependencies = [ + "numpy >= 1.24.2, < 2.0.0", + "scipy >= 1.10.1, <= 1.14.1", + "iminuit >= 2.21.3, <= 2.30.1", + "configparser >= 5.3.0, <= 7.1.0", + "astropy >= 5.2.2, <= 6.1.6", + "pandas >= 2.0.0, <= 2.2.3", + "scikit-learn == 1.2.2", + "matplotlib >= 3.9.0, <= 3.9.2", + "matplotlib-label-lines >= 0.5.2, <= 0.7.0", + "h5py >= 3.8.0, <= 3.12.1", + "psutil >= 5.9.4, <= 6.1.0", + "tqdm >= 4.65.0, <= 4.67.0", + "tables >= 3.8.0, <= 3.10.1", + "progressbar2 >= 4.2.0, <= 4.5.0", + "hurry.filesize >= 0.9, <= 0.9", + "python-dotenv >= 1.0.0, <= 1.0.1", +] -#[project.optional-dependencies] -#doc = [ -# 'ipython', -# 'sphinx >= 8.2.2', -# 'numpydoc', -# 'sphinx_rtd_theme', -# 'sphinxcontrib_programoutput', -# 'PSphinxTheme', -# 'nbsphinx', -# 'pandoc' -#] -#vis = [ -# 'PyQt5 <= 5.15.11, >= 5.15.9' -#] -#ml = [ -# 'tensorflow >= 2.13.0' -#] -#hpc = [ -# 'mpi4py >= 3.0.3' -#] +[project.optional-dependencies] +doc = [ + "ipython", + "sphinx >= 8.2.2", + "numpydoc", + "sphinx_rtd_theme", + "sphinxcontrib_programoutput", + "PSphinxTheme", + "nbsphinx", + "pandoc", +] +vis = [ + "PyQt5 >= 5.15.9, <= 5.15.11", +] +ml = [ + "tensorflow >= 2.13.0", +] +hpc = [ + "mpi4py >= 3.0.3", +] +dev = [ + "pre-commit >= 3.7.0", + "isort >= 5.13.2", +] +test = [ + "pytest >= 7.3.1", + "pytest-cov >= 4.0.0", +] [project.urls] Homepage = "https://posydon.org" @@ -95,3 +93,42 @@ Documentation = "https://posydon.org/POSYDON" Repository = "https://github.com/POSYDON-code/POSYDON.git" Issues = "https://github.com/POSYDON-code/POSYDON/issues" Changelog = "https://github.com/POSYDON-code/POSYDON/releases" + +[tool.setuptools.packages.find] +include = ["posydon*"] + +[tool.setuptools] +include-package-data = true +script-files = [ + "bin/compress-mesa", + "bin/get-posydon-data", + "bin/posydon-popsyn", + "bin/posydon-run-grid", + "bin/posydon-run-pipeline", + "bin/posydon-setup-grid", + "bin/posydon-setup-pipeline", +] + +[tool.setuptools_scm] +# setuptools-scm will automatically determine version from git tags + +[tool.pytest.ini_options] +addopts = "--verbose -r s --cov --cov-branch --cov-report=term-missing --cov-fail-under=100" +testpaths = ["posydon/unit_tests"] + +[tool.coverage.run] +branch = true +source = [ + "posydon.config", + "posydon.utils", + "posydon.grids", + "posydon.popsyn.IMFs", + "posydon.popsyn.norm_pop", + "posydon.popsyn.distributions", + "posydon.popsyn.star_formation_history", + "posydon.CLI", +] + +[tool.coverage.report] +fail_under = 100 +show_missing = true diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 3a53872f29..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,22 +0,0 @@ -[aliases] -test = pytest - -[bdist_wheel] -universal = 1 - -[tool:pytest] -addopts = --verbose -r s - -[versioneer] -VCS = git -style = pep440 -versionfile_source = posydon/_version.py -versionfile_build = posydon/_version.py -tag_prefix = v -parentdir_prefix = - -[coverage:run] -source = posydon -omit = - posydon/tests/* - posydon/_version.py diff --git a/setup.py b/setup.py index d8e1e7e709..adc1fb2678 100644 --- a/setup.py +++ b/setup.py @@ -1,174 +1,19 @@ -"""Setup the posydon package.""" +"""Minimal setup.py for POSYDON package. -from __future__ import print_function +All configuration is in pyproject.toml. +This file only handles optional sphinx documentation builds. +""" -import glob -import os.path -import sys - -sys.path.insert(0, os.path.dirname(__file__)) - -import versioneer +from setuptools import setup +# Optional: Add sphinx documentation build command cmdclass = {} - - -# VERSIONING - -__version__ = versioneer.get_version() -cmdclass.update(versioneer.get_cmdclass()) - - -# TOGGLE WRAPPING C/C++ OR FORTRAN - -WRAP_C_CPP_OR_FORTRAN = False - -if WRAP_C_CPP_OR_FORTRAN: - from distutils.command.sdist import sdist - - try: - from numpy.distutils.core import Extension, setup - except ImportError: - raise ImportError("Building fortran extensions requires numpy.") - - cmdclass["sdist"] = sdist -else: - from setuptools import find_packages, setup - - -# DOCUMENTATION - -# import sphinx commands try: from sphinx.setup_command import BuildDoc + cmdclass["build_sphinx"] = BuildDoc except ImportError: pass -else: - cmdclass["build_sphinx"] = BuildDoc - -# read description -with open("README.md", "rb") as f: - longdesc = "f.read().decode().strip()" - - -# DEPENDENCIES -setup_requires = [ - 'setuptools >= 76.0.0', -] -if 'test' in sys.argv: - setup_requires += [ - 'pytest-runner', - ] - - -# These pretty common requirement are commented out. Various syntax types -# are all used in the example below for specifying specific version of the -# packages that are compatbile with your software. -# TODO NOTE: before the v2.0.0 code release, we should froze the versions -# the correct way to do this is to make sure that they are available on -# conda and pip for all platforms we support (see prerequisites doc page). -install_requires = [ - 'numpy >= 1.24.2, < 2.0.0', - 'scipy >= 1.10.1, <= 1.14.1', - 'iminuit >= 2.21.3, <= 2.30.1', - 'configparser >= 5.3.0, <= 7.1.0', - 'astropy >= 5.2.2, <= 6.1.6', - 'pandas >= 2.0.0, <= 2.2.3', - 'scikit-learn == 1.2.2', - 'matplotlib >= 3.9.0, <= 3.9.2', - 'matplotlib-label-lines >= 0.5.2, <= 0.7.0', - 'h5py >= 3.8.0, <= 3.12.1', - 'psutil >= 5.9.4, <= 6.1.0', - 'tqdm >= 4.65.0, <= 4.67.0', - 'tables >= 3.8.0, <= 3.10.1', - 'progressbar2 >= 4.2.0, <= 4.5.0', # for downloading data - 'hurry.filesize >= 0.9, <= 0.9', - 'python-dotenv >= 1.0.0, <= 1.0.1', -] - -tests_require = [ - "pytest >= 7.3.1", - "pytest-cov >= 4.0.0", -] - -# For documentation -extras_require = { - # to build documentation - "doc": [ - "ipython", - "sphinx >= 8.2.2", - "numpydoc", - "sphinx_rtd_theme", - "sphinxcontrib_programoutput", - "PSphinxTheme", - "nbsphinx", - "pandoc", - ], - # for experimental visualization features, e.g. VDH diagrams - "vis": ["PyQt5 >= 5.15.9, <= 5.15.11"], - # for profile machine learning features, e.g. profile interpolation - "ml": ["tensorflow >= 2.13.0"], - # for running population synthesis on HPC facilities - "hpc": ["mpi4py >= 3.0.3"], - # development tooling - 'dev': [ - 'pre-commit >= 3.7.0', - 'isort >= 5.13.2', - ], -} - -# RUN SETUP - -packagenames = find_packages() - -# Executables go in a folder called bin -scripts = glob.glob(os.path.join("bin", "*")) - -PACKAGENAME = "posydon" -DISTNAME = "posydon" -AUTHOR = "POSYDON Collaboration" -AUTHOR_EMAIL = "posydon.team@gmail.com" -LICENSE = "GPLv3+" -DESCRIPTION = "POSYDON the Next Generation of Population Synthesis" -GITHUBURL = "https://github.com/POSYDON-code/POSYDON" - -# Additional included files via include_package_data are defined in MANIFEST.in -setup( - name=DISTNAME, - provides=[PACKAGENAME], - version=__version__, - description=DESCRIPTION, - long_description=longdesc, - long_description_content_type="text/markdown", - ext_modules=[wrapper] if WRAP_C_CPP_OR_FORTRAN else [], - author=AUTHOR, - author_email=AUTHOR_EMAIL, - license=LICENSE, - packages=packagenames, - include_package_data=True, - cmdclass=cmdclass, - url=GITHUBURL, - scripts=scripts, - setup_requires=setup_requires, - install_requires=install_requires, - tests_require=tests_require, - extras_require=extras_require, - python_requires=">=3.11, <3.12", - use_2to3=False, - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Science/Research", - "Intended Audience :: End Users/Desktop", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Astronomy", - "Topic :: Scientific/Engineering :: Physics", - "Programming Language :: Python", - "Programming Language :: Python :: 3.11", - "Operating System :: POSIX", - "Operating System :: Unix", - "Operating System :: MacOS", - "Natural Language :: English", - "License :: OSI Approved :: GNU General Public License v3 (GPLv3+)", - ], -) +# Minimal setup call - all metadata including version is in pyproject.toml +# Version is automatically determined by setuptools-scm from git tags +setup(cmdclass=cmdclass) diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index ef293d9b3e..0000000000 --- a/versioneer.py +++ /dev/null @@ -1,1825 +0,0 @@ - -# Version: 0.18 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer -* Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -""" - -from __future__ import print_function - -try: - import configparser -except ImportError: - import ConfigParser as configparser - -import errno -import json -import os -import re -import subprocess -import sys - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): - cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Get decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -LONG_VERSION_PY['git'] = ''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(manifest_in, versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] - if ipy: - files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: - pass - if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except EnvironmentError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 - - cmds = {} - - # we add "version" to both distutils and setuptools - from distutils.core import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - - # we override "build_py" in both distutils and setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py - else: - from distutils.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if 'py2exe' in sys.modules: # py2exe enabled? - try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 - except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["py2exe"] = cmd_py2exe - - # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist - else: - from distutils.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -INIT_PY_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - - -def do_setup(): - """Perform VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except EnvironmentError: - old = "" - if INIT_PY_SNIPPET not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1)