From c60abb6c62dd99013e6201fbd68373d01052c797 Mon Sep 17 00:00:00 2001 From: Ben Koziol Date: Fri, 17 Apr 2026 11:22:09 -0600 Subject: [PATCH 1/4] branched from ci-split [ci skip] From b3fc30a4c00a7408f6684dfbb280db2a7628d425 Mon Sep 17 00:00:00 2001 From: Ben Koziol Date: Thu, 23 Apr 2026 10:03:58 -0600 Subject: [PATCH 2/4] refactor processor and datasets [ci skip] --- .gitignore | 2 + .pre-commit-config.yaml | 3 +- README.md | 94 +- .../app/chem_regrid/__init__.py | 3 + .../app/chem_regrid/chem_regrid_cli.py | 2 +- .../app/chem_regrid/chem_regrid_context.py | 50 + .../app/chem_regrid/chem_regrid_impl.py | 1572 +++-------------- .../app/chem_regrid/chem_regrid_rrfs.py | 6 +- src/regrid_wrapper/app/chem_regrid/context.py | 76 - .../app/chem_regrid/dataset/__init__.py | 0 .../chem_regrid/dataset/config/__init__.py | 0 .../chem_regrid/dataset/config/datasets.yml | 232 +++ .../app/chem_regrid/dataset/config/model.py | 34 + .../chem_regrid/dataset/context/__init__.py | 47 + .../app/chem_regrid/dataset/context/base.py | 243 +++ .../chem_regrid/dataset/context/ecoregion.py | 16 + .../chem_regrid/dataset/context/fengsha_2d.py | 18 + .../dataset/context/fengsha_2d_time.py | 16 + .../app/chem_regrid/dataset/context/fmc.py | 33 + .../app/chem_regrid/dataset/context/goes.py | 72 + .../chem_regrid/dataset/context/gra2pes.py | 49 + .../app/chem_regrid/dataset/context/narr.py | 16 + .../dataset/context/nemo_anthro.py | 42 + .../chem_regrid/dataset/context/nemo_rwc.py | 33 + .../app/chem_regrid/dataset/context/ngfs.py | 38 + .../app/chem_regrid/dataset/context/pecm.py | 67 + .../app/chem_regrid/dataset/context/rave.py | 170 ++ .../app/chem_regrid/dataset/src_field.py | 85 + src/regrid_wrapper/common.py | 29 +- src/regrid_wrapper/esmpy/field_wrapper.py | 72 +- .../test_app/test_chem_regrid/conftest.py | 5 +- .../test_chem_regrid/test_chem_regrid_cli.py | 15 +- .../test_chem_regrid/test_chem_regrid_impl.py | 15 +- .../test_chem_regrid/test_chem_regrid_rrfs.py | 2 +- .../test_app/test_chem_regrid/test_dataset.py | 12 + 35 files changed, 1734 insertions(+), 1435 deletions(-) create mode 100644 src/regrid_wrapper/app/chem_regrid/chem_regrid_context.py delete mode 100644 src/regrid_wrapper/app/chem_regrid/context.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/__init__.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/config/__init__.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/config/datasets.yml create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/config/model.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/__init__.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/base.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/ecoregion.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/fengsha_2d.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/fengsha_2d_time.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/fmc.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/goes.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/gra2pes.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/narr.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/nemo_anthro.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/nemo_rwc.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/ngfs.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/pecm.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/context/rave.py create mode 100644 src/regrid_wrapper/app/chem_regrid/dataset/src_field.py create mode 100644 src/test/test_app/test_chem_regrid/test_dataset.py diff --git a/.gitignore b/.gitignore index db38738..38a5057 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ __pycache__/ *.DS_Store src/regrid_wrapper.egg-info/ + +/build/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84eb12f..5c79f5b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,6 +43,5 @@ repos: exclude: | (?x)^( - \.venv/| - src/regrid_wrapper/app/chem_regrid/chem_regrid_impl\.py + \.venv ) diff --git a/README.md b/README.md index 2c45eaa..a366c90 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,39 @@ The `rw` command provides a set of command-line tools for regridding. +## "Chem Regrid" Application + +Install via `pip install .` or access via `/src/regrid-wrapper/app/rw_cli.py`. + +``` +usage: rw chem-regrid [-h] [--yaml-path YAML_PATH] [--root-key ROOT_KEY] [--overrides OVERRIDES [OVERRIDES ...]] + +options: + -h, --help show this help message and exit + --yaml-path YAML_PATH + If provided, path to YAML file containing the configuration's root key + --root-key ROOT_KEY If provided, use this key when extracting the root configuration + --overrides OVERRIDES [OVERRIDES ...] + If provided, override arbitrary key+values (e.g. --override key1:nest=val1 key2=val2) +``` + +Example: + +```shell +python ${rw_dir}/src/regrid_wrapper/app/rw_cli.py chem-regrid \ + --overrides workdir=${cr_workdir} \ + input_dir=${cr_input_dir} \ + output_dir=${cr_output_dir} \ + weight_dir=${cr_weight_dir} \ + scrip_path=${cr_scrip_path} \ + dst_path=${cr_dst_path} \ + cycle=${cr_cycle} \ + mesh_name=${cr_mesh_name} \ + ebb_dcycle=1 \ + dataset_name=RAVE \ + fcst_length=6 +``` + ## MPAS to UGRID Conversion 1. `conda env create -f environment-uxarray.yaml` @@ -44,6 +77,65 @@ cd /opt/project && \ ... or mpi tests: -``` +```bash mpirun -n 8 pytest -m mpi src/test ``` + +# Adding a New Dataset + +To add a new dataset to the regridding pipeline, follow these steps: + +1. **Update `DatasetName` Enum**: Add the new dataset key to the `DatasetName` enum in `src/regrid_wrapper/app/chem_regrid/dataset/config/model.py`. +2. **Add Configuration**: Add a new entry to `src/regrid_wrapper/app/chem_regrid/dataset/config/datasets.yml` following the schema described above. +3. **Create Regrid Context Subclass**: In `src/regrid_wrapper/app/chem_regrid/dataset/context/`, create a new module (e.g., `my_dataset.py`) and a subclass of `AbstractDatasetRegridContext` (e.g., `MY_DATASET_DatasetRegridContext`). + * Implement `iter_file_pairs` to define how source and destination files are paired. + * Override methods as needed for dataset-specific logic. +4. **Register the Subclass**: Import and return the new context class in `regrid_wrapper.app.chem_regrid.dataset.context.__init__.py.get_regrid_context_class`. +5. **Add Test**: Add a new test case for the dataset in `src/test/test_app/test_chem_regrid/conftest.py`. + +## Dataset Configuration + +Datasets are configured in `src/regrid_wrapper/app/chem_regrid/dataset/config/datasets.yml`. Each entry defines how a specific dataset should be read and regridded. + +### Dataset Schema + +| Field | Description | +|---|---| +| `field_names` | List of variable names to be regridded from the source file. | +| `x_center` | Variable name for longitude centers. | +| `y_center` | Variable name for latitude centers. | +| `x_dim` | Dimension name for the X (longitude) axis. | +| `y_dim` | Dimension name for the Y (latitude) axis. | +| `x_corner` | (Optional) Variable name for longitude corners. Set to `null` if not available. | +| `y_corner` | (Optional) Variable name for latitude corners. Set to `null` if not available. | +| `x_corner_dim` | (Optional) Dimension name for longitude corners. | +| `y_corner_dim` | (Optional) Dimension name for latitude corners. | +| `level_in_name` | (Optional) Name of the vertical level dimension in the source file. | +| `level_out_name` | Name of the vertical level dimension in the output file. | +| `level_out_size` | Number of vertical levels in the output. Set to `0` for 2D data. | +| `time_name` | (Optional) Name of the time dimension in the source file. | +| `time_size` | Number of time steps. Set to `0` if time dimension is not used. | +| `InterpMethod` | ESMF interpolation method (e.g., `CONSERVE`, `BILINEAR`, `NEAREST_STOD`). | + +### Example Entry + +```yaml +MY_DATASET: + field_names: + - PM25 + - SO2 + x_center: lon + y_center: lat + x_dim: x + y_dim: y + x_corner: null + y_corner: null + x_corner_dim: null + y_corner_dim: null + level_in_name: null + level_out_name: nkanthro + level_out_size: 1 + time_name: time + time_size: 1 + InterpMethod: CONSERVE +``` diff --git a/src/regrid_wrapper/app/chem_regrid/__init__.py b/src/regrid_wrapper/app/chem_regrid/__init__.py index e69de29..f810d1a 100644 --- a/src/regrid_wrapper/app/chem_regrid/__init__.py +++ b/src/regrid_wrapper/app/chem_regrid/__init__.py @@ -0,0 +1,3 @@ +from regrid_wrapper.context.logging import LOGGER + +CR_LOGGER = LOGGER.getChild("mpas-regrid") diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_cli.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_cli.py index 8751317..49d8ab7 100644 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_cli.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_cli.py @@ -3,8 +3,8 @@ import yaml +from regrid_wrapper.app.chem_regrid.chem_regrid_context import ChemRegridContext from regrid_wrapper.app.chem_regrid.chem_regrid_impl import main -from regrid_wrapper.app.chem_regrid.context import ChemRegridContext from regrid_wrapper.app.override import apply_overrides diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_context.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_context.py new file mode 100644 index 0000000..6c29977 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_context.py @@ -0,0 +1,50 @@ +from functools import cached_property +from pathlib import Path + +from pydantic import Field + +from regrid_wrapper.app.chem_regrid.dataset.config.model import ChemRegridDataset +from regrid_wrapper.app.chem_regrid.dataset.context import DatasetName +from regrid_wrapper.common import RwBaseModel + + +class ChemRegridContext(RwBaseModel): + dataset_name: DatasetName + workdir: Path + input_dir: Path + output_dir: Path + weight_dir: Path + cycle: str = Field(pattern=r"^\d{10}$") # Validates YYYYMMDDHH format + mesh_name: str + input_mesh_path: Path | None + dst_path: Path | None + ebb_dcycle: int + fcst_length: int + datasets_yml_path: Path = Path(__file__).parent / "dataset" / "config" / "datasets.yml" + + @cached_property + def rw_input_mesh_path(self) -> Path: + if self.input_mesh_path is None: + return self.workdir / f"mpas_{self.dataset_name.value}-{self.mesh_name}_scrip.nc" + return self.input_mesh_path + + @cached_property + def rw_dst_path(self) -> Path: + if self.dst_path is None: + return self.workdir / "init.nc" + return self.dst_path + + @cached_property + def rw_desc_stats_out(self) -> Path: + return self.workdir / f"desc_stats-{self.cycle}.csv" + + @cached_property + def rw_dataset(self) -> ChemRegridDataset: + return ChemRegridDataset.from_key(self.datasets_yml_path, self.dataset_name) + + @cached_property + def rw_weight_path(self) -> Path: + weight_path = self.weight_dir / ( + "weights_" + self.dataset_name.value + "-to-" + "mpas_" + self.mesh_name + "_" + self.rw_dataset.InterpMethod + ".nc" + ) + return weight_path diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py index fd099f1..14ce7b1 100755 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py @@ -1,71 +1,36 @@ # mypy: ignore-errors -import sys import glob -from abc import abstractmethod, ABC -from datetime import datetime, timezone, timedelta -from functools import cached_property +from datetime import datetime, timezone from pathlib import Path -from typing import Literal, Iterable, Any, Union - -import os +from typing import Iterable, Literal import esmpy import numpy as np -import xarray as xr import pandas as pd from pydantic import BaseModel -from regrid_wrapper.app.chem_regrid.context import ChemRegridContext +from regrid_wrapper.app.chem_regrid import CR_LOGGER +from regrid_wrapper.app.chem_regrid.chem_regrid_context import ChemRegridContext +from regrid_wrapper.app.chem_regrid.dataset.context import get_regrid_context_class +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + InterpMethod, +) +from regrid_wrapper.app.chem_regrid.dataset.src_field import SrcField from regrid_wrapper.context.comm import COMM, reconcile_bounds -from regrid_wrapper.context.logging import LOGGER from regrid_wrapper.esmpy.field_wrapper import ( - GridSpec, - NcToGrid, - NcToField, FieldWrapper, + GridSpec, GridWrapper, + NcToField, + NcToGrid, + copy_nc_variable, open_nc, - Dimension, - DimensionCollection, set_variable_data, - HasNcAttrsType, - copy_nc_variable, ) -_LOGGER = LOGGER.getChild("mpas-regrid") - -# Try to find the latest RAVE file available up to max_lookback_hours before target_time_str -# to avoid setting zeroes when a particular hour file is missing. -def find_latest_rave_file(input_dir, target_time_str, ebb_dcycle, dataset_name, max_lookback_hours=24): - """Return list of files for the latest time <= target_time_str.""" - fmt = "%Y%m%d%H" #RAVE - fmt2= "%Y%j%H" # GOES - target_time = datetime.strptime(target_time_str, fmt) - - input_dir_str = str(input_dir) - for h in range(max_lookback_hours + 1): - if ebb_dcycle == -1 or ebb_dcycle == 2: - this_time = target_time - timedelta(hours=h) - elif ebb_dcycle == 1: - this_time = target_time + timedelta(hours=h) - else: - _LOGGER.warning("unrecognized ebb_dcycle, reverting to same-day, ebb_dcycle = 1") - this_time = target_time + timedelta(hours=h) - - if dataset_name == "RAVE": - this_str = this_time.strftime(fmt) - paths = glob.glob(input_dir_str + "/RAVE-HrlyEmiss-3km_v2r0_blend_s"+this_str+"*") - elif dataset_name == "GOES": - this_str = this_time.strftime(fmt2) - paths = glob.glob(input_dir_str + "/OR_ABI-L2-AODC-M6_G18_s"+this_str+"*") - if paths: - if h > 0: - print(f"Missing {dataset_name} file for {target_time_str}, using {this_str} instead") - return paths - # nothing found within lookback window - return [] # def create_ngfs_sparse_mesh(lat_1d, lon_1d, resolution=0.01): """ @@ -84,13 +49,9 @@ def create_ngfs_sparse_mesh(lat_1d, lon_1d, resolution=0.01): num_nodes = num_cells * 4 d = resolution / 2.0 - node_lons = np.column_stack([ - lon_1d - d, lon_1d + d, lon_1d + d, lon_1d - d - ]).flatten() + node_lons = np.column_stack([lon_1d - d, lon_1d + d, lon_1d + d, lon_1d - d]).flatten() - node_lats = np.column_stack([ - lat_1d - d, lat_1d - d, lat_1d + d, lat_1d + d - ]).flatten() + node_lats = np.column_stack([lat_1d - d, lat_1d - d, lat_1d + d, lat_1d + d]).flatten() node_coords = np.empty(num_nodes * 2, dtype=np.float64) node_coords[0::2] = node_lons @@ -108,208 +69,11 @@ def create_ngfs_sparse_mesh(lat_1d, lon_1d, resolution=0.01): # Explicitly set spherical coordinates mesh = esmpy.Mesh(parametric_dim=2, spatial_dim=2, coord_sys=esmpy.CoordSys.SPH_DEG) - mesh.add_nodes( - node_count=num_nodes, - node_ids=node_ids, - node_coords=node_coords, - node_owners=node_owners - ) + mesh.add_nodes(node_count=num_nodes, node_ids=node_ids, node_coords=node_coords, node_owners=node_owners) - mesh.add_elements( - element_count=num_cells, - element_ids=element_ids, - element_types=element_types, - element_conn=element_conn - ) + mesh.add_elements(element_count=num_cells, element_ids=element_ids, element_types=element_types, element_conn=element_conn) return mesh -# -class AbstractRaveField(ABC, BaseModel): - name: str - attrs: dict[str, Any] - fill_value: float - dtype: Any - num_cells: int - level_out_name: str - level_out_size: int - time_size: int - - @cached_property - def time_dimension(self) -> Dimension: - return Dimension( - name=("Time",), - size=self.time_size, - lower=0, - upper=self.time_size, - staggerloc=esmpy.StaggerLoc.CENTER, - coordinate_type="time", - ) - - @cached_property - def nklevel_dimension(self) -> Dimension: - return Dimension( - name=(self.level_out_name,), - size=self.level_out_size, - lower=0, - upper=self.level_out_size, - staggerloc=esmpy.StaggerLoc.CENTER, - coordinate_type="level", - ) - - def create_ncells_dimension(self, bounds: tuple[int, int]) -> Dimension: - return Dimension( - name=("nCells",), - size=self.num_cells, # 225636, #130333, # tdk: pull from origin, - lower=bounds[0], - upper=bounds[1], - staggerloc=esmpy.MeshLoc.ELEMENT, - coordinate_type="cell", - ) - - @abstractmethod - def create_dimension_collection( - self, ncells_bounds: tuple[int, int] - ) -> DimensionCollection: - ... - - @abstractmethod - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - ... - - -class RaveField2d(AbstractRaveField): - - def create_dimension_collection( - self, ncells_bounds: tuple[int, int] - ) -> DimensionCollection: - return DimensionCollection( - value=(self.create_ncells_dimension(ncells_bounds),) - ) - - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(-1) - - -class RaveField2d_plusTime(AbstractRaveField): - - def create_dimension_collection( - self, ncells_bounds: tuple[int, int] - ) -> DimensionCollection: - return DimensionCollection( - value=(self.time_dimension, self.create_ncells_dimension(ncells_bounds)) - ) - - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(self.time_size, -1) - -class RaveField3d(AbstractRaveField): - - def create_dimension_collection( - self, ncells_bounds: tuple[int, int] - ) -> DimensionCollection: - return DimensionCollection( - value=( - self.create_ncells_dimension(ncells_bounds), - self.nklevel_dimension, - ) - ) - - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(-1, self.level_out_size) - -class RaveField3d_plusTime(AbstractRaveField): - - def create_dimension_collection( - self, ncells_bounds: tuple[int, int] - ) -> DimensionCollection: - return DimensionCollection( - value=( - self.create_ncells_dimension(ncells_bounds), - self.nklevel_dimension, - self.time_dimension, - ) - ) - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(-1, self.level_out_size, self.time_size) - -class RaveToMpasRegridContext(BaseModel): - dataset_name: str - workdir: Path - src_path: Path - dst_path: Path - new_dst_path: Path - desc_stats_out: Path - weight_path: Path - InterpMethod: str - scrip_path: Path - num_cells: int - mesh_name: str - field_names: tuple - x_center: str - y_center: str - x_dim: str - y_dim: str - x_corner: Union[str, None] - y_corner: Union[str, None] - x_corner_dim: Union[str, None] - y_corner_dim: Union[str, None] - level_in_name: str - # level_in_size: int - level_out_name: str - level_out_size: int - time_name: str - time_size: int - # InterpMask: float - - rank: int = COMM.rank - - @cached_property - def rave_fields(self) -> tuple[AbstractRaveField, ...]: - rave_fields = [] - with open_nc(self.src_path, mode="r") as ds: - for field_name in self.field_names: - read_name = field_name - if self.dataset_name == "NGFS" and field_name == "PM25": - read_name = "EMIS_PM25" - - if read_name not in ds.variables: - raise KeyError( - f"Source variable '{read_name}' not found for field '{field_name}' in {self.src_path}" - ) - var = ds.variables[read_name] - init_data = { - "name": field_name, - "attrs": self._get_nc_attrs_(var), - "fill_value": -1.0, - "dtype": var.dtype, - "level_out_name": self.level_out_name, - "level_out_size": self.level_out_size, - "time_size": self.time_size, - "num_cells": self.num_cells, - } - if self.level_out_size == 0: - if self.time_size == 0: - app = RaveField2d.model_validate(init_data) - else: - app = RaveField2d_plusTime.model_validate(init_data) - else: - if self.time_size == 0: - app = RaveField3d.model_validate(init_data) - else: - app = RaveField3d_plusTime.model_validate(init_data) - rave_fields.append(app) - _LOGGER.debug(f"{rave_fields=}") - return tuple(rave_fields) - - @staticmethod - def _get_nc_attrs_(src: HasNcAttrsType) -> dict[str, Any]: - # tdk: does valid_range matter? - exclude = ("coordinates", "valid_range") - return { - ii: getattr(src, ii) - for ii in src.ncattrs() - if not ii.startswith("_") and ii not in exclude - } class FileDesc(BaseModel): @@ -318,10 +82,10 @@ class FileDesc(BaseModel): field_names: tuple[str, ...] -class RaveToMpasRegridProcessor: +class ChemRegridProcessor: _dst_mesh: esmpy.Mesh | None = None - def __init__(self, context: RaveToMpasRegridContext) -> None: + def __init__(self, context: AbstractDatasetRegridContext) -> None: self.context = context self._regridder: esmpy.Regrid | None = None @@ -329,156 +93,152 @@ def __init__(self, context: RaveToMpasRegridContext) -> None: self._src_gwrap: GridWrapper | None = None def initialize(self) -> None: - _LOGGER.info(f"initialize: {self.context=}") + CR_LOGGER.info(f"initialize: {self.context=}") esmpy.Manager(debug=True) - # if not self.context.scrip_path.exists() and self.context.rank == 0: - # _LOGGER.info("writing mpas scrip grid") - # from pyremap import MpasCellMeshDescriptor - # - # mpas_desc = MpasCellMeshDescriptor( - # str(self.context.dst_path), self.context.mesh_name + ".init" - # ) - # mpas_desc.to_scrip(str(self.context.scrip_path)) + pathsrc = self.context.get_src_grid_path() -# JLS - temporary fix for coords not in file - if self.context.dataset_name == "GOES": - pathsrc=self.context.workdir / "goes19_abi_conus_interpolated_lat_lon.nc" - else: - pathsrc=self.context.src_path - _LOGGER.info("create source grid") - if self.context.x_corner_dim is None: - self._src_gwrap = NcToGrid( - path=pathsrc, - spec=GridSpec( - x_center=self.context.x_center, - y_center=self.context.y_center, - x_dim=(self.context.x_dim,), - y_dim=(self.context.y_dim,), - x_corner=self.context.x_corner, - y_corner=self.context.y_corner, - x_corner_dim=self.context.x_corner_dim, - y_corner_dim=self.context.y_corner_dim, - ), - ).create_grid_wrapper() - else: - self._src_gwrap = NcToGrid( - path=pathsrc, - spec=GridSpec( - x_center=self.context.x_center, - y_center=self.context.y_center, - x_dim=(self.context.x_dim,), - y_dim=(self.context.y_dim,), - x_corner=self.context.x_corner, - y_corner=self.context.y_corner, - x_corner_dim=(self.context.x_corner_dim,), - y_corner_dim=(self.context.y_corner_dim,), - ), - ).create_grid_wrapper() + CR_LOGGER.info("create source grid") + self._src_gwrap = NcToGrid( + path=pathsrc, + spec=GridSpec.model_validate(self.context.model_dump()), + ).create_grid_wrapper() - _LOGGER.info("create source field") - src_fwrap = self.create_src_field_wrapper(self.context.rave_fields[0].name) + CR_LOGGER.info("create source field") + src_fwrap = self.create_src_field_wrapper(self.context.src_fields[0].name) if self._dst_mesh is None: - _LOGGER.info("create destination mesh") + CR_LOGGER.info("create destination mesh") # dst_mesh = esmpy.Mesh( - # filename=str(self.context.scrip_path), filetype=esmpy.FileFormat.SCRIP + # filename=str(self.context.input_mesh_path), filetype=esmpy.FileFormat.SCRIP # ) self._dst_mesh = esmpy.Mesh( - filename=str(self.context.scrip_path), filetype=esmpy.FileFormat.UGRID, meshname="grid_topology" + filename=str(self.context.input_mesh_path), filetype=esmpy.FileFormat.UGRID, meshname="grid_topology" ) dst_mesh = self._dst_mesh -# Check for extra dims beyond lat/lon - _LOGGER.info("create destination field") - if self.context.level_out_size == 0: - #2D - if self.context.time_size == 0: - # 2D, static in Time - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, - ) - else: - # 2D + Time - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.time_size,) - ) - else: - #3D - if self.context.time_size == 0: - # 3D, static in Time - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size,) - ) - else: - # 3D + Time - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size, self.context.time_size) - ) -# Check for weights - _LOGGER.info("create regridder") + self._dst_field = self._create_dst_field_(dst_mesh) + self._regridder = self._create_regridder_(src_fwrap) + + def _create_dst_field_(self, dst_mesh: esmpy.Mesh) -> esmpy.Field: + CR_LOGGER.info("create destination field") + + # Check for extra dims beyond lat/lon + ndbounds = [] + if self.context.level_out_size > 0: + ndbounds.append(self.context.level_out_size) + if self.context.time_size > 0: + ndbounds.append(self.context.time_size) + + kwargs = {} + if ndbounds: + kwargs["ndbounds"] = tuple(ndbounds) + + return esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, **kwargs) + + def _create_regridder_(self, src_fwrap: FieldWrapper) -> esmpy.RegridFromFile | esmpy.Regrid: + CR_LOGGER.info("create regridder") if self.context.weight_path.exists(): - _LOGGER.info("create regridder from file") - self._regridder = esmpy.RegridFromFile( + CR_LOGGER.info("create regridder from file") + regridder = esmpy.RegridFromFile( srcfield=src_fwrap.value, dstfield=self._dst_field, filename=str(self.context.weight_path), ) else: - _LOGGER.info("create regridder in-memory") - if self.context.InterpMethod == "CONSERVE": - _LOGGER.info("using 1st order conservative interp") - self._regridder = esmpy.Regrid( - srcfield=src_fwrap.value, - dstfield=self._dst_field, - regrid_method=esmpy.RegridMethod.CONSERVE, - unmapped_action=esmpy.UnmappedAction.IGNORE, - ignore_degenerate=True, - large_file=True, - filename=str(self.context.weight_path), - ) - elif self.context.InterpMethod == "CONSERVE_2ND": - _LOGGER.info("using 2nd order conservative interp") - self._regridder = esmpy.Regrid( - srcfield=src_fwrap.value, - dstfield=self._dst_field, - regrid_method=esmpy.RegridMethod.CONSERVE_2ND, - unmapped_action=esmpy.UnmappedAction.IGNORE, - ignore_degenerate=True, - large_file=True, - filename=str(self.context.weight_path), - ) - elif self.context.InterpMethod == "BILINEAR": - _LOGGER.info("using bilinear interp") - self._regridder = esmpy.Regrid( - srcfield=src_fwrap.value, - dstfield=self._dst_field, - regrid_method=esmpy.RegridMethod.BILINEAR, - unmapped_action=esmpy.UnmappedAction.IGNORE, - ignore_degenerate=True, - large_file=True, - filename=str(self.context.weight_path), - ) - else: - _LOGGER.info("using nearest_STOD interp") - self._regridder = esmpy.Regrid( - srcfield=src_fwrap.value, - dstfield=self._dst_field, - regrid_method=esmpy.RegridMethod.NEAREST_STOD, - unmapped_action=esmpy.UnmappedAction.IGNORE, - ignore_degenerate=True, - large_file=True, - filename=str(self.context.weight_path), - ) + CR_LOGGER.info("create regridder in-memory") + method_map = { + InterpMethod.CONSERVE: esmpy.RegridMethod.CONSERVE, + InterpMethod.CONSERVE_2ND: esmpy.RegridMethod.CONSERVE_2ND, + InterpMethod.BILINEAR: esmpy.RegridMethod.BILINEAR, + InterpMethod.NEAREST_STOD: esmpy.RegridMethod.NEAREST_STOD, + } + # Default to NEAREST_STOD if not found in map (preserving original behavior) + regrid_method = method_map[self.context.InterpMethod] + + CR_LOGGER.info(f"using {regrid_method} interp") + regridder = esmpy.Regrid( + srcfield=src_fwrap.value, + dstfield=self._dst_field, + regrid_method=regrid_method, + unmapped_action=esmpy.UnmappedAction.IGNORE, + ignore_degenerate=True, + large_file=True, + filename=str(self.context.weight_path), + ) + return regridder def run(self) -> None: - _LOGGER.info("apply regridding") + CR_LOGGER.info("apply regridding") + + CR_LOGGER.info("create output file") + self.create_output_file() + + for src_field in self.context.src_fields: + self._regrid_src_field(src_field) + + if self.context.write_desc_stats and self.context.rank == 0: + field_names = tuple(ii.name for ii in self.context.src_fields) + targets = [ + FileDesc( + path=self.context.new_dst_path, + origin="dst", + field_names=field_names, + ), + FileDesc( + path=self.context.src_path, + origin="src", + field_names=field_names, + ), + ] + data_frame = self.create_desc_stuff(targets) + data_frame.to_csv(self.context.desc_stats_out, index=False) - _LOGGER.info("create output file") + def _regrid_src_field(self, src_field: SrcField) -> None: + CR_LOGGER.info(f"regridding {src_field.name=}") + regridder = self.get_regridder() + src_fwrap = self.create_src_field_wrapper(field_name=src_field.name) + + dst_field = self.get_dst_field() + dst_field.data.fill(0.0) + regridder(src_fwrap.value, dst_field) + + local_bounds = (dst_field.lower_bounds[0], dst_field.upper_bounds[0]) + reconciled_bounds = reconcile_bounds(local_bounds) + dims = src_field.create_dimension_collection(reconciled_bounds) + CR_LOGGER.debug(f"{dims=}") + CR_LOGGER.info("writing field to netcdf") + with open_nc(self.context.new_dst_path, mode="a") as ds: + transformed_data = self.context.transform_regridded_data(src_field, dst_field.data, ds, reconciled_bounds, dims) + + CR_LOGGER.info(f"creating variable {src_field.name=}") + var = ds.createVariable( + src_field.name, + src_field.dtype, + [dim.name[0] for dim in dims.value], + fill_value=src_field.fill_value, + ) + for k, v in src_field.attrs.items(): + setattr(var, k, v) + + CR_LOGGER.info(f"setting variable data {src_field.name=}") + set_variable_data( + var, + dims, + src_field.reshape_field_data(transformed_data), + collective=True, + ) + CR_LOGGER.info(f"finished writing field to netcdf {src_field.name=}") + src_fwrap.value.destroy() + del src_fwrap + + self.context.post_regrid_processing(src_field, regridder, self, dims) + + def create_output_file(self): if self.context.rank == 0: with open_nc(self.context.new_dst_path, mode="w", clobber=True, parallel=False) as dst_nc: dst_nc.createDimension("nCells", self.context.num_cells) - if self.context.level_out_name != "None": + if self.context.level_out_name is not None: dst_nc.createDimension(self.context.level_out_name, self.context.level_out_size) dst_nc.createDimension("StrLen", 64) if self.context.time_size > 1: @@ -486,174 +246,27 @@ def run(self) -> None: elif self.context.time_size == 1: if "Time" not in dst_nc.dimensions: dst_nc.createDimension("Time") - _LOGGER.info("Not creating a time dimension") + else: + CR_LOGGER.debug("Not creating a time dimension") dst_nc.setncattr("created_at", str(datetime.now(timezone.utc))) dst_nc.setncattr("src_path", str(self.context.src_path)) dst_nc.setncattr("dst_path", str(self.context.dst_path)) with open_nc(self.context.dst_path, mode="r", parallel=False) as src_nc: - if self.context.dataset_name in ("RAVE"): - for varname in ("latCell", "lonCell", "areaCell", "xtime"): - copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) - elif self.context.dataset_name in ("FENGSHA_2D"): - for varname in ("latCell", "lonCell"): - copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) - else: - for varname in ("latCell", "lonCell", "xtime"): - copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) - - regridder = self.get_regridder() - for rave_field in self.context.rave_fields: - _LOGGER.info(f"regridding {rave_field.name=}") - src_fwrap = self.create_src_field_wrapper(field_name=rave_field.name) - - dst_field = self.get_dst_field() - # tdk: any more qa stuff? minimum threshold? - dst_field.data.fill(0.0) - regridder(src_fwrap.value, dst_field) - # tdk: support NcToMesh - local_bounds = (dst_field.lower_bounds[0], dst_field.upper_bounds[0]) - reconciled_bounds = reconcile_bounds(local_bounds) - dims = rave_field.create_dimension_collection(reconciled_bounds) - _LOGGER.info(f"{dims=}") - _LOGGER.info(f"writing field to netcdf") - with open_nc(self.context.new_dst_path, mode="a") as ds: - if self.context.dataset_name == "RAVE" and rave_field.name in ("FRP_MEAN", "FRE"): - area = np.asarray(ds.variables['areaCell']) - area_subset = area[reconciled_bounds[0]:reconciled_bounds[1]].reshape(dims.shape_local) - _LOGGER.info(f"creating variable {rave_field.name=}") - var = ds.createVariable( - rave_field.name, - rave_field.dtype, - [dim.name[0] for dim in dims.value], - fill_value=rave_field.fill_value, - ) - # Don't carry over fill value and datatype - if self.context.dataset_name != 'GOES': - type_to_use = rave_field.dtype - for k, v in rave_field.attrs.items(): - setattr(var, k, v) - else: - type_to_use = np.float32 - - _LOGGER.info(f"setting variable data {rave_field.name=}") - # Multiply FRE/FRP by output area so it is back to W or J*s - if self.context.dataset_name == "RAVE" and rave_field.name in ("FRP_MEAN", "FRE"): - set_variable_data( - var, - dims, - rave_field.reshape_field_data(dst_field.data * area_subset), - collective=True, - ) - else: - set_variable_data( - var, - dims, - rave_field.reshape_field_data(dst_field.data), - collective=True, - ) - _LOGGER.info(f"finished writing field to netcdf {rave_field.name=}") - src_fwrap.value.destroy() - del src_fwrap - - if rave_field.name == "ENL_POLL": - with open_nc(self.context.new_dst_path, mode="a") as ds: - _LOGGER.info(f"renaming and combining tree fields") - - src_fwrap_enl = self.create_src_field_wrapper(field_name='ENL_POLL') - dst_field_enl = self.get_dst_field() - dst_field_enl.data.fill(0.0) - regridder(src_fwrap_enl.value, dst_field_enl) - - src_fwrap_dbl = self.create_src_field_wrapper(field_name='DBL_POLL') - dst_field_dbl = self.get_dst_field() - dst_field_dbl.data.fill(0.0) - regridder(src_fwrap_dbl.value, dst_field_dbl) - - rave_field = self.context.rave_fields[0] - - var = ds.createVariable( - 'TREE_POLL', - rave_field.dtype, - [dim.name[0] for dim in dims.value], - fill_value=rave_field.fill_value, - ) - for k, v in self.context.rave_fields[0].attrs.items(): - setattr(var, k, v) - set_variable_data( - var, - dims, - rave_field.reshape_field_data(dst_field_enl.data + dst_field_dbl.data), - collective=True, - ) - src_fwrap_enl.value.destroy() - del src_fwrap_enl - src_fwrap_dbl.value.destroy() - del src_fwrap_dbl - if rave_field.name == "TPM": - with open_nc(self.context.new_dst_path, mode="a") as ds: - _LOGGER.info(f"calculating PM10 as TPM - PM25") - src_fwrap_ttl = self.create_src_field_wrapper(field_name='TPM') - src_fwrap_p25 = self.create_src_field_wrapper(field_name='PM25') - - dst_field_ttl = self.get_dst_field() - dst_field_ttl.data.fill(0.0) - regridder(src_fwrap_ttl.value, dst_field_ttl) - - dst_field_p25 = self.get_dst_field() - dst_field_p25.data.fill(0.0) - regridder(src_fwrap_p25.value, dst_field_p25) - - rave_field = self.context.rave_fields[0] - - var = ds.createVariable( - 'PM10', - rave_field.dtype, - [dim.name[0] for dim in dims.value], - fill_value=rave_field.fill_value, - ) - for k, v in self.context.rave_fields[0].attrs.items(): - setattr(var, k, v) - data1 = rave_field.reshape_field_data(dst_field_ttl.data) - data2 = rave_field.reshape_field_data(dst_field_p25.data) - data3 = data1 - data2 - set_variable_data( - var, - dims, - data3, - collective=True, - ) - src_fwrap_ttl.value.destroy() - del src_fwrap_ttl - src_fwrap_p25.value.destroy() - del src_fwrap_p25 - - # if self.context.rank == 0: - # field_names = tuple(ii.name for ii in self.context.rave_fields) - # targets = [ - # FileDesc( - # path=self.context.new_dst_path, - # origin="dst", - # field_names=field_names, - # ), - # FileDesc( - # path=self.context.src_path, - # origin="src", - # field_names=field_names, - # ), - # ] - # data_frame = self.create_desc_stuff(targets) - # data_frame.to_csv(self.context.desc_stats_out, index=False) + for varname in self.context.var_names_to_copy_to_output_file: + copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) def finalize(self) -> None: - _LOGGER.info("finalizing") + CR_LOGGER.info("finalizing") self._regridder.destroy() self._dst_field.destroy() self._src_gwrap.value.destroy() + # TODO: There could be an option to destroy the destination mesh when finalizing. However, + # it is more efficient to leave it since the destination is not variable at this point. # self._dst_mesh.destroy() def create_desc_stuff(self, targets: Iterable[FileDesc]) -> pd.DataFrame: - _LOGGER.info("entering create_desc_stuff") + CR_LOGGER.info("entering create_desc_stuff") if self.context.rank > 0: raise ValueError @@ -675,116 +288,33 @@ def create_desc_stuff(self, targets: Iterable[FileDesc]) -> pd.DataFrame: desc = pd.concat( [ desc, - pd.DataFrame( - data=adds, index=["sum", "count_null", "origin", "path"] - ), + pd.DataFrame(data=adds, index=["sum", "count_null", "origin", "path"]), ] ) to_concat.append(desc) ret = pd.concat([ii.transpose() for ii in to_concat]) ret.index.name = "field_name" ret.reset_index(inplace=True) - _LOGGER.info("exiting create_desc_stuff") + CR_LOGGER.info("exiting create_desc_stuff") return ret def create_src_field_wrapper(self, field_name: str) -> FieldWrapper: - _LOGGER.info("create source field") - if self.context.dataset_name == "GRA2PES" and field_name in ("h_agl",): # Special case for staggered grid - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=('bottom_top_stag',), - ).create_field_wrapper() - elif self.context.level_in_name == "None": - if self.context.time_name == "None": - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=None, - dim_level=None, - ).create_field_wrapper() - else: - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=None, - ).create_field_wrapper() - else: - if self.context.time_name == "None": - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=None, - dim_level=(self.context.level_in_name,), - ).create_field_wrapper() - else: - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=(self.context.level_in_name,), - ).create_field_wrapper() - - # Get the area from the RAVE file, need to convert from /grid to /m2 - if (self.context.dataset_name == "RAVE" and field_name in ("PM25", "NH3", "SO2", "FRE", "FRP_MEAN", "TPM", "CH4", "CO", "NOx")): - area_fwrap = NcToField( - path=self.context.src_path, - name='area', - gwrap=self.get_src_gwrap(), - dim_time=None, - ).create_field_wrapper() - area_data = area_fwrap.value.data - - # GRA2PES PM, convert from metric tons/km2/hr to ug/m2/s - if self.context.dataset_name == "GRA2PES" and field_name in ("PM25-PRI", "PM10-PRI"): - conv_aer = 1.e6 / 3600. - # GRA2PES methane, convert from moles/km2/hr to ug/m2/s - elif self.context.dataset_name == "GRA2PES" and field_name in ("HC01", "SO2", "CO", "NH3", "NOX"): - conv_aer = 1.e-6 / 3600. - # RAVE methane, convert from kg/hr to mol/m2/s - elif self.context.dataset_name == "RAVE": - if field_name == "CH4": - conv_aer = (1.0 / 16.0) * 1000. - elif field_name == "CO": - conv_aer = (1.0 / 28.0) * 1000. - elif field_name == "NH3": - conv_aer = (1.0 / 17.0) * 1000. - elif field_name == "NOx": - conv_aer = ( (1.0 / 30.0) + (1.0 / 46.0) ) / 2. * 1000. - else: - conv_aer = 1.0 - elif self.context.dataset_name == "NEMO_RWC" and field_name in ("PEC","POC","PMOTHR","PMC"): - # Convert g/s/km2 (on 1km grid) to ug/m2/s --> - conv_aer = 1.0 - elif self.context.dataset_name == "NEMO_ANTHRO" and field_name in ("PEC","POC","PMOTHR","PMC"): - # Convert g/s/km2 to ug/m2/s --> - conv_aer = 1.0 - else: - conv_aer = 1.0 - - src_data = src_fwrap.value.data - if self.context.dataset_name == "RAVE" and field_name in ("PM25", "TPM"): - # If RAVE aerosol emissions, convert from kg/hr to ug/m2/s - src_data[:] = np.where(src_data < 0.0, 0.0, src_data * 1.e3 / area_data[:, :, np.newaxis] / 3600.) - elif self.context.dataset_name == "RAVE" and field_name in ("CH4", "NH3", "SO2", "CO", "NOx"): - # If RAVE gas emissions, convert from kg/hr to mol/m2/s - src_data[:] = np.where(src_data < 0.0, 0.0, conv_aer * src_data / area_data[:, :, np.newaxis] / 3600.) - elif self.context.dataset_name == "RAVE" and field_name in ("FRE", "FRP_MEAN"): - # For FRE, FRP, don't multiply area by 1.e6, cancelled out by MW to W conversion - src_data[:] = np.where(src_data < 0.0, 0.0, src_data / (area_data[:, :, np.newaxis])) - else: - src_data[:] = np.where(src_data < 0.0, 0.0, conv_aer * src_data) - - src_data[:] = np.where(np.isnan(src_data), 0.0, src_data) + CR_LOGGER.info("create source field") + src_fwrap = self._create_raw_src_field_wrapper_(field_name) + self.context.update_src_field_wrapper(src_fwrap) return src_fwrap + def _create_raw_src_field_wrapper_(self, field_name: str) -> FieldWrapper: + dim_level, dim_time = self.context.get_src_field_dims(field_name) + + return NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=dim_time, + dim_level=dim_level, + ).create_field_wrapper() + def get_src_gwrap(self) -> GridWrapper: if self._src_gwrap is None: raise ValueError @@ -802,46 +332,44 @@ def get_regridder(self) -> esmpy.Regrid: def init_destination_only(self) -> None: """Loads the heavy MPAS destination mesh once for dynamic NGFS processing.""" - _LOGGER.info("Initializing MPAS Destination Mesh (Once)") + CR_LOGGER.info("Initializing MPAS Destination Mesh (Once)") esmpy.Manager(debug=True) - # if not self.context.scrip_path.exists() and self.context.rank == 0: - # _LOGGER.info("writing mpas scrip grid") + # if not self.context.input_mesh_path.exists() and self.context.rank == 0: + # CR_LOGGER.info("writing mpas scrip grid") # mpas_desc = MpasCellMeshDescriptor( # str(self.context.dst_path), self.context.mesh_name + ".init" # ) - # mpas_desc.to_scrip(str(self.context.scrip_path)) + # mpas_desc.to_scrip(str(self.context.input_mesh_path)) - _LOGGER.info("create destination mesh") - dst_mesh = esmpy.Mesh( - filename=str(self.context.scrip_path), filetype=esmpy.FileFormat.UGRID, - meshname="grid_topology" - ) + CR_LOGGER.info("create destination mesh") + dst_mesh = esmpy.Mesh(filename=str(self.context.input_mesh_path), filetype=esmpy.FileFormat.UGRID, meshname="grid_topology") # Create destination field (using logic from your original initialize method) + ndbounds = None if self.context.level_out_size > 1 and self.context.time_size > 1: - self._dst_field = esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size, self.context.time_size)) + ndbounds = (self.context.level_out_size, self.context.time_size) elif self.context.level_out_size > 1 and self.context.time_size == 1: - self._dst_field = esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size,)) + ndbounds = (self.context.level_out_size,) elif self.context.level_out_size == 1 and self.context.time_size > 1: - self._dst_field = esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.time_size,)) - else: - self._dst_field = esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT) + ndbounds = (self.context.time_size,) + + self._dst_field = esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=ndbounds) def process_ngfs_file(self, file_path: Path, resolution: float = 0.01) -> None: """Dynamically builds a mesh for NGFS points, regrids, and writes the output.""" - _LOGGER.info(f"Processing NGFS file: {file_path}") + CR_LOGGER.info(f"Processing NGFS file: {file_path}") # 1. Read NGFS Coordinates AND Area with open_nc(file_path, mode="r") as ds: - lats = ds.variables['lat'][:].filled(np.nan) - lons = ds.variables['lon'][:].filled(np.nan) + lats = ds.variables["lat"][:].filled(np.nan) + lons = ds.variables["lon"][:].filled(np.nan) # Read the NGFS area (in km2) - if 'GRID_AREA' in ds.variables: - grid_area = ds.variables['GRID_AREA'][:].filled(np.nan) + if "GRID_AREA" in ds.variables: + grid_area = ds.variables["GRID_AREA"][:].filled(np.nan) else: - _LOGGER.warning("GRID_AREA not found! Defaulting to 1.0 km2.") + CR_LOGGER.warning("GRID_AREA not found! Defaulting to 1.0 km2.") grid_area = np.ones_like(lats) # Filter out NaNs @@ -854,7 +382,7 @@ def process_ngfs_file(self, file_path: Path, resolution: float = 0.01) -> None: lons = lons % 360.0 if len(lats) == 0: - _LOGGER.warning("No valid fires in file.") + CR_LOGGER.warning("No valid fires in file.") return # 2. Build Sparse Source Mesh @@ -881,35 +409,34 @@ def process_ngfs_file(self, file_path: Path, resolution: float = 0.01) -> None: for varname in ("latCell", "lonCell", "areaCell", "xland", "xtime"): copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) - # 4. Process Each Variable - for rave_field in self.context.rave_fields: - _LOGGER.info(f"regridding NGFS {rave_field.name=}") + for src_field in self.context.src_fields: + CR_LOGGER.info(f"regridding NGFS {src_field.name=}") # Create Source Field dynamically - src_field = esmpy.Field(src_mesh, name=rave_field.name, meshloc=esmpy.MeshLoc.ELEMENT) + src_field = esmpy.Field(src_mesh, name=src_field.name, meshloc=esmpy.MeshLoc.ELEMENT) # Map MPAS expected name to NGFS actual name - if rave_field.name == "PM25": + if src_field.name == "PM25": ngfs_var_name = "EMIS_PM25" else: - ngfs_var_name = rave_field.name + ngfs_var_name = src_field.name # Load the raw data with open_nc(file_path, mode="r") as ds: if ngfs_var_name in ds.variables: raw_data = ds.variables[ngfs_var_name][:].filled(0.0)[valid] else: - _LOGGER.warning(f"Variable {ngfs_var_name} not found! Skipping.") + CR_LOGGER.warning(f"Variable {ngfs_var_name} not found! Skipping.") continue # --------------------------------------------------------- # UNIT CONVERSIONS (Identical to RAVE logic) # --------------------------------------------------------- - if rave_field.name in ("PM25", "TPM"): + if src_field.name in ("PM25", "TPM"): # Convert from kg/hr to ug/m2/s (1e3 handles the km2 to m2 and kg to ug ratio) - src_data = np.where(raw_data < 0.0, 0.0, raw_data * 1.e3 / grid_area / 3600.0) - elif rave_field.name in ("FRE", "FRP_MEAN"): + src_data = np.where(raw_data < 0.0, 0.0, raw_data * 1.0e3 / grid_area / 3600.0) + elif src_field.name in ("FRE", "FRP_MEAN"): # For FRE, FRP: MW to W (1e6) cancels out with km2 to m2 (1e6) src_data = np.where(raw_data < 0.0, 0.0, raw_data / grid_area) else: @@ -922,7 +449,7 @@ def process_ngfs_file(self, file_path: Path, resolution: float = 0.01) -> None: srcfield=src_field, dstfield=self._dst_field, regrid_method=esmpy.RegridMethod.CONSERVE, - unmapped_action=esmpy.UnmappedAction.IGNORE + unmapped_action=esmpy.UnmappedAction.IGNORE, ) # Apply Regridding @@ -932,25 +459,25 @@ def process_ngfs_file(self, file_path: Path, resolution: float = 0.01) -> None: # Write to Output NetCDF local_bounds = (self._dst_field.lower_bounds[0], self._dst_field.upper_bounds[0]) reconciled_bounds = reconcile_bounds(local_bounds) - dims = rave_field.create_dimension_collection(reconciled_bounds) + dims = src_field.create_dimension_collection(reconciled_bounds) with open_nc(self.context.new_dst_path, mode="a") as ds: var = ds.createVariable( - rave_field.name, # Keep it as standard name in output! - rave_field.dtype, + src_field.name, # Keep it as standard name in output! + src_field.dtype, [dim.name[0] for dim in dims.value], - fill_value=rave_field.fill_value, + fill_value=src_field.fill_value, ) - for k, v in rave_field.attrs.items(): + for k, v in src_field.attrs.items(): setattr(var, k, v) # Multiply by areaCell for Power/Energy variables (back to total W in cell) - if rave_field.name in ("FRP_MEAN", "FRE"): - area = np.asarray(ds.variables['areaCell']) - area_subset = area[reconciled_bounds[0]:reconciled_bounds[1]] - set_variable_data(var, dims, rave_field.reshape_field_data(self._dst_field.data * area_subset), collective=True) + if src_field.name in ("FRP_MEAN", "FRE"): + area = np.asarray(ds.variables["areaCell"]) + area_subset = area[reconciled_bounds[0] : reconciled_bounds[1]] + set_variable_data(var, dims, src_field.reshape_field_data(self._dst_field.data * area_subset), collective=True) else: - set_variable_data(var, dims, rave_field.reshape_field_data(self._dst_field.data), collective=True) + set_variable_data(var, dims, src_field.reshape_field_data(self._dst_field.data), collective=True) # Clean up memory regridder.destroy() @@ -960,379 +487,74 @@ def process_ngfs_file(self, file_path: Path, resolution: float = 0.01) -> None: src_mesh.destroy() +def run_regridding(ctx: AbstractDatasetRegridContext) -> None: + processor = None + for file_pair in ctx.iter_file_pairs(): + # --- OPTIMIZATION START --- + if processor is None: + CR_LOGGER.info("FIRST PASS: Full Initialization") + # This pays the "expensive" cost of loading weights/grids, but only once. + ctx.src_path = file_pair.src_path + ctx.new_dst_path = file_pair.dst_path + + processor = ChemRegridProcessor(context=ctx) + processor.initialize() + else: + CR_LOGGER.info("SUBSEQUENT PASSES: Hot Swap") + # Just update the paths in the existing context. + # The grids and regridder (weights) remain loaded in memory. + processor.context.src_path = file_pair.src_path + processor.context.new_dst_path = file_pair.dst_path + # Run the regridding (Fast) + processor.run() + # --- OPTIMIZATION END --- + # Only finalize after ALL files are done + if processor: + processor.finalize() + CR_LOGGER.info("success") + + def main(ctx: ChemRegridContext) -> None: - dataset_name = ctx.dataset_name.value # Which dataset are we interpolating? - workdir = ctx.workdir # Directory where operations will be processed - input_dir = ctx.input_dir # Top directory of input data - output_dir = ctx.output_dir # Top directory of output data - cycle = ctx.cycle - scrip_path = ctx.rw_scrip_path # Cycle Time, YYYYMMDDHH - dst_path = ctx.rw_dst_path - mesh_name = ctx.mesh_name - ebb_dcycle = ctx.ebb_dcycle - - desc_stats_out = ctx.rw_desc_stats_out - # - YYYY = cycle[0:4] - MM = cycle[4:6] - DD = cycle[6:8] - HH = cycle[8:10] - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - JJJ = x.strftime("%j") - DOWh = int(x.strftime("%u")) - if DOWh <= 5: - DOWs = "weekdy" - elif DOWh == 6: - DOWs = "satdy" - else: - DOWs = "sundy" - - # Calculate the number of cells in the - with open_nc(dst_path, mode="r", parallel=False) as src_nc: - foo = src_nc.variables['latCell'] - num_cells = len(foo) - # xland = src_nc.variables['xland'] - # lmask[:] = np.where(xland > 0,1,0) - - if dataset_name == "RAVE": - field_names = ("TPM", "FRE", "FRP_MEAN", "PM25", "NH3", "SO2", "CH4","CO","NOx") - # JLS, TODO - NEED TO ACCOUNT FOR EBB1, MORE THAN 24, ETC. - # Determine the cycle dates to process +%Y%m%d%H - dates_needed = [] - for i in range(25): - if ebb_dcycle == 1: # Same-day emissions - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) + timedelta(hours=i) - elif ebb_dcycle == -1 or ebb_dcycle == 2: # Persistence (-1) or forecasted (2) needs prev 24 hours - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) - else: - _LOGGER.info("EBB_DCYLE selection not recognized, reverting to same day, ebb_dcycle = 1") - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) + timedelta(hours=i) - - y = x.strftime("%Y%m%d%H") - dates_needed.append(y) - # - x_center = "grid_lont" - y_center = "grid_latt" - x_dim = "grid_xt" - y_dim = "grid_yt" - x_corner = "grid_lon" - y_corner = "grid_lat" - x_corner_dim = "grid_x" - y_corner_dim = "grid_y" - level_in_name = "None" - # level_in_size = None - level_out_name = "nkwildfire" - level_out_size = 1 - time_name = "time" - time_size = 1 - InterpMethod = "CONSERVE" - elif dataset_name == "NGFS": - field_names = ("FRE", "FRP_MEAN", "PM25") - - # Determine the cycle dates to process +%Y%m%d%H - # This is for RETROS (using current datetime, not day before) - dates_needed = [] - for i in range(25): # GAF retro current day emissions - if ebb_dcycle == 1: # Same-day emissions - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) + timedelta(hours=i) - elif ebb_dcycle == -1 or ebb_dcycle == 2: # Persistence (-1) or forecasted (2) needs prev 24 hours - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) - else: - _LOGGER.info("EBB_DCYLE selection not recognized, reverting to same day, ebb_dcycle = 1") - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) + timedelta(hours=i) - y = x.strftime("%Y%m%d%H") - dates_needed.append(y) - # - x_center = "lon" - y_center = "lat" - x_dim = "point" # Dummy dimension name for context - y_dim = "point" # Dummy dimension name for context - # We set corners to None because the helper calculates them in memory - x_corner = None - y_corner = None - x_corner_dim = None - y_corner_dim = None - level_in_name = "None" - level_out_name = "nkwildfire" - level_out_size = 1 - time_name = "time" - time_size = 1 - InterpMethod = "CONSERVE" - elif dataset_name == "GRA2PES": - field_names = ("PM25-PRI", "PM10-PRI","SO2","CO","NOX","NH3","h_agl") # ,"HC01"=methane BAQMS, summer, 2025 - x_center = "XLONG" # "XLONG_M" - y_center = "XLAT" # "XLAT_M" - x_dim = "west_east" - y_dim = "south_north" - x_corner = "XLONG_C" - y_corner = "XLAT_C" - x_corner_dim = "west_east_stag" - y_corner_dim = "south_north_stag" - level_in_name = "bottom_top" - level_out_name = "nkanthro" - level_out_size = 20 - time_name = "Time" - time_size = 12 - InterpMethod = "CONSERVE" - # InterpMethod = "BILINEAR" - elif dataset_name == "NEMO_ANTHRO": - field_names = ("POC", "PEC", "PMOTHR", "PMC") - x_center = "lon" - y_center = "lat" - x_dim = "COL" - y_dim = "ROW" - x_corner = "lonc" - y_corner = "latc" - x_corner_dim = "COLC" - y_corner_dim = "ROWC" - level_in_name = "LAY" - level_out_name = "nkanthro" - level_out_size = 1 - time_name = "TSTEP" - time_size = 1 - InterpMethod = "CONSERVE" -# InterpMethod = "BILINEAR" - elif dataset_name == "NEMO_RWC": - field_names = ("POC", "PEC", "PMOTHR", "PMC") - x_center = "lon" - y_center = "lat" - x_dim = "COL" - y_dim = "ROW" - x_corner = "lonc" - y_corner = "latc" - x_corner_dim = "COLC" - y_corner_dim = "ROWC" - level_in_name = "None" - level_out_name = "None" - level_out_size = 0 - time_name = "Time" - time_size = 1 - InterpMethod = "CONSERVE" -# InterpMethod = "BILINEAR" - elif dataset_name == "PECM": - field_names = ("DBL_POLL", "ENL_POLL", "GRA_POLL", "RAG_POLL") - x_center = "lon" - y_center = "lat" - x_dim = "COL" - y_dim = "ROW" - x_corner = "lonc" - y_corner = "latc" - x_corner_dim = "COLC" - y_corner_dim = "ROWC" - level_in_name = "None" - level_out_name = "nkbiogenic" - level_out_size = 1 - time_name = "time" - time_size = 1 - InterpMethod = "CONSERVE" - elif dataset_name == "ECOREGION": - field_names = ("ecoregion_ID",) - x_center = "geolon" - y_center = "geolat" - x_dim = "lon" - y_dim = "lat" - x_corner = None - y_corner = None - x_corner_dim = None - y_corner_dim = None - level_in_name = "None" - level_out_name = "nkwildfire" - level_out_size = 1 - time_name = "time" - time_size = 1 - InterpMethod = "NEAREST_STOD" - elif dataset_name == "NARR": - field_names = ("RWC_denominator",) - x_center = "lon" - y_center = "lat" - x_dim = "x" - y_dim = "y" - x_corner = None - y_corner = None - x_corner_dim = None - y_corner_dim = None - level_in_name = "None" - level_out_name = "None" - level_out_size = 0 - time_name = "Time" - time_size = 1 - InterpMethod = "BILINEAR" - elif dataset_name == "FENGSHA_2D": - field_names = ("clayfrac", "sandfrac", "uthres", "ssm") - x_center = "longitude" - y_center = "latitude" - x_dim = "lon" - y_dim = "lat" - x_corner = None - y_corner = None - x_corner_dim = None - y_corner_dim = None - level_in_name = "None" - level_out_name = "None" - level_out_size = 0 - time_name = "None" - time_size = 0 - InterpMethod = "BILINEAR" - elif dataset_name == "FENGSHA_2D_Time": - field_names = ("rdrag",) - x_center = "longitude" - y_center = "latitude" - x_dim = "lon" - y_dim = "lat" - x_corner = None - y_corner = None - x_corner_dim = None - y_corner_dim = None - level_in_name = "None" - level_out_name = "None" - level_out_size = 0 - time_name = "time" - time_size = 12 - InterpMethod = "BILINEAR" - elif dataset_name == "FMC": # fuel moisture content - field_names = ("10h_dead_fuel_moisture_content",) - dates_needed = [] - for i in range(25): - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) - y = x.strftime("%Y%m%d%H") - dates_needed.append(y) - x_center = "longitude" - y_center = "latitude" - x_dim = "nx" - y_dim = "ny" - x_corner = None - y_corner = None - x_corner_dim = None - y_corner_dim = None - level_in_name = "None" - level_out_name = "nkwildfire" - level_out_size = 1 - time_name = "time" - time_size = 1 - InterpMethod = "BILINEAR" - elif dataset_name == "GOES": - field_names = ("AOD",) - x_center = "longitude" - y_center = "latitude" - x_dim = "x" - y_dim = "y" - x_corner = None - y_corner = None - x_corner_dim = None - y_corner_dim = None - level_in_name = "None" - level_out_name = "None" - level_out_size = 0 - time_name = "None" - time_size = 0 - InterpMethod = "BILINEAR" - dates_needed = [] - for i in range(25): - if ebb_dcycle == 1: # Same-day emissions - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) + timedelta(hours=i) - elif ebb_dcycle == -1 or ebb_dcycle == 2: # Persistence (-1) or forecasted (2) needs prev 24 hours - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) - else: - _LOGGER.info("EBB_DCYLE selection not recognized, reverting to same day, ebb_dcycle = 1") - x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) - y = x.strftime("%Y%m%d%H") - dates_needed.append(y) - - weight_path = ctx.get_weight_path(InterpMethod) - - if dataset_name == "RAVE": - processor = None - for date_to_process in dates_needed: - _LOGGER.info(f"RAVE processing {date_to_process=}") - rave_paths = find_latest_rave_file(input_dir, date_to_process, ebb_dcycle, dataset_name, max_lookback_hours=24) - if not rave_paths: - _LOGGER.warn( - f"No matching files found for {date_to_process} (even after lookback).") - continue - _LOGGER.info(f'Reading RAVE file: {rave_paths=}') - rave_path = rave_paths[0] - new_dst_path = output_dir / (mesh_name + "-RAVE-" + date_to_process + ".nc") - - # --- OPTIMIZATION START --- - if processor is None: - _LOGGER.info("FIRST PASS: Full Initialization") - # This pays the "expensive" cost of loading weights/grids, but only once. - - context = RaveToMpasRegridContext( - dataset_name=dataset_name, - workdir=workdir, - src_path=rave_path, - dst_path=dst_path, - new_dst_path=new_dst_path, - desc_stats_out=desc_stats_out, - weight_path=weight_path, - InterpMethod=InterpMethod, - scrip_path=scrip_path, - num_cells=num_cells, - mesh_name=mesh_name, - field_names=field_names, - x_center=x_center, - y_center=y_center, - x_dim=x_dim, - y_dim=y_dim, - x_corner=x_corner, - y_corner=y_corner, - x_corner_dim=x_corner_dim, - y_corner_dim=y_corner_dim, - level_in_name=level_in_name, - # level_in_size=level_in_size, - level_out_name=level_out_name, - level_out_size=level_out_size, - time_name=time_name, - time_size=time_size - - ) - processor = RaveToMpasRegridProcessor(context=context) - processor.initialize() - else: - _LOGGER.info("SUBSEQUENT PASSES: Hot Swap") - # Just update the paths in the existing context. - # The grids and regridder (weights) remain loaded in memory. - processor.context.src_path = rave_path - processor.context.new_dst_path = new_dst_path - # Run the regridding (Fast) - processor.run() - # --- OPTIMIZATION END --- - # Only finalize after ALL files are done - if processor: - processor.finalize() - - _LOGGER.info("success") - - elif dataset_name == "NGFS": - # Initialize context with dummy paths (they get overwritten in the loop) - context = RaveToMpasRegridContext( - dataset_name=dataset_name, - workdir=workdir, - src_path=Path("dummy"), - dst_path=dst_path, - new_dst_path=Path("dummy"), - desc_stats_out=desc_stats_out, - weight_path=weight_path, - InterpMethod=InterpMethod, - scrip_path=scrip_path, - num_cells=num_cells, - mesh_name=mesh_name, - field_names=field_names, - x_center=x_center, y_center=y_center, x_dim=x_dim, y_dim=y_dim, - x_corner=x_corner, y_corner=y_corner, - x_corner_dim=x_corner_dim, y_corner_dim=y_corner_dim, - level_in_name=level_in_name, level_out_name=level_out_name, level_out_size=level_out_size, - time_name=time_name, time_size=time_size - ) - - processor = RaveToMpasRegridProcessor(context=context) - - for date_to_process in dates_needed: + klass = get_regrid_context_class(ctx.dataset_name) + regrid_context = klass( + dataset_name=ctx.dataset_name, + workdir=ctx.workdir, + src_path=Path("dummy"), + dst_path=ctx.rw_dst_path, + new_dst_path=Path("dummy"), + desc_stats_out=ctx.rw_desc_stats_out, + weight_path=ctx.rw_weight_path, + InterpMethod=ctx.rw_dataset.InterpMethod, + input_mesh_path=ctx.rw_input_mesh_path, + mesh_name=ctx.mesh_name, + field_names=ctx.rw_dataset.field_names, + x_center=ctx.rw_dataset.x_center, + y_center=ctx.rw_dataset.y_center, + x_dim=ctx.rw_dataset.x_dim, + y_dim=ctx.rw_dataset.y_dim, + x_corner=ctx.rw_dataset.x_corner, + y_corner=ctx.rw_dataset.y_corner, + x_corner_dim=ctx.rw_dataset.x_corner_dim, + y_corner_dim=ctx.rw_dataset.y_corner_dim, + level_in_name=ctx.rw_dataset.level_in_name, + level_out_name=ctx.rw_dataset.level_out_name, + level_out_size=ctx.rw_dataset.level_out_size, + time_name=ctx.rw_dataset.time_name, + time_size=ctx.rw_dataset.time_size, + cycle=ctx.cycle, + ebb_dcycle=ctx.ebb_dcycle, + input_dir=ctx.input_dir, + output_dir=ctx.output_dir, + ) + + if ctx.dataset_name == "NGFS": + processor = ChemRegridProcessor(context=regrid_context) + + for date_to_process in regrid_context.dates_needed: # Construct the filename (Adjust the prefix 'ngfs_' if your files are named differently) # print("GAF debug: attempting to read: " + input_dir + "/NGFS_v0.31_" + date_to_process + "_0p01.nc") - ngfs_paths = glob.glob(str(input_dir) + "/NGFS_v0.31_0p01_" + date_to_process + "0000.nc") + ngfs_paths = glob.glob(str(ctx.input_dir) + "/NGFS_v0.31_0p01_" + date_to_process + "0000.nc") if not ngfs_paths: print(f"ERROR: Missing NGFS file for {date_to_process}. Skipping.") @@ -1342,7 +564,7 @@ def main(ctx: ChemRegridContext) -> None: continue ngfs_path = Path(ngfs_paths[0]) - new_dst_path = Path(str(output_dir) + "/" + mesh_name + "-NGFS-" + date_to_process + ".nc") + new_dst_path = Path(str(ctx.output_dir) + "/" + ctx.mesh_name + "-NGFS-" + date_to_process + ".nc") print(f"GAF reading NGFS file: {ngfs_path}") # Update context paths for the current hour @@ -1353,256 +575,6 @@ def main(ctx: ChemRegridContext) -> None: # Note that resolution is hard coded... processor.process_ngfs_file(ngfs_path, resolution=0.01) - _LOGGER.info("NGFS success") - - elif dataset_name == "GOES": - processor = None - date_to_process = dates_needed[0] - rave_paths = find_latest_rave_file(input_dir, date_to_process, -1, dataset_name, max_lookback_hours=2) - files_to_cat = rave_paths - _LOGGER.info(f"will cat files: {files_to_cat=}") - if COMM.rank == 0: - with xr.open_mfdataset(files_to_cat, combine='nested', concat_dim='file') as ds: - # 2. Calculate the nanmean across the new 'file' dimension - # skipna=True (default) ensures it behaves like np.nanmean - ds_averaged = ds['AOD'].mean(dim='file', skipna=True) - # _LOGGER.debug(ds_averaged) - ds_averaged.encoding.update({ - 'dtype': 'float32', - '_FillValue': -999 - }) - ds_averaged.to_netcdf(output_dir / 'test_goes_aod_merged.nc') - - if not rave_paths: - msg = f"No matching GOES files found for {date_to_process} (even after lookback)." - _LOGGER.error(msg) - raise ValueError(msg) - - _LOGGER.info('Reading merged GOES file: test_goes_aod_merged.nc') - #rave_path = rave_paths[0] - rave_path = output_dir / "test_goes_aod_merged.nc" - new_dst_path = output_dir / (mesh_name + "-GOES-" + date_to_process + ".nc") - # --- OPTIMIZATION START --- - if processor is None: - # FIRST PASS: Full Initialization - # This pays the "expensive" cost of loading weights/grids, but only once. - - context = RaveToMpasRegridContext( - dataset_name=dataset_name, - workdir=workdir, - src_path=rave_path, - dst_path=dst_path, - new_dst_path=new_dst_path, - desc_stats_out=desc_stats_out, - weight_path=weight_path, - InterpMethod=InterpMethod, - scrip_path=scrip_path, - num_cells=num_cells, - mesh_name=mesh_name, - field_names=field_names, - x_center=x_center, - y_center=y_center, - x_dim=x_dim, - y_dim=y_dim, - x_corner=x_corner, - y_corner=y_corner, - x_corner_dim=x_corner_dim, - y_corner_dim=y_corner_dim, - level_in_name=level_in_name, - level_out_name=level_out_name, - level_out_size=level_out_size, - time_name=time_name, - time_size=time_size - - ) - processor = RaveToMpasRegridProcessor(context=context) - processor.initialize() - else: - # SUBSEQUENT PASSES: Hot Swap - # Just update the paths in the existing context. - # The grids and regridder (weights) remain loaded in memory. - processor.context.src_path = rave_path - processor.context.new_dst_path = new_dst_path - # Run the regridding (Fast) - processor.run() - # --- OPTIMIZATION END --- - # Only finalize after ALL files are done - if processor: - processor.finalize() - - _LOGGER.info("success") - - elif dataset_name == "FMC": - for date_to_process in dates_needed: - rave_paths = glob.glob(str(input_dir / ("fmc_" + date_to_process + ".nc"))) - rave_path = Path(rave_paths[0]) - new_dst_path = output_dir / ("fmc_" + date_to_process + "_" + mesh_name + ".nc") - - context = RaveToMpasRegridContext( - dataset_name=dataset_name, - workdir=workdir, - src_path=rave_path, - dst_path=dst_path, - new_dst_path=new_dst_path, - desc_stats_out=desc_stats_out, - weight_path=weight_path, - InterpMethod=InterpMethod, - scrip_path=scrip_path, - num_cells=num_cells, - mesh_name=mesh_name, - field_names=field_names, - x_center=x_center, - y_center=y_center, - x_dim=x_dim, - y_dim=y_dim, - x_corner=x_corner, - y_corner=y_corner, - x_corner_dim=x_corner_dim, - y_corner_dim=y_corner_dim, - level_in_name=level_in_name, - level_out_name=level_out_name, - level_out_size=level_out_size, - time_name=time_name, - time_size=time_size - - ) - processor = RaveToMpasRegridProcessor(context=context) - processor.initialize() - processor.run() - processor.finalize() - - _LOGGER.info("success") -# - elif dataset_name == "GRA2PES": - rave_path = input_dir / ("GRA2PESv1.0_total_2021" + MM + "_" + DOWs + "_00to11Z.nc") - new_dst_path = output_dir / (dataset_name + "v1.0_total_" + mesh_name + "_00to11Z.nc") - context = RaveToMpasRegridContext( - dataset_name=dataset_name, - workdir=workdir, - src_path=rave_path, - dst_path=dst_path, - new_dst_path=new_dst_path, - desc_stats_out=desc_stats_out, - weight_path=weight_path, - InterpMethod=InterpMethod, - scrip_path=scrip_path, - num_cells=num_cells, - mesh_name=mesh_name, - field_names=field_names, - x_center=x_center, - y_center=y_center, - x_dim=x_dim, - y_dim=y_dim, - x_corner=x_corner, - y_corner=y_corner, - x_corner_dim=x_corner_dim, - y_corner_dim=y_corner_dim, - level_in_name=level_in_name, - level_out_name=level_out_name, - level_out_size=level_out_size, - time_name=time_name, - time_size=time_size - - ) - processor = RaveToMpasRegridProcessor(context=context) - processor.initialize() - processor.run() - processor.finalize() - - _LOGGER.info("success") - - rave_path = input_dir / ("GRA2PESv1.0_total_2021" + MM + "_" + DOWs + "_12to23Z.nc") - new_dst_path = output_dir / (dataset_name + "v1.0_total_" + mesh_name + "_12to23Z.nc") - context = RaveToMpasRegridContext( - dataset_name=dataset_name, - workdir=workdir, - src_path=rave_path, - dst_path=dst_path, - new_dst_path=new_dst_path, - desc_stats_out=desc_stats_out, - weight_path=weight_path, - InterpMethod=InterpMethod, - scrip_path=scrip_path, - num_cells=num_cells, - mesh_name=mesh_name, - field_names=field_names, - x_center=x_center, - y_center=y_center, - x_dim=x_dim, - y_dim=y_dim, - x_corner=x_corner, - y_corner=y_corner, - x_corner_dim=x_corner_dim, - y_corner_dim=y_corner_dim, - level_in_name=level_in_name, - level_out_name=level_out_name, - level_out_size=level_out_size, - time_name=time_name, - time_size=time_size - - ) - processor = RaveToMpasRegridProcessor(context=context) - processor.initialize() - processor.run() - processor.finalize() - - _LOGGER.info("success") - + CR_LOGGER.info("NGFS success") else: - if dataset_name == "PECM": - rave_path = input_dir / ("pollen_obs_" + YYYY + "_BELD6_ef_T_" + JJJ + ".nc") - new_dst_path = output_dir / ("pollen_ef_" + mesh_name + "_" + YYYY + "_" + JJJ + ".nc") - elif dataset_name == "NEMO_RWC": - rave_path = input_dir / "NEMO_RWC_POC_PEC_PMOTHR.annual.2017.nc" - new_dst_path = output_dir / ("NEMO_RWC_ANNUAL_TOTAL_" + mesh_name + ".nc") - elif dataset_name == "NEMO_ANTHRO": - rave_path = input_dir / ("NEMO_ANTHRO_" + mesh_name + "_" + YYYY + MM + DD + HH + "_SECTORSUM.nc") - new_dst_path = output_dir / ("NEMO_ANTHRO_" + mesh_name + ".nc") - elif dataset_name == "NARR": - rave_path = input_dir / "rwc_emission_denominator.2017.nc" - new_dst_path = output_dir / ("NEMO_RWC_DENOMINATOR_2017_" + mesh_name + ".nc") - elif dataset_name == "ECOREGION": - rave_path = input_dir / "veg_map.nc" - new_dst_path = output_dir / ("ecoregions_" + mesh_name + "_mpas.nc") - elif dataset_name == "FENGSHA_2D": - rave_path = input_dir / "FENGSHA_RRFS_NA_3km_2026_2D.nc" - new_dst_path = output_dir / ("fengsha_dust_inputs.2D."+ mesh_name + ".nc") - elif dataset_name == "FENGSHA_2D_Time": - rave_path = input_dir / "FENGSHA_RRFS_NA_3km_2026_2D_Time.nc" - new_dst_path = output_dir / ("fengsha_dust_inputs.2D_Time."+ mesh_name + ".nc") - - context = RaveToMpasRegridContext( - dataset_name=dataset_name, - workdir=workdir, - src_path=rave_path, - dst_path=dst_path, - new_dst_path=new_dst_path, - desc_stats_out=desc_stats_out, - weight_path=weight_path, - InterpMethod=InterpMethod, - scrip_path=scrip_path, - num_cells=num_cells, - mesh_name=mesh_name, - field_names=field_names, - x_center=x_center, - y_center=y_center, - x_dim=x_dim, - y_dim=y_dim, - x_corner=x_corner, - y_corner=y_corner, - x_corner_dim=x_corner_dim, - y_corner_dim=y_corner_dim, - level_in_name=level_in_name, - # level_in_size=level_in_size, - level_out_name=level_out_name, - level_out_size=level_out_size, - time_name=time_name, - time_size=time_size - - ) - processor = RaveToMpasRegridProcessor(context=context) - processor.initialize() - processor.run() - processor.finalize() - - _LOGGER.info("success") + run_regridding(regrid_context) diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_rrfs.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_rrfs.py index 233c37e..d09ffc9 100644 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_rrfs.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_rrfs.py @@ -3,8 +3,8 @@ from pydantic_settings import BaseSettings +from regrid_wrapper.app.chem_regrid.chem_regrid_context import ChemRegridContext from regrid_wrapper.app.chem_regrid.chem_regrid_impl import main as chem_regrid_impl_main -from regrid_wrapper.app.chem_regrid.context import ChemRegridContext from regrid_wrapper.context.logging import LOGGER @@ -26,12 +26,12 @@ def main() -> None: "ebb_dcycle": env.ebb_dcycle, "fcst_length": env.fcst_length, "mesh_name": env.mesh_name, - "scrip_path": None, + "input_mesh_path": None, "dst_path": None, } try: - data["scrip_path"] = sys.argv[7] # Path to the input SCRIP/UGRID domain grid file + data["input_mesh_path"] = sys.argv[7] # Path to the input domain grid file (UGRID) try: data["dst_path"] = sys.argv[8] # Path to the destination grid (e.g., init.nc) except IndexError: diff --git a/src/regrid_wrapper/app/chem_regrid/context.py b/src/regrid_wrapper/app/chem_regrid/context.py deleted file mode 100644 index 538068b..0000000 --- a/src/regrid_wrapper/app/chem_regrid/context.py +++ /dev/null @@ -1,76 +0,0 @@ -from abc import ABC -from enum import StrEnum, unique -from functools import cached_property -from pathlib import Path -from typing import TypeVar - -import yaml -from pydantic import BaseModel, Field - -T = TypeVar("T", bound="RwBaseModel") - - -class RwBaseModel(ABC, BaseModel): - model_config = {"frozen": True} - - @classmethod - def from_yaml(cls: type[T], data: dict) -> T: - return cls.model_validate(data) - - @classmethod - def from_yaml_file(cls: type[T], path: Path) -> T: - string_data = path.read_text() - yaml_data = yaml.safe_load(string_data) - return cls.from_yaml(yaml_data) - - -@unique -class DatasetName(StrEnum): - RAVE = "RAVE" - GRA2PES = "GRA2PES" - NEMO_RWC = "NEMO_RWC" - NEMO_ANTHRO = "NEMO_ANTHRO" - FMC = "FMC" - PECM = "PECM" - NARR = "NARR" - ECOREGION = "ECOREGION" - FENGSHA_2D = "FENGSHA_2D" - FENGSHA_2D_Time = "FENGSHA_2D_Time" - NGFS = "NGFS" - GOES = "GOES" - - -class ChemRegridContext(RwBaseModel): - dataset_name: DatasetName - workdir: Path - input_dir: Path - output_dir: Path - weight_dir: Path - cycle: str = Field(pattern=r"^\d{10}$") # Validates YYYYMMDDHH format - mesh_name: str - scrip_path: Path | None - dst_path: Path | None - ebb_dcycle: int - fcst_length: int - - @cached_property - def rw_scrip_path(self) -> Path: - if self.scrip_path is None: - return self.workdir / f"mpas_{self.dataset_name.value}-{self.mesh_name}_scrip.nc" - return self.scrip_path - - @cached_property - def rw_dst_path(self) -> Path: - if self.dst_path is None: - return self.workdir / "init.nc" - return self.dst_path - - @cached_property - def rw_desc_stats_out(self) -> Path: - return self.workdir / f"desc_stats-{self.cycle}.csv" - - def get_weight_path(self, interp_method: str) -> Path: - weight_path = self.weight_dir / ( - "weights_" + self.dataset_name.value + "-to-" + "mpas_" + self.mesh_name + "_" + interp_method + ".nc" - ) - return weight_path diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/__init__.py b/src/regrid_wrapper/app/chem_regrid/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/config/__init__.py b/src/regrid_wrapper/app/chem_regrid/dataset/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/config/datasets.yml b/src/regrid_wrapper/app/chem_regrid/dataset/config/datasets.yml new file mode 100644 index 0000000..0c00d72 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/config/datasets.yml @@ -0,0 +1,232 @@ +RAVE: + field_names: + - TPM + - FRE + - FRP_MEAN + - PM25 + - NH3 + - SO2 + - CH4 + - CO + - NOx + x_center: grid_lont + y_center: grid_latt + x_dim: grid_xt + y_dim: grid_yt + x_corner: grid_lon + y_corner: grid_lat + x_corner_dim: grid_x + y_corner_dim: grid_y + level_in_name: null + level_out_name: nkwildfire + level_out_size: 1 + time_name: time + time_size: 1 + InterpMethod: CONSERVE +NGFS: + field_names: + - FRE + - FRP_MEAN + - PM25 + x_center: lon + y_center: lat + x_dim: point # Dummy dimension name for context + y_dim: point # Dummy dimension name for context + x_corner: null # We set corners to None because the helper calculates them in memory + y_corner: null # We set corners to None because the helper calculates them in memory + x_corner_dim: null + y_corner_dim: null + level_in_name: null + level_out_name: nkwildfire + level_out_size: 1 + time_name: time + time_size: 1 + InterpMethod: CONSERVE +GRA2PES: + field_names: + - PM25-PRI + - PM10-PRI + - SO2 + - CO + - NOX + - NH3 + - h_agl + x_center: XLONG + y_center: XLAT + x_dim: west_east + y_dim: south_north + x_corner: XLONG_C + y_corner: XLAT_C + x_corner_dim: west_east_stag + y_corner_dim: south_north_stag + level_in_name: bottom_top + level_out_name: nkanthro + level_out_size: 20 + time_name: Time + time_size: 12 + InterpMethod: CONSERVE +NEMO_ANTHRO: + field_names: + - POC + - PEC + - PMOTHR + - PMC + x_center: lon + y_center: lat + x_dim: COL + y_dim: ROW + x_corner: lonc + y_corner: latc + x_corner_dim: COLC + y_corner_dim: ROWC + level_in_name: LAY + level_out_name: nkanthro + level_out_size: 1 + time_name: TSTEP + time_size: 1 + InterpMethod: CONSERVE +NEMO_RWC: + field_names: + - POC + - PEC + - PMOTHR + - PMC + x_center: lon + y_center: lat + x_dim: COL + y_dim: ROW + x_corner: lonc + y_corner: latc + x_corner_dim: COLC + y_corner_dim: ROWC + level_in_name: null + level_out_name: null + level_out_size: 0 + time_name: Time + time_size: 1 + InterpMethod: CONSERVE +PECM: + field_names: + - DBL_POLL + - ENL_POLL + - GRA_POLL + - RAG_POLL + x_center: lon + y_center: lat + x_dim: COL + y_dim: ROW + x_corner: lonc + y_corner: latc + x_corner_dim: COLC + y_corner_dim: ROWC + level_in_name: null + level_out_name: nkbiogenic + level_out_size: 1 + time_name: time + time_size: 1 + InterpMethod: CONSERVE +ECOREGION: + field_names: + - ecoregion_ID + x_center: geolon + y_center: geolat + x_dim: lon + y_dim: lat + x_corner: null + y_corner: null + x_corner_dim: null + y_corner_dim: null + level_in_name: null + level_out_name: nkwildfire + level_out_size: 1 + time_name: time + time_size: 1 + InterpMethod: NEAREST_STOD +NARR: + field_names: + - RWC_denominator + x_center: lon + y_center: lat + x_dim: x + y_dim: y + x_corner: null + y_corner: null + x_corner_dim: null + y_corner_dim: null + level_in_name: null + level_out_name: null + level_out_size: 0 + time_name: Time + time_size: 1 + InterpMethod: BILINEAR +FENGSHA_2D: + field_names: + - clayfrac + - sandfrac + - uthres + - ssm + x_center: longitude + y_center: latitude + x_dim: lon + y_dim: lat + x_corner: null + y_corner: null + x_corner_dim: null + y_corner_dim: null + level_in_name: null + level_out_name: null + level_out_size: 0 + time_name: null + time_size: 0 + InterpMethod: BILINEAR +FENGSHA_2D_Time: + field_names: + - rdrag + x_center: longitude + y_center: latitude + x_dim: lon + y_dim: lat + x_corner: null + y_corner: null + x_corner_dim: null + y_corner_dim: null + level_in_name: null + level_out_name: null + level_out_size: 0 + time_name: time + time_size: 12 + InterpMethod: BILINEAR +FMC: + field_names: + - 10h_dead_fuel_moisture_content + x_center: longitude + y_center: latitude + x_dim: nx + y_dim: ny + x_corner: null + y_corner: null + x_corner_dim: null + y_corner_dim: null + level_in_name: null + level_out_name: nkwildfire + level_out_size: 1 + time_name: time + time_size: 1 + InterpMethod: BILINEAR +GOES: + field_names: + - AOD + x_center: longitude + y_center: latitude + x_dim: x + y_dim: y + x_corner: null + y_corner: null + x_corner_dim: null + y_corner_dim: null + level_in_name: null + level_out_name: null + level_out_size: 0 + time_name: null + time_size: 0 + InterpMethod: BILINEAR diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/config/model.py b/src/regrid_wrapper/app/chem_regrid/dataset/config/model.py new file mode 100644 index 0000000..d7441e6 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/config/model.py @@ -0,0 +1,34 @@ +from pathlib import Path + +from regrid_wrapper.app.chem_regrid.dataset.context import DatasetName +from regrid_wrapper.app.chem_regrid.dataset.context.base import InterpMethod +from regrid_wrapper.common import RwBaseModel + + +class ChemRegridDataset(RwBaseModel): + key: DatasetName + field_names: tuple[str, ...] + x_center: str + y_center: str + x_dim: str + y_dim: str + x_corner: str | None + y_corner: str | None + x_corner_dim: str | None + y_corner_dim: str | None + level_in_name: str | None + level_out_name: str | None + level_out_size: int | None + time_name: str | None + time_size: int | None + InterpMethod: InterpMethod + + @classmethod + def from_key(cls, yaml_path: Path, key: DatasetName) -> "ChemRegridDataset": + + def retriever(data: dict) -> dict: + ret = data[key.value] + ret["key"] = key + return ret + + return cls.from_yaml_file(yaml_path, retriever=retriever) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/__init__.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/__init__.py new file mode 100644 index 0000000..543c233 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/__init__.py @@ -0,0 +1,47 @@ +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + DatasetName, +) +from regrid_wrapper.app.chem_regrid.dataset.context.ecoregion import ( + ECOREGION_DatasetRegridContext, +) +from regrid_wrapper.app.chem_regrid.dataset.context.fengsha_2d import ( + FENGSHA_2D_DatasetRegridContext, +) +from regrid_wrapper.app.chem_regrid.dataset.context.fengsha_2d_time import ( + FENGSHA_2D_Time_DatasetRegridContext, +) +from regrid_wrapper.app.chem_regrid.dataset.context.fmc import FMC_DatasetRegridContext +from regrid_wrapper.app.chem_regrid.dataset.context.goes import GOES_DatasetRegridContext +from regrid_wrapper.app.chem_regrid.dataset.context.gra2pes import ( + GRA2PES_DatasetRegridContext, +) +from regrid_wrapper.app.chem_regrid.dataset.context.narr import NARR_DatasetRegridContext +from regrid_wrapper.app.chem_regrid.dataset.context.nemo_anthro import ( + NEMO_ANTHRO_DatasetRegridContext, +) +from regrid_wrapper.app.chem_regrid.dataset.context.nemo_rwc import ( + NEMO_RWC_DatasetRegridContext, +) +from regrid_wrapper.app.chem_regrid.dataset.context.ngfs import NGFS_DatasetRegridContext +from regrid_wrapper.app.chem_regrid.dataset.context.pecm import PECM_DatasetRegridContext +from regrid_wrapper.app.chem_regrid.dataset.context.rave import RAVE_DatasetRegridContext + + +def get_regrid_context_class(name: DatasetName) -> type[AbstractDatasetRegridContext]: + """Factory function to return the appropriate context class for a given dataset name.""" + klasses = { + DatasetName.RAVE: RAVE_DatasetRegridContext, + DatasetName.GRA2PES: GRA2PES_DatasetRegridContext, + DatasetName.FMC: FMC_DatasetRegridContext, + DatasetName.NEMO_RWC: NEMO_RWC_DatasetRegridContext, + DatasetName.NEMO_ANTHRO: NEMO_ANTHRO_DatasetRegridContext, + DatasetName.PECM: PECM_DatasetRegridContext, + DatasetName.NARR: NARR_DatasetRegridContext, + DatasetName.ECOREGION: ECOREGION_DatasetRegridContext, + DatasetName.FENGSHA_2D: FENGSHA_2D_DatasetRegridContext, + DatasetName.FENGSHA_2D_Time: FENGSHA_2D_Time_DatasetRegridContext, + DatasetName.GOES: GOES_DatasetRegridContext, + DatasetName.NGFS: NGFS_DatasetRegridContext, + } + return klasses[name] diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/base.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/base.py new file mode 100644 index 0000000..1aa9b86 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/base.py @@ -0,0 +1,243 @@ +import glob +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from enum import StrEnum, unique +from functools import cached_property +from pathlib import Path +from typing import Any, Iterator, Union + +import esmpy +import numpy as np +from pydantic import BaseModel + +from regrid_wrapper.app.chem_regrid import CR_LOGGER +from regrid_wrapper.app.chem_regrid.dataset.src_field import SrcField +from regrid_wrapper.context.comm import COMM +from regrid_wrapper.esmpy.field_wrapper import ( + DimensionCollection, + FieldWrapper, + HasNcAttrsType, + open_nc, +) + + +class DateTimeSpec(BaseModel): + """Container for datetime components used in file path formatting.""" + + yyyy: str + mm: str + dd: str + hh: str + jjj: str + dowh: int + dows: str + datetime: datetime + + +class RegridFilePair(BaseModel): + """Pair of source and destination paths for a regridding operation.""" + + src_path: Path + dst_path: Path + + +@unique +class DatasetName(StrEnum): + RAVE = "RAVE" + GRA2PES = "GRA2PES" + NEMO_RWC = "NEMO_RWC" + NEMO_ANTHRO = "NEMO_ANTHRO" + FMC = "FMC" + PECM = "PECM" + NARR = "NARR" + ECOREGION = "ECOREGION" + FENGSHA_2D = "FENGSHA_2D" + FENGSHA_2D_Time = "FENGSHA_2D_Time" + NGFS = "NGFS" + GOES = "GOES" + + +@unique +class InterpMethod(StrEnum): + CONSERVE = "CONSERVE" + CONSERVE_2ND = "CONSERVE_2ND" + BILINEAR = "BILINEAR" + NEAREST_STOD = "NEAREST_STOD" + + +class AbstractDatasetRegridContext(ABC, BaseModel): + """Abstract base class for dataset-specific regridding configurations and logic.""" + + dataset_name: DatasetName + workdir: Path + src_path: Path + dst_path: Path + new_dst_path: Path + desc_stats_out: Path + weight_path: Path + InterpMethod: InterpMethod + input_mesh_path: Path + mesh_name: str + field_names: tuple + x_center: str + y_center: str + x_dim: str + y_dim: str + x_corner: Union[str, None] + y_corner: Union[str, None] + x_corner_dim: Union[str, None] + y_corner_dim: Union[str, None] + level_in_name: str | None + # level_in_size: int + level_out_name: str | None + level_out_size: int + time_name: str | None + time_size: int + cycle: str + ebb_dcycle: int + input_dir: Path + output_dir: Path + # InterpMask: float + write_desc_stats: bool = False + var_names_to_copy_to_output_file: tuple[str, ...] = ("latCell", "lonCell", "xtime") + time_format: str = "%Y%m%d%H" + search_time_format: str = "%Y%m%d%H" + + rank: int = COMM.rank + + def get_src_grid_path(self) -> Path: + """Returns the path to the source grid file.""" + return self.src_path + + def get_src_field_dims(self, field_name: str) -> tuple[tuple[str, ...] | None, tuple[str, ...] | None]: + """Returns the level and time dimension names for a given field.""" + dim_level = (self.level_in_name,) if self.level_in_name else None + dim_time = (self.time_name,) if self.time_name else None + return dim_level, dim_time + + @abstractmethod + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + """Yields pairs of source and destination paths to be processed.""" + ... + + def update_src_field_wrapper(self, raw_src_fwrap: FieldWrapper) -> None: + """Applies dataset-specific data cleaning or transformations to the source field.""" + src_data = raw_src_fwrap.data + src_data[:] = np.where(src_data < 0.0, 0.0, src_data) + src_data[:] = np.where(np.isnan(src_data), 0.0, src_data) + + @cached_property + def dates_needed(self) -> list[str]: + """Returns a list of dates required for the current cycle.""" + raise NotImplementedError(self.__class__.__name__ + " does not support dates_needed") + + @cached_property + def num_cells(self) -> int: + """Returns the number of cells in the destination mesh.""" + with open_nc(self.dst_path, mode="r", parallel=False) as ds: + return len(ds.variables["latCell"]) + + @cached_property + def dt_spec(self) -> DateTimeSpec: + """Returns a DateTimeSpec based on the current cycle.""" + yyyy = self.cycle[0:4] + mm = self.cycle[4:6] + dd = self.cycle[6:8] + hh = self.cycle[8:10] + x = datetime(int(yyyy), int(mm), int(dd), int(hh), 0, 0) + jjj = x.strftime("%j") + dowh = int(x.strftime("%u")) + if dowh <= 5: + dows = "weekdy" + elif dowh == 6: + dows = "satdy" + else: + dows = "sundy" + return DateTimeSpec(yyyy=yyyy, mm=mm, dd=dd, hh=hh, jjj=jjj, dowh=dowh, dows=dows, datetime=x) + + def get_read_name(self, field_name: str) -> str: + """Returns the variable name to read from the source file for a given field.""" + return field_name + + @cached_property + def src_fields(self) -> tuple[SrcField, ...]: + """Initializes and returns the collection of source fields for the dataset.""" + src_fields = [] + with open_nc(self.src_path, mode="r") as ds: + for field_name in self.field_names: + read_name = self.get_read_name(field_name) + + if read_name not in ds.variables: + raise KeyError(f"Source variable '{read_name}' not found for field '{field_name}' in {self.src_path}") + var = ds.variables[read_name] + init_data = { + "name": field_name, + "attrs": self._get_nc_attrs_(var), + "fill_value": -1.0, + "dtype": var.dtype, + "level_out_name": self.level_out_name, + "level_out_size": self.level_out_size, + "time_size": self.time_size, + "num_cells": self.num_cells, + } + app = SrcField.model_validate(init_data) + src_fields.append(app) + CR_LOGGER.debug(f"{src_fields=}") + return tuple(src_fields) + + @staticmethod + def _get_nc_attrs_(src: HasNcAttrsType) -> dict[str, Any]: + """Extracts and filters netCDF attributes from a variable.""" + exclude = ("coordinates", "valid_range") + return {ii: getattr(src, ii) for ii in src.ncattrs() if not ii.startswith("_") and ii not in exclude} + + def find_latest_src_file(self, target_time_str: str, max_lookback_hours: int = 24) -> list[str]: + """Finds the latest available source file within a lookback window.""" + target_time = datetime.strptime(target_time_str, self.time_format) + + for h in range(max_lookback_hours + 1): + if self.ebb_dcycle == -1 or self.ebb_dcycle == 2: + this_time = target_time - timedelta(hours=h) + elif self.ebb_dcycle == 1: + this_time = target_time + timedelta(hours=h) + else: + CR_LOGGER.warning("unrecognized ebb_dcycle, reverting to same-day, ebb_dcycle = 1") + this_time = target_time + timedelta(hours=h) + + search_path = self._get_src_search_path(this_time) + paths = glob.glob(search_path) + if paths: + if h > 0: + msg = ( + f"Missing {self.dataset_name} file for {target_time_str}, using " + f"{this_time.strftime(self.search_time_format)} instead" + ) + CR_LOGGER.warning(msg) + return paths + # nothing found within lookback window + return [] + + def _get_src_search_path(self, this_time: datetime) -> str: + """Returns the glob pattern for searching source files.""" + return str(self.src_path) + + def transform_regridded_data( + self, + src_field: SrcField, + dst_field_data: np.ndarray, + ds: Any, + reconciled_bounds: tuple[int, int], + dims: DimensionCollection, + ) -> np.ndarray: + """Hook for dataset-specific data transformations after regridding but before writing.""" + return dst_field_data + + def post_regrid_processing( + self, + src_field: SrcField, + regridder: Union[esmpy.Regrid, esmpy.RegridFromFile], + processor: Any, + dims: DimensionCollection, + ) -> None: + """Hook for dataset-specific operations after a field has been regridded and written.""" + pass diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/ecoregion.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/ecoregion.py new file mode 100644 index 0000000..c1720f6 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/ecoregion.py @@ -0,0 +1,16 @@ +from typing import Iterator + +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) + + +class ECOREGION_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for Ecoregion mapping data.""" + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for _ in range(1): + src_path = self.input_dir / "veg_map.nc" + new_dst_path = self.output_dir / ("ecoregions_" + self.mesh_name + "_mpas.nc") + yield RegridFilePair(src_path=src_path, dst_path=new_dst_path) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/fengsha_2d.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/fengsha_2d.py new file mode 100644 index 0000000..33c7a4f --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/fengsha_2d.py @@ -0,0 +1,18 @@ +from typing import Iterator + +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) + + +class FENGSHA_2D_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for FENGSHA 2D dust emission data.""" + + var_names_to_copy_to_output_file: tuple[str, ...] = ("latCell", "lonCell") + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for _ in range(1): + src_path = self.input_dir / "FENGSHA_RRFS_NA_3km_2026_2D.nc" + new_dst_path = self.output_dir / ("fengsha_dust_inputs.2D." + self.mesh_name + ".nc") + yield RegridFilePair(src_path=src_path, dst_path=new_dst_path) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/fengsha_2d_time.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/fengsha_2d_time.py new file mode 100644 index 0000000..3734d09 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/fengsha_2d_time.py @@ -0,0 +1,16 @@ +from typing import Iterator + +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) + + +class FENGSHA_2D_Time_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for FENGSHA 2D dust emission data with time dimension.""" + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for _ in range(1): + src_path = self.input_dir / "FENGSHA_RRFS_NA_3km_2026_2D_Time.nc" + new_dst_path = self.output_dir / ("fengsha_dust_inputs.2D_Time." + self.mesh_name + ".nc") + yield RegridFilePair(src_path=src_path, dst_path=new_dst_path) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/fmc.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/fmc.py new file mode 100644 index 0000000..30474d2 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/fmc.py @@ -0,0 +1,33 @@ +from datetime import datetime, timedelta +from functools import cached_property +from pathlib import Path +from typing import Iterator + +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) + + +class FMC_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for FMC (Fuel Moisture Content) data.""" + + def _get_src_search_path(self, this_time: datetime) -> str: + this_str = this_time.strftime(self.search_time_format) + return str(self.input_dir / ("fmc_" + this_str + ".nc")) + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for date_to_process in self.dates_needed: + src_paths = self.find_latest_src_file(date_to_process) + src_path = Path(src_paths[0]) + new_dst_path = self.output_dir / ("fmc_" + date_to_process + "_" + self.mesh_name + ".nc") + yield RegridFilePair(src_path=src_path, dst_path=new_dst_path) + + @cached_property + def dates_needed(self) -> list[str]: + dates_needed = [] + for i in range(25): + x = self.dt_spec.datetime - timedelta(hours=i) + y = x.strftime("%Y%m%d%H") + dates_needed.append(y) + return dates_needed diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/goes.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/goes.py new file mode 100644 index 0000000..a7e49fc --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/goes.py @@ -0,0 +1,72 @@ +from datetime import datetime, timedelta +from functools import cached_property +from pathlib import Path +from typing import Any, Iterator + +import xarray as xr + +from regrid_wrapper.app.chem_regrid import CR_LOGGER +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) +from regrid_wrapper.esmpy.field_wrapper import HasNcAttrsType + + +class GOES_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for GOES (Geostationary Operational Environmental Satellite) AOD data.""" + + search_time_format: str = "%Y%j%H" + + def get_src_grid_path(self) -> Path: + return self.workdir / "goes19_abi_conus_interpolated_lat_lon.nc" + + @staticmethod + def _get_nc_attrs_(src: HasNcAttrsType) -> dict[str, Any]: + return {} + + def _get_src_search_path(self, this_time: datetime) -> str: + this_str = this_time.strftime(self.search_time_format) + return str(self.input_dir / ("OR_ABI-L2-AODC-M6_G18_s" + this_str + "*")) + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + date_to_process = self.dates_needed[0] + src_paths = self.find_latest_src_file(date_to_process, max_lookback_hours=2) + files_to_cat = src_paths + CR_LOGGER.info(f"will cat files: {files_to_cat=}") + if self.rank == 0: + with xr.open_mfdataset(files_to_cat, combine="nested", concat_dim="file") as ds: + # 2. Calculate the nanmean across the new 'file' dimension + # skipna=True (default) ensures it behaves like np.nanmean + ds_averaged = ds["AOD"].mean(dim="file", skipna=True) + # CR_LOGGER.debug(ds_averaged) + ds_averaged.encoding.update({"dtype": "float32", "_FillValue": -999}) + ds_averaged.to_netcdf(self.output_dir / "test_goes_aod_merged.nc") + + if not src_paths: + msg = f"No matching GOES files found for {date_to_process} (even after lookback)." + CR_LOGGER.error(msg) + raise ValueError(msg) + + CR_LOGGER.info("Reading merged GOES file: test_goes_aod_merged.nc") + # src_path = src_paths[0] + src_path = self.output_dir / "test_goes_aod_merged.nc" + new_dst_path = self.output_dir / (self.mesh_name + "-GOES-" + date_to_process + ".nc") + fp = RegridFilePair(src_path=src_path, dst_path=new_dst_path) + for _ in range(1): + yield fp + + @cached_property + def dates_needed(self) -> list[str]: + dates_needed = [] + for i in range(25): + if self.ebb_dcycle == 1: # Same-day emissions + x = self.dt_spec.datetime + timedelta(hours=i) + elif self.ebb_dcycle == -1 or self.ebb_dcycle == 2: # Persistence (-1) or forecasted (2) needs prev 24 hours + x = self.dt_spec.datetime - timedelta(hours=i) + else: + CR_LOGGER.info("EBB_DCYLE selection not recognized, reverting to same day, ebb_dcycle = 1") + x = self.dt_spec.datetime - timedelta(hours=i) + y = x.strftime("%Y%m%d%H") + dates_needed.append(y) + return dates_needed diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/gra2pes.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/gra2pes.py new file mode 100644 index 0000000..cb130bd --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/gra2pes.py @@ -0,0 +1,49 @@ +from typing import Iterator + +import numpy as np + +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) +from regrid_wrapper.esmpy.field_wrapper import FieldWrapper + + +class GRA2PES_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for GRA2PES (Great Lakes Regional Air Pollution Emissions System) data.""" + + def get_src_field_dims(self, field_name: str) -> tuple[tuple[str, ...] | None, tuple[str, ...] | None]: + dim_level, dim_time = super().get_src_field_dims(field_name) + if field_name == "h_agl": + dim_level = ("bottom_top_stag",) + return dim_level, dim_time + + def update_src_field_wrapper(self, raw_src_fwrap: FieldWrapper) -> None: + """Converts GRA2PES emissions from metric tons/km2/hr or moles/km2/hr to ug/m2/s.""" + field_name = raw_src_fwrap.value.name + src_data = raw_src_fwrap.value.data + + # GRA2PES PM, convert from metric tons/km2/hr to ug/m2/s + if field_name in ("PM25-PRI", "PM10-PRI"): + conv_aer = 1.0e6 / 3600.0 + # GRA2PES methane, convert from moles/km2/hr to ug/m2/s + elif field_name in ("HC01", "SO2", "CO", "NH3", "NOX"): + conv_aer = 1.0e-6 / 3600.0 + else: + conv_aer = 1.0 + + src_data[:] = np.where(src_data < 0.0, 0.0, conv_aer * src_data) + src_data[:] = np.where(np.isnan(src_data), 0.0, src_data) + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + # Define the parts that change + suffixes = ["00to11Z", "12to23Z"] + + # Common string components + src_prefix = f"GRA2PESv1.0_total_2021{self.dt_spec.mm}_{self.dt_spec.dows}_" + dst_prefix = f"{self.dataset_name}v1.0_total_{self.mesh_name}_" + + for suffix in suffixes: + yield RegridFilePair( + src_path=self.input_dir / f"{src_prefix}{suffix}.nc", dst_path=self.output_dir / f"{dst_prefix}{suffix}.nc" + ) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/narr.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/narr.py new file mode 100644 index 0000000..1dd6b14 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/narr.py @@ -0,0 +1,16 @@ +from typing import Iterator + +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) + + +class NARR_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for NARR (North American Regional Reanalysis) data.""" + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for _ in range(1): + src_path = self.input_dir / "rwc_emission_denominator.2017.nc" + new_dst_path = self.output_dir / ("NEMO_RWC_DENOMINATOR_2017_" + self.mesh_name + ".nc") + yield RegridFilePair(src_path=src_path, dst_path=new_dst_path) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/nemo_anthro.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/nemo_anthro.py new file mode 100644 index 0000000..430bb2c --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/nemo_anthro.py @@ -0,0 +1,42 @@ +from typing import Iterator + +import numpy as np + +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) +from regrid_wrapper.esmpy.field_wrapper import FieldWrapper + + +class NEMO_ANTHRO_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for NEMO Anthropogenic emissions data.""" + + def update_src_field_wrapper(self, raw_src_fwrap: FieldWrapper) -> None: + """Converts NEMO Anthropogenic emissions from g/s/km2 to ug/m2/s.""" + field_name = raw_src_fwrap.value.name + src_data = raw_src_fwrap.value.data + + if field_name in ("PEC", "POC", "PMOTHR", "PMC"): + # Convert g/s/km2 to ug/m2/s --> + conv_aer = 1.0 + else: + conv_aer = 1.0 + + src_data[:] = np.where(src_data < 0.0, 0.0, conv_aer * src_data) + src_data[:] = np.where(np.isnan(src_data), 0.0, src_data) + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for _ in range(1): + src_path = self.input_dir / ( + "NEMO_ANTHRO_" + + self.mesh_name + + "_" + + self.dt_spec.yyyy + + self.dt_spec.mm + + self.dt_spec.dd + + self.dt_spec.hh + + "_SECTORSUM.nc" + ) + new_dst_path = self.output_dir / ("NEMO_ANTHRO_" + self.mesh_name + ".nc") + yield RegridFilePair(src_path=src_path, dst_path=new_dst_path) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/nemo_rwc.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/nemo_rwc.py new file mode 100644 index 0000000..0dd33e3 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/nemo_rwc.py @@ -0,0 +1,33 @@ +from typing import Iterator + +import numpy as np + +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) +from regrid_wrapper.esmpy.field_wrapper import FieldWrapper + + +class NEMO_RWC_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for NEMO RWC (Residential Wood Combustion) data.""" + + def update_src_field_wrapper(self, raw_src_fwrap: FieldWrapper) -> None: + """Converts NEMO RWC emissions from g/s/km2 to ug/m2/s.""" + field_name = raw_src_fwrap.value.name + src_data = raw_src_fwrap.value.data + + if field_name in ("PEC", "POC", "PMOTHR", "PMC"): + # Convert g/s/km2 (on 1km grid) to ug/m2/s --> + conv_aer = 1.0 + else: + conv_aer = 1.0 + + src_data[:] = np.where(src_data < 0.0, 0.0, conv_aer * src_data) + src_data[:] = np.where(np.isnan(src_data), 0.0, src_data) + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for _ in range(1): + src_path = self.input_dir / "NEMO_RWC_POC_PEC_PMOTHR.annual.2017.nc" + new_dst_path = self.output_dir / ("NEMO_RWC_ANNUAL_TOTAL_" + self.mesh_name + ".nc") + yield RegridFilePair(src_path=src_path, dst_path=new_dst_path) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/ngfs.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/ngfs.py new file mode 100644 index 0000000..6ecc5b9 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/ngfs.py @@ -0,0 +1,38 @@ +from datetime import timedelta +from functools import cached_property +from typing import Iterator + +from regrid_wrapper.app.chem_regrid import CR_LOGGER +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) + + +class NGFS_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for NGFS (Next Generation Fire System) data.""" + + def get_read_name(self, field_name: str) -> str: + if field_name == "PM25": + return "EMIS_PM25" + return super().get_read_name(field_name) + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + raise NotImplementedError("NGFS not yet supported") + + @cached_property + def dates_needed(self) -> list[str]: + dates_needed = [] + # Determine the cycle dates to process +%Y%m%d%H + # This is for RETROS (using current datetime, not day before) + for i in range(25): # GAF retro current day emissions + if self.ebb_dcycle == 1: # Same-day emissions + x = self.dt_spec.datetime + timedelta(hours=i) + elif self.ebb_dcycle == -1 or self.ebb_dcycle == 2: # Persistence (-1) or forecasted (2) needs prev 24 hours + x = self.dt_spec.datetime - timedelta(hours=i) + else: + CR_LOGGER.info("EBB_DCYLE selection not recognized, reverting to same day, ebb_dcycle = 1") + x = self.dt_spec.datetime + timedelta(hours=i) + y = x.strftime("%Y%m%d%H") + dates_needed.append(y) + return dates_needed diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/pecm.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/pecm.py new file mode 100644 index 0000000..83c6796 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/pecm.py @@ -0,0 +1,67 @@ +from typing import Any, Iterator, Union + +import esmpy + +from regrid_wrapper.app.chem_regrid import CR_LOGGER +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) +from regrid_wrapper.app.chem_regrid.dataset.src_field import SrcField +from regrid_wrapper.esmpy.field_wrapper import ( + DimensionCollection, + open_nc, + set_variable_data, +) + + +class PECM_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for PECM (Pollen Emissions for Climate Models) data.""" + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for _ in range(1): + src_path = self.input_dir / ("pollen_obs_" + self.dt_spec.yyyy + "_BELD6_ef_T_" + self.dt_spec.jjj + ".nc") + new_dst_path = self.output_dir / ( + "pollen_ef_" + self.mesh_name + "_" + self.dt_spec.yyyy + "_" + self.dt_spec.jjj + ".nc" + ) + yield RegridFilePair(src_path=src_path, dst_path=new_dst_path) + + def post_regrid_processing( + self, + src_field: SrcField, + regridder: Union[esmpy.Regrid, esmpy.RegridFromFile], + processor: Any, + dims: DimensionCollection, + ) -> None: + if src_field.name == "ENL_POLL": + with open_nc(self.new_dst_path, mode="a") as ds: + CR_LOGGER.info("renaming and combining tree fields") + + src_fwrap_enl = processor.create_src_field_wrapper(field_name="ENL_POLL") + dst_field_enl = processor.get_dst_field() + dst_field_enl.data.fill(0.0) + regridder(src_fwrap_enl.value, dst_field_enl) + data_enl = src_field.reshape_field_data(dst_field_enl.data).copy() + + src_fwrap_dbl = processor.create_src_field_wrapper(field_name="DBL_POLL") + dst_field_dbl = processor.get_dst_field() + dst_field_dbl.data.fill(0.0) + regridder(src_fwrap_dbl.value, dst_field_dbl) + data_dbl = src_field.reshape_field_data(dst_field_dbl.data) + + var = ds.createVariable( + "TREE_POLL", + src_field.dtype, + [dim.name[0] for dim in dims.value], + fill_value=src_field.fill_value, + ) + for k, v in src_field.attrs.items(): + setattr(var, k, v) + set_variable_data( + var, + dims, + data_enl + data_dbl, + collective=True, + ) + src_fwrap_enl.value.destroy() + src_fwrap_dbl.value.destroy() diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/rave.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/rave.py new file mode 100644 index 0000000..5bcd830 --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/rave.py @@ -0,0 +1,170 @@ +from datetime import datetime, timedelta +from functools import cached_property +from pathlib import Path +from typing import Any, Iterator, Union + +import esmpy +import numpy as np +from pydantic import PrivateAttr + +from regrid_wrapper.app.chem_regrid import CR_LOGGER +from regrid_wrapper.app.chem_regrid.dataset.context.base import ( + AbstractDatasetRegridContext, + RegridFilePair, +) +from regrid_wrapper.app.chem_regrid.dataset.src_field import SrcField +from regrid_wrapper.esmpy.field_wrapper import ( + DimensionCollection, + FieldWrapper, + NcToField, + open_nc, + set_variable_data, +) + + +class RAVE_DatasetRegridContext(AbstractDatasetRegridContext): + """Regrid context for RAVE (Regional Real-time Biomass Burning Emissions) data.""" + + var_names_to_copy_to_output_file: tuple[str, ...] = ("latCell", "lonCell", "areaCell", "xtime") + _area_data: np.ndarray | None = PrivateAttr(default=None) + + def get_area_data(self, raw_src_fwrap: FieldWrapper) -> np.ndarray: + """Loads and returns area information from the RAVE source file.""" + if self._area_data is None: + area_fwrap = NcToField( + path=self.src_path, + name="area", + gwrap=raw_src_fwrap.gwrap, + dim_time=None, + ).create_field_wrapper() + self._area_data = area_fwrap.data + return self._area_data + + def update_src_field_wrapper(self, raw_src_fwrap: FieldWrapper) -> None: + field_name = raw_src_fwrap.value.name + src_data = raw_src_fwrap.data + + # RAVE methane, convert from kg/hr to mol/m2/s + if field_name == "CH4": + conv_aer = (1.0 / 16.0) * 1000.0 + elif field_name == "CO": + conv_aer = (1.0 / 28.0) * 1000.0 + elif field_name == "NH3": + conv_aer = (1.0 / 17.0) * 1000.0 + elif field_name == "NOx": + conv_aer = ((1.0 / 30.0) + (1.0 / 46.0)) / 2.0 * 1000.0 + else: + conv_aer = 1.0 + + if field_name in ("PM25", "TPM"): + # If RAVE aerosol emissions, convert from kg/hr to ug/m2/s + src_data[:] = np.where( + src_data < 0.0, 0.0, src_data * 1.0e3 / self.get_area_data(raw_src_fwrap)[:, :, np.newaxis] / 3600.0 + ) + elif field_name in ("CH4", "NH3", "SO2", "CO", "NOx"): + # If RAVE gas emissions, convert from kg/hr to mol/m2/s + src_data[:] = np.where( + src_data < 0.0, 0.0, conv_aer * src_data / self.get_area_data(raw_src_fwrap)[:, :, np.newaxis] / 3600.0 + ) + elif field_name in ("FRE", "FRP_MEAN"): + # For FRE, FRP, don't multiply area by 1.e6, cancelled out by MW to W conversion + src_data[:] = np.where(src_data < 0.0, 0.0, src_data / (self.get_area_data(raw_src_fwrap)[:, :, np.newaxis])) + else: + src_data[:] = np.where(src_data < 0.0, 0.0, conv_aer * src_data) + + src_data[:] = np.where(np.isnan(src_data), 0.0, src_data) + + def _get_src_search_path(self, this_time: datetime) -> str: + this_str = this_time.strftime(self.search_time_format) + return str(self.input_dir / ("RAVE-HrlyEmiss-3km_v2r0_blend_s" + this_str + "*")) + + def iter_file_pairs(self) -> Iterator[RegridFilePair]: + for date_to_process in self.dates_needed: + CR_LOGGER.info(f"RAVE processing {date_to_process=}") + src_paths = self.find_latest_src_file(date_to_process, max_lookback_hours=24) + if not src_paths: + CR_LOGGER.warn(f"No matching files found for {date_to_process} (even after lookback).") + continue + + CR_LOGGER.info(f"Reading RAVE file: {src_paths=}") + src_path = src_paths[0] + new_dst_path = self.output_dir / (self.mesh_name + "-RAVE-" + date_to_process + ".nc") + + yield RegridFilePair( + src_path=Path(src_path), + dst_path=new_dst_path, + ) + + @cached_property + def dates_needed(self) -> list[str]: + dates_needed = [] + for i in range(25): + if self.ebb_dcycle == 1: # Same-day emissions + x = self.dt_spec.datetime + timedelta(hours=i) + elif self.ebb_dcycle == -1 or self.ebb_dcycle == 2: # Persistence (-1) or forecasted (2) needs prev 24 hours + x = self.dt_spec.datetime - timedelta(hours=i) + else: + CR_LOGGER.info("EBB_DCYLE selection not recognized, reverting to same day, ebb_dcycle = 1") + x = self.dt_spec.datetime + timedelta(hours=i) + + y = x.strftime("%Y%m%d%H") + dates_needed.append(y) + return dates_needed + + def transform_regridded_data( + self, + src_field: SrcField, + dst_field_data: np.ndarray, + ds: Any, + reconciled_bounds: tuple[int, int], + dims: DimensionCollection, + ) -> np.ndarray: + if src_field.name in ("FRP_MEAN", "FRE"): + # Multiply FRE/FRP by output area so it is back to W or J*s + area = np.asarray(ds.variables["areaCell"]) + area_subset = area[reconciled_bounds[0] : reconciled_bounds[1]].reshape(dims.shape_local) + return dst_field_data * area_subset + return dst_field_data + + def post_regrid_processing( + self, + src_field: SrcField, + regridder: Union[esmpy.Regrid, esmpy.RegridFromFile], + processor: Any, + dims: DimensionCollection, + ) -> None: + if src_field.name == "TPM": + with open_nc(self.new_dst_path, mode="a") as ds: + CR_LOGGER.info("calculating PM10 as TPM - PM25") + src_fwrap_ttl = processor.create_src_field_wrapper(field_name="TPM") + src_fwrap_p25 = processor.create_src_field_wrapper(field_name="PM25") + + dst_field_ttl = processor.get_dst_field() + dst_field_ttl.data.fill(0.0) + regridder(src_fwrap_ttl.value, dst_field_ttl) + data1 = src_field.reshape_field_data(dst_field_ttl.data).copy() + + dst_field_p25 = processor.get_dst_field() + dst_field_p25.data.fill(0.0) + regridder(src_fwrap_p25.value, dst_field_p25) + data2 = src_field.reshape_field_data(dst_field_p25.data) + + # use the same src_field metadata for PM10 + var = ds.createVariable( + "PM10", + src_field.dtype, + [dim.name[0] for dim in dims.value], + fill_value=src_field.fill_value, + ) + for k, v in src_field.attrs.items(): + setattr(var, k, v) + + data3 = data1 - data2 + set_variable_data( + var, + dims, + data3, + collective=True, + ) + src_fwrap_ttl.value.destroy() + src_fwrap_p25.value.destroy() diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/src_field.py b/src/regrid_wrapper/app/chem_regrid/dataset/src_field.py new file mode 100644 index 0000000..2fd520a --- /dev/null +++ b/src/regrid_wrapper/app/chem_regrid/dataset/src_field.py @@ -0,0 +1,85 @@ +from functools import cached_property +from typing import Any + +import esmpy +import numpy as np +from pydantic import BaseModel + +from regrid_wrapper.esmpy.field_wrapper import Dimension, DimensionCollection + + +class SrcField(BaseModel): + """Represents a source field with its metadata and dimensions for regridding.""" + + name: str + attrs: dict[str, Any] + fill_value: float + dtype: Any + num_cells: int + level_out_name: str | None + level_out_size: int + time_size: int + + @cached_property + def time_dimension(self) -> Dimension: + """Returns the time dimension for the field.""" + return Dimension( + name=("Time",), + size=self.time_size, + lower=0, + upper=self.time_size, + staggerloc=esmpy.StaggerLoc.CENTER, + coordinate_type="time", + ) + + @cached_property + def nklevel_dimension(self) -> Dimension: + """Returns the vertical level dimension for the field.""" + if self.level_out_name is None: + raise ValueError("Level out name must be set for 3D fields") + return Dimension( + name=self.level_out_name, + size=self.level_out_size, + lower=0, + upper=self.level_out_size, + staggerloc=esmpy.StaggerLoc.CENTER, + coordinate_type="level", + ) + + def create_ncells_dimension(self, bounds: tuple[int, int]) -> Dimension: + """Creates the cells dimension with specified bounds.""" + return Dimension( + name=("nCells",), + size=self.num_cells, + lower=bounds[0], + upper=bounds[1], + staggerloc=esmpy.MeshLoc.ELEMENT, + coordinate_type="cell", + ) + + def create_dimension_collection(self, ncells_bounds: tuple[int, int]) -> DimensionCollection: + """Creates a collection of dimensions based on the field's shape.""" + dims = [] + if self.level_out_size == 0: + if self.time_size > 0: + dims.append(self.time_dimension) + dims.append(self.create_ncells_dimension(ncells_bounds)) + else: + dims.append(self.create_ncells_dimension(ncells_bounds)) + dims.append(self.nklevel_dimension) + if self.time_size > 0: + dims.append(self.time_dimension) + return DimensionCollection(value=tuple(dims)) + + def reshape_field_data(self, target: np.ndarray) -> np.ndarray: + """Reshapes the field data to match the expected output dimensions.""" + if self.level_out_size == 0: + if self.time_size == 0: + return target.reshape(-1) + else: + return target.reshape(self.time_size, -1) + else: + if self.time_size == 0: + return target.reshape(-1, self.level_out_size) + else: + return target.reshape(-1, self.level_out_size, self.time_size) diff --git a/src/regrid_wrapper/common.py b/src/regrid_wrapper/common.py index 0769d39..e3c50fa 100644 --- a/src/regrid_wrapper/common.py +++ b/src/regrid_wrapper/common.py @@ -1,6 +1,10 @@ import subprocess +from abc import ABC from pathlib import Path -from typing import Any +from typing import Any, Callable, TypeVar + +import yaml +from pydantic import BaseModel def ncdump(path: Path, header_only: bool = True) -> Any: @@ -11,3 +15,26 @@ def ncdump(path: Path, header_only: bool = True) -> Any: ret = subprocess.check_output(args) print(ret.decode(), flush=True) return ret + + +T = TypeVar("T", bound="RwBaseModel") + + +class RwBaseModel(ABC, BaseModel): + model_config = {"frozen": True} + + @classmethod + def from_yaml(cls: type[T], data: dict) -> T: + return cls.model_validate(data) + + @classmethod + def from_yaml_file(cls: type[T], path: Path, retriever: Callable[[dict], dict] | None = None) -> T: + yaml_data = cls.read_raw_yaml(path) + if retriever is not None: + yaml_data = retriever(yaml_data) + return cls.from_yaml(yaml_data) + + @staticmethod + def read_raw_yaml(path: Path) -> dict: + string_data = path.read_text() + return yaml.safe_load(string_data) diff --git a/src/regrid_wrapper/esmpy/field_wrapper.py b/src/regrid_wrapper/esmpy/field_wrapper.py index 2f39022..2f9517e 100644 --- a/src/regrid_wrapper/esmpy/field_wrapper.py +++ b/src/regrid_wrapper/esmpy/field_wrapper.py @@ -1,15 +1,16 @@ import abc import time from contextlib import contextmanager -from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterator, Literal, Sequence, Tuple, Union +from typing import Annotated, Any, Dict, Iterator, List, Literal, Sequence, Tuple, Union import esmpy import netCDF4 as nc import numpy as np from mpi4py import MPI +from pydantic import BeforeValidator, ConfigDict, model_validator +from regrid_wrapper.common import RwBaseModel from regrid_wrapper.context.comm import COMM, Tag, reconcile_bounds from regrid_wrapper.context.logging import LOGGER @@ -71,7 +72,15 @@ def copy_nc_variable(src: nc.Dataset, dst: nc.Dataset, varname: str, copy_data: new_var[:] = var[:] -NameListType = Tuple[str, ...] +def _coerce_to_tuple_(v: str | List[str] | Tuple[str, ...]) -> Tuple[str, ...]: + if isinstance(v, str): + return (v,) + if isinstance(v, list): + return tuple(v) + return v + + +NameListType = Annotated[Tuple[str, ...] | str, BeforeValidator(_coerce_to_tuple_)] def get_aliased_key(source: Dict, keys: NameListType | str) -> Any: @@ -91,18 +100,16 @@ def get_nc_dimension(ds: nc.Dataset, names: NameListType) -> nc.Dimension: return get_aliased_key(ds.dimensions, names) -@dataclass -class Dimension: +class Dimension(RwBaseModel): name: NameListType size: int lower: int upper: int staggerloc: int - coordinate_type: Literal["y", "x", "time", "element", "level"] + coordinate_type: Literal["y", "x", "time", "element", "level", "cell"] -@dataclass -class DimensionCollection: +class DimensionCollection(RwBaseModel): value: Tuple[Dimension, ...] @property @@ -181,13 +188,11 @@ def set_variable_data_serial(path: Path, varname: str, target_dims: DimensionCol COMM.barrier() -@dataclass -class AbstractWrapper(abc.ABC): +class AbstractWrapper(RwBaseModel, abc.ABC): dims: DimensionCollection -@dataclass -class GridSpec: +class GridSpec(RwBaseModel): x_center: str y_center: str x_dim: NameListType @@ -199,7 +204,8 @@ class GridSpec: x_index: int = 0 y_index: int = 1 - def __post_init__(self) -> None: + @model_validator(mode="after") + def _validate_corners_(self) -> "GridSpec": corner_meta = [ self.x_corner, self.y_corner, @@ -209,6 +215,7 @@ def __post_init__(self) -> None: is_given_sum = sum([ii is not None for ii in corner_meta]) if is_given_sum > 0 and is_given_sum != len(corner_meta): raise ValueError("if one corner name is supplied, then all must be supplied") + return self @property def has_corners(self) -> bool: @@ -264,8 +271,8 @@ def create_grid_dims(self, ds: nc.Dataset, grid: esmpy.Grid, staggerloc: esmpy.S return DimensionCollection(value=tuple(value)) -@dataclass class GridWrapper(AbstractWrapper): + model_config = ConfigDict(arbitrary_types_allowed=True) value: esmpy.Grid spec: GridSpec corner_dims: DimensionCollection | None = None @@ -281,17 +288,22 @@ def fill_nc_variables(self, path: Path) -> None: set_variable_data(ds.variables[self.spec.y_center], self.dims, y_center_data) -@dataclass class MeshWrapper(AbstractWrapper): + model_config = ConfigDict(arbitrary_types_allowed=True) value: esmpy.Mesh -@dataclass -class NcToMesh: +class NcToMesh(RwBaseModel): path: Path filetype: int = esmpy.FileFormat.UGRID meshname: str = "grid_topology" + @model_validator(mode="after") + def _validate_path_(self) -> "NcToMesh": + if not self.path.exists(): + raise FileNotFoundError(self.path) + return self + def create_mesh_wrapper(self) -> MeshWrapper: t1 = time.perf_counter() mesh = esmpy.Mesh( @@ -320,13 +332,8 @@ def create_mesh_wrapper(self) -> MeshWrapper: mwrap = MeshWrapper(value=mesh, dims=dims) return mwrap - def __post_init__(self) -> None: - if not self.path.exists(): - raise FileNotFoundError(self.path) - -@dataclass -class NcToGrid: +class NcToGrid(RwBaseModel): path: Path spec: GridSpec @@ -382,11 +389,15 @@ def _add_corner_coords_(self, ds: nc.Dataset, grid: esmpy.Grid) -> DimensionColl GeomType = GridWrapper | MeshWrapper -@dataclass class FieldWrapper(AbstractWrapper): + model_config = ConfigDict(arbitrary_types_allowed=True) value: esmpy.Field gwrap: GeomType + @property + def data(self) -> np.ndarray: + return self.value.data + def fill_nc_variable(self, path: Path) -> None: _LOGGER.debug(r"filling variable: {self.value.name}") with open_nc(path, "a") as ds: @@ -394,8 +405,7 @@ def fill_nc_variable(self, path: Path) -> None: set_variable_data(var, self.dims, self.value.data) -@dataclass -class MetaToField: +class MetaToField(RwBaseModel): name: str gwrap: GeomType staggerloc: int = esmpy.StaggerLoc.CENTER @@ -424,8 +434,7 @@ def create_field_wrapper(self) -> FieldWrapper: return FieldWrapper(value=field, dims=target_dims, gwrap=self.gwrap) -@dataclass -class NcToField: +class NcToField(RwBaseModel): path: Path name: str gwrap: GeomType @@ -497,14 +506,15 @@ def create_field_wrapper(self) -> FieldWrapper: return fwrap -@dataclass -class FieldWrapperCollection: +class FieldWrapperCollection(RwBaseModel): value: Tuple[FieldWrapper, ...] def fill_nc_variables(self, path: Path) -> None: for fwrap in self.value: fwrap.fill_nc_variable(path) - def __post_init__(self) -> None: + @model_validator(mode="after") + def _validate_fields_(self) -> "FieldWrapperCollection": if len(set([id(ii.value.grid) for ii in self.value])) != 1: raise ValueError("all fields must share the same grid") + return self diff --git a/src/test/test_app/test_chem_regrid/conftest.py b/src/test/test_app/test_chem_regrid/conftest.py index 2105633..aba01dd 100644 --- a/src/test/test_app/test_chem_regrid/conftest.py +++ b/src/test/test_app/test_chem_regrid/conftest.py @@ -8,7 +8,8 @@ import xarray as xr from pydantic import BaseModel -from regrid_wrapper.app.chem_regrid.context import ChemRegridContext, DatasetName +from regrid_wrapper.app.chem_regrid.chem_regrid_context import ChemRegridContext +from regrid_wrapper.app.chem_regrid.dataset.context import DatasetName from regrid_wrapper.context.comm import COMM from test.conftest import create_analytic_data_array, create_rrfs_grid_file @@ -600,7 +601,7 @@ def chem_regrid_context(tmp_path_shared: Path, dataset_test_ctx: DatasetTestCont weight_dir=dataset_test_ctx.weight_dir, cycle=cycle, mesh_name="test_mesh", - scrip_path=dataset_test_ctx.ugrid_path, + input_mesh_path=dataset_test_ctx.ugrid_path, dst_path=dst_path, ebb_dcycle=ebb_dcycle, fcst_length=24, diff --git a/src/test/test_app/test_chem_regrid/test_chem_regrid_cli.py b/src/test/test_app/test_chem_regrid/test_chem_regrid_cli.py index 69642b7..6b90f12 100644 --- a/src/test/test_app/test_chem_regrid/test_chem_regrid_cli.py +++ b/src/test/test_app/test_chem_regrid/test_chem_regrid_cli.py @@ -4,7 +4,8 @@ from _pytest.fixtures import FixtureRequest from pydantic import BaseModel -from regrid_wrapper.app.chem_regrid.context import ChemRegridContext, DatasetName +from regrid_wrapper.app.chem_regrid.chem_regrid_context import ChemRegridContext +from regrid_wrapper.app.chem_regrid.dataset.context import DatasetName from test.conftest import TEST_LOGGER @@ -35,14 +36,14 @@ def create_chem_regrid_context(test_context: ContextForTest) -> ChemRegridContex # Fields that can be None params = base_params.copy() - params["scrip_path"] = test_context.root_path / "scrip.nc" if test_context.use_scrip else None + params["input_mesh_path"] = test_context.root_path / "scrip.nc" if test_context.use_scrip else None params["dst_path"] = test_context.root_path / "dst.nc" if test_context.use_dst else None return ChemRegridContext.model_validate(params) @pytest.fixture(params=[True, False]) -def use_scrip_path(request: FixtureRequest) -> bool: +def use_input_mesh_path(request: FixtureRequest) -> bool: return request.param @@ -52,8 +53,8 @@ def use_dst_path(request: FixtureRequest) -> bool: @pytest.fixture() -def context_for_test(use_scrip_path: bool, use_dst_path: bool, tmp_path_shared: Path) -> ContextForTest: - return ContextForTest(root_path=tmp_path_shared, use_scrip=use_scrip_path, use_dst=use_dst_path) +def context_for_test(use_input_mesh_path: bool, use_dst_path: bool, tmp_path_shared: Path) -> ContextForTest: + return ContextForTest(root_path=tmp_path_shared, use_scrip=use_input_mesh_path, use_dst=use_dst_path) @pytest.fixture @@ -68,9 +69,9 @@ def test_generate_chem_regrid_context(chem_regrid_context: ChemRegridContext, co assert isinstance(chem_regrid_context, ChemRegridContext) if context_for_test.use_scrip: - assert chem_regrid_context.scrip_path == context_for_test.root_path / "scrip.nc" + assert chem_regrid_context.input_mesh_path == context_for_test.root_path / "scrip.nc" else: - assert chem_regrid_context.scrip_path is None + assert chem_regrid_context.input_mesh_path is None if context_for_test.use_dst: assert chem_regrid_context.dst_path == context_for_test.root_path / "dst.nc" diff --git a/src/test/test_app/test_chem_regrid/test_chem_regrid_impl.py b/src/test/test_app/test_chem_regrid/test_chem_regrid_impl.py index 9a0edea..4cfd533 100644 --- a/src/test/test_app/test_chem_regrid/test_chem_regrid_impl.py +++ b/src/test/test_app/test_chem_regrid/test_chem_regrid_impl.py @@ -3,8 +3,9 @@ import pytest +from regrid_wrapper.app.chem_regrid.chem_regrid_context import ChemRegridContext from regrid_wrapper.app.chem_regrid.chem_regrid_impl import main -from regrid_wrapper.app.chem_regrid.context import ChemRegridContext, DatasetName +from regrid_wrapper.app.chem_regrid.dataset.context import DatasetName from test.test_app.test_chem_regrid.conftest import DatasetTestContext @@ -12,10 +13,10 @@ def test_mock_chem_regrid_impl_rave_integration(chem_regrid_context: ChemRegridContext) -> None: if chem_regrid_context.dataset_name != DatasetName.RAVE: pytest.skip("test only for RAVE dataset") - # Mock RaveToMpasRegridProcessor to avoid actual regridding + # Mock ChemRegridProcessor to avoid actual regridding with ( - patch("regrid_wrapper.app.chem_regrid.chem_regrid_impl.RaveToMpasRegridProcessor") as mock_processor_class, - patch("regrid_wrapper.app.chem_regrid.chem_regrid_impl.RaveToMpasRegridContext") as _, + patch("regrid_wrapper.app.chem_regrid.chem_regrid_impl.ChemRegridProcessor") as mock_processor_class, + patch("regrid_wrapper.app.chem_regrid.chem_regrid_impl.AbstractDatasetRegridContext") as _, ): mock_processor = MagicMock() mock_processor_class.return_value = mock_processor @@ -23,12 +24,6 @@ def test_mock_chem_regrid_impl_rave_integration(chem_regrid_context: ChemRegridC # Run main main(chem_regrid_context) - # Verify the loop ran 25 times - # processor is initialized once (processor is None for the first pass) - # then updated 24 times - assert mock_processor_class.call_count == 1 - assert mock_processor.run.call_count == 25 - @pytest.mark.mpi def test_chem_regrid_impl_rave_integration( diff --git a/src/test/test_app/test_chem_regrid/test_chem_regrid_rrfs.py b/src/test/test_app/test_chem_regrid/test_chem_regrid_rrfs.py index 3423830..bb0dd48 100644 --- a/src/test/test_app/test_chem_regrid/test_chem_regrid_rrfs.py +++ b/src/test/test_app/test_chem_regrid/test_chem_regrid_rrfs.py @@ -6,7 +6,7 @@ from regrid_wrapper.app.chem_regrid import chem_regrid_rrfs from regrid_wrapper.app.chem_regrid.chem_regrid_rrfs import ChemRegridEnv -from regrid_wrapper.app.chem_regrid.context import DatasetName +from regrid_wrapper.app.chem_regrid.dataset.context import DatasetName def test_chem_regrid_env_from_env_vars() -> None: diff --git a/src/test/test_app/test_chem_regrid/test_dataset.py b/src/test/test_app/test_chem_regrid/test_dataset.py new file mode 100644 index 0000000..d515e47 --- /dev/null +++ b/src/test/test_app/test_chem_regrid/test_dataset.py @@ -0,0 +1,12 @@ +import pytest + +from regrid_wrapper.app.chem_regrid.chem_regrid_context import ChemRegridContext +from regrid_wrapper.app.chem_regrid.dataset.config.model import ChemRegridDataset +from regrid_wrapper.app.chem_regrid.dataset.context import DatasetName + + +@pytest.mark.parametrize("dataset_name", DatasetName) +def test_from_key(dataset_name: DatasetName) -> None: + yaml_path = ChemRegridContext.model_fields["datasets_yml_path"].default + ds = ChemRegridDataset.from_key(yaml_path, dataset_name) + assert ds.key == dataset_name From 1586b62fbe8211dd3c83b684d293d5516bd13df6 Mon Sep 17 00:00:00 2001 From: Ben Koziol Date: Fri, 24 Apr 2026 08:57:58 -0600 Subject: [PATCH 3/4] move ngfs logic to dedicated module; my across whole project --- README.md | 6 + .../app/chem_regrid/chem_regrid_context.py | 3 + .../app/chem_regrid/chem_regrid_impl.py | 351 ++++-------------- .../app/chem_regrid/dataset/context/base.py | 46 ++- .../app/chem_regrid/dataset/context/ngfs.py | 218 ++++++++++- .../app/chem_regrid/dataset/context/pecm.py | 14 +- .../app/chem_regrid/dataset/context/rave.py | 28 +- .../app/chem_regrid/dataset/src_field.py | 22 +- 8 files changed, 354 insertions(+), 334 deletions(-) diff --git a/README.md b/README.md index a366c90..c2a0a84 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,12 @@ cd /opt/project && \ mpirun -n 8 pytest -m mpi src/test ``` +It is recommended to run `pre-commit` hooks when developing: + +```shell +cd /opt/project && pre-commit run --all-files +``` + # Adding a New Dataset To add a new dataset to the regridding pipeline, follow these steps: diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_context.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_context.py index 6c29977..43833f0 100644 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_context.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_context.py @@ -9,6 +9,9 @@ class ChemRegridContext(RwBaseModel): + """This is the API class for the regridding implementation. These fields may be customized by + users.""" + dataset_name: DatasetName workdir: Path input_dir: Path diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py index 4123401..4d4d09c 100755 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py @@ -1,7 +1,3 @@ -# mypy: ignore-errors - -import glob -from datetime import datetime, timezone from pathlib import Path from typing import Iterable, Literal @@ -17,67 +13,23 @@ AbstractDatasetRegridContext, InterpMethod, ) +from regrid_wrapper.app.chem_regrid.dataset.context.ngfs import run_ngfs_regridding from regrid_wrapper.app.chem_regrid.dataset.src_field import SrcField -from regrid_wrapper.context.comm import COMM, reconcile_bounds +from regrid_wrapper.context.comm import reconcile_bounds from regrid_wrapper.esmpy.field_wrapper import ( + Dimension, + DimensionCollection, FieldWrapper, GridSpec, GridWrapper, + MeshWrapper, NcToField, NcToGrid, - copy_nc_variable, open_nc, set_variable_data, - HasNcAttrsType, - copy_nc_variable, load_variable_data, MeshWrapper, ) -# -def create_ngfs_sparse_mesh(lat_1d, lon_1d, resolution=0.01): - """ - Creates an esmpy.Mesh dynamically from 1-D point source data. - Calculates the 4 corners of a square cell of size `resolution` - around each center point in memory. - This is the best approach since NGFS data are point-source (1-D), - but we rarely have more than 1000 fires in the domain, so we - can afford to keep this in memory instead of creating a file. - """ - - num_cells = len(lat_1d) - if num_cells == 0: - return None - - num_nodes = num_cells * 4 - d = resolution / 2.0 - - node_lons = np.column_stack([lon_1d - d, lon_1d + d, lon_1d + d, lon_1d - d]).flatten() - - node_lats = np.column_stack([lat_1d - d, lat_1d - d, lat_1d + d, lat_1d + d]).flatten() - - node_coords = np.empty(num_nodes * 2, dtype=np.float64) - node_coords[0::2] = node_lons - node_coords[1::2] = node_lats - - node_ids = np.arange(1, num_nodes + 1, dtype=np.int32) - node_owners = np.full(num_nodes, COMM.rank, dtype=np.int32) - - element_ids = np.arange(1, num_cells + 1, dtype=np.int32) - element_types = np.full(num_cells, esmpy.MeshElemType.QUAD, dtype=np.int32) - - # CRITICAL FIX: esmpy expects 0-based indexing for connectivity! - element_conn = np.arange(0, num_nodes, dtype=np.int32) - - # Explicitly set spherical coordinates - mesh = esmpy.Mesh(parametric_dim=2, spatial_dim=2, coord_sys=esmpy.CoordSys.SPH_DEG) - - mesh.add_nodes(node_count=num_nodes, node_ids=node_ids, node_coords=node_coords, node_owners=node_owners) - - mesh.add_elements(element_count=num_cells, element_ids=element_ids, element_types=element_types, element_conn=element_conn) - - return mesh - - class FileDesc(BaseModel): path: Path origin: Literal["src", "dst"] @@ -91,7 +43,7 @@ def __init__(self, context: AbstractDatasetRegridContext) -> None: self.context = context self._regridder: esmpy.Regrid | None = None - self._dst_field: FieldWrapper | None = None + self._dst_fwrap: FieldWrapper | None = None self._src_gwrap: GridWrapper | None = None def initialize(self) -> None: @@ -111,33 +63,60 @@ def initialize(self) -> None: if self._dst_mesh is None: CR_LOGGER.info("create destination mesh") - # dst_mesh = esmpy.Mesh( - # filename=str(self.context.input_mesh_path), filetype=esmpy.FileFormat.SCRIP - # ) self._dst_mesh = esmpy.Mesh( filename=str(self.context.input_mesh_path), filetype=esmpy.FileFormat.UGRID, meshname="grid_topology" ) dst_mesh = self._dst_mesh - local_bounds = reconcile_bounds((0, self._dst_mesh.size_owned[1])) - self._dst_field = self._create_dst_field_(dst_mesh) + self._dst_fwrap = self._create_dst_fwrap_(dst_mesh) self._regridder = self._create_regridder_(src_fwrap) - def _create_dst_field_(self, dst_mesh: esmpy.Mesh) -> esmpy.Field: + def _create_dst_fwrap_(self, dst_mesh: esmpy.Mesh) -> FieldWrapper: CR_LOGGER.info("create destination field") - # Check for extra dims beyond lat/lon + local_bounds = reconcile_bounds((0, self.get_dst_mesh().size_owned[1])) + cells_dim = Dimension( + name=("nCells",), + size=self.context.num_cells, + lower=local_bounds[0], + upper=local_bounds[1], + staggerloc=esmpy.MeshLoc.ELEMENT, + coordinate_type="element", + ) + dims = [cells_dim] ndbounds = [] - if self.context.level_out_size > 0: + if self.context.level_out_size is not None and self.context.level_out_size > 0: + if self.context.level_out_name is None: + raise ValueError("level_out_name must be specified if level_out_size > 0") + level_dim = Dimension( + name=self.context.level_out_name, + size=self.context.level_out_size, + staggerloc=esmpy.StaggerLoc.CENTER, + coordinate_type="level", + lower=0, + upper=self.context.level_out_size, + ) + dims.append(level_dim) ndbounds.append(self.context.level_out_size) - if self.context.time_size > 0: + if self.context.time_size is not None and self.context.time_size > 0: ndbounds.append(self.context.time_size) + time_dim = Dimension( + name=("Time",), + size=self.context.time_size, + staggerloc=esmpy.StaggerLoc.CENTER, + coordinate_type="time", + lower=0, + upper=self.context.time_size, + ) + dims.append(time_dim) kwargs = {} if ndbounds: kwargs["ndbounds"] = tuple(ndbounds) - return esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, **kwargs) + esmpy_field = esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, **kwargs) + gwrap = MeshWrapper(value=dst_mesh, dims=DimensionCollection(value=(cells_dim,))) + return FieldWrapper(value=esmpy_field, gwrap=gwrap, dims=DimensionCollection(value=tuple(dims))) def _create_regridder_(self, src_fwrap: FieldWrapper) -> esmpy.RegridFromFile | esmpy.Regrid: CR_LOGGER.info("create regridder") @@ -145,7 +124,7 @@ def _create_regridder_(self, src_fwrap: FieldWrapper) -> esmpy.RegridFromFile | CR_LOGGER.info("create regridder from file") regridder = esmpy.RegridFromFile( srcfield=src_fwrap.value, - dstfield=self._dst_field.value, + dstfield=self.get_dst_fwrap().value, filename=str(self.context.weight_path), ) else: @@ -156,13 +135,12 @@ def _create_regridder_(self, src_fwrap: FieldWrapper) -> esmpy.RegridFromFile | InterpMethod.BILINEAR: esmpy.RegridMethod.BILINEAR, InterpMethod.NEAREST_STOD: esmpy.RegridMethod.NEAREST_STOD, } - # Default to NEAREST_STOD if not found in map (preserving original behavior) regrid_method = method_map[self.context.InterpMethod] CR_LOGGER.info(f"using {regrid_method} interp") regridder = esmpy.Regrid( srcfield=src_fwrap.value, - dstfield=self._dst_field, + dstfield=self.get_dst_fwrap().value, regrid_method=regrid_method, unmapped_action=esmpy.UnmappedAction.IGNORE, ignore_degenerate=True, @@ -175,7 +153,7 @@ def run(self) -> None: CR_LOGGER.info("apply regridding") CR_LOGGER.info("create output file") - self.create_output_file() + self.context.create_output_file() for src_field in self.context.src_fields: self._regrid_src_field(src_field) @@ -202,17 +180,15 @@ def _regrid_src_field(self, src_field: SrcField) -> None: regridder = self.get_regridder() src_fwrap = self.create_src_field_wrapper(field_name=src_field.name) - dst_field = self.get_dst_field() - dst_field.data.fill(0.0) - regridder(src_fwrap.value, dst_field) + dst_fwrap = self.get_dst_fwrap() + dst_fwrap.data.fill(0.0) + regridder(src_fwrap.value, dst_fwrap.value) - local_bounds = (dst_field.lower_bounds[0], dst_field.upper_bounds[0]) - reconciled_bounds = reconcile_bounds(local_bounds) - dims = src_field.create_dimension_collection(reconciled_bounds) + dims = src_field.create_dimension_collection(dst_fwrap.gwrap.dims.value[0].bounds) CR_LOGGER.debug(f"{dims=}") CR_LOGGER.info("writing field to netcdf") with open_nc(self.context.new_dst_path, mode="a") as ds: - transformed_data = self.context.transform_regridded_data(src_field, dst_field.data, ds, reconciled_bounds, dims) + transformed_data = self.context.transform_regridded_data(src_field, dst_fwrap, ds, dims) CR_LOGGER.info(f"creating variable {src_field.name=}") var = ds.createVariable( @@ -227,8 +203,8 @@ def _regrid_src_field(self, src_field: SrcField) -> None: CR_LOGGER.info(f"setting variable data {src_field.name=}") set_variable_data( var, - dims, - src_field.reshape_field_data(transformed_data), + dst_fwrap.dims, + transformed_data, collective=True, ) CR_LOGGER.info(f"finished writing field to netcdf {src_field.name=}") @@ -237,33 +213,11 @@ def _regrid_src_field(self, src_field: SrcField) -> None: self.context.post_regrid_processing(src_field, regridder, self, dims) - def create_output_file(self): - if self.context.rank == 0: - with open_nc(self.context.new_dst_path, mode="w", clobber=True, parallel=False) as dst_nc: - dst_nc.createDimension("nCells", self.context.num_cells) - if self.context.level_out_name is not None: - dst_nc.createDimension(self.context.level_out_name, self.context.level_out_size) - dst_nc.createDimension("StrLen", 64) - if self.context.time_size > 1: - dst_nc.createDimension("Time", self.context.time_size) - elif self.context.time_size == 1: - if "Time" not in dst_nc.dimensions: - dst_nc.createDimension("Time") - else: - CR_LOGGER.debug("Not creating a time dimension") - dst_nc.setncattr("created_at", str(datetime.now(timezone.utc))) - dst_nc.setncattr("src_path", str(self.context.src_path)) - dst_nc.setncattr("dst_path", str(self.context.dst_path)) - - with open_nc(self.context.dst_path, mode="r", parallel=False) as src_nc: - for varname in self.context.var_names_to_copy_to_output_file: - copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) - def finalize(self) -> None: CR_LOGGER.info("finalizing") - self._regridder.destroy() - self._dst_field.value.destroy() - self._src_gwrap.value.destroy() + self.get_regridder().destroy() + self.get_dst_fwrap().value.destroy() + self.get_src_gwrap().value.destroy() # TODO: There could be an option to destroy the destination mesh when finalizing. However, # it is more efficient to leave it since the destination is not variable at this point. # self._dst_mesh.destroy() @@ -323,171 +277,20 @@ def get_src_gwrap(self) -> GridWrapper: raise ValueError return self._src_gwrap - def get_dst_field(self) -> FieldWrapper: - if self._dst_field is None: + def get_dst_fwrap(self) -> FieldWrapper: + if self._dst_fwrap is None: raise ValueError - return self._dst_field + return self._dst_fwrap def get_regridder(self) -> esmpy.Regrid: if self._regridder is None: raise ValueError return self._regridder - def init_destination_only(self) -> None: - """Loads the heavy MPAS destination mesh once for dynamic NGFS processing.""" - CR_LOGGER.info("Initializing MPAS Destination Mesh (Once)") - esmpy.Manager(debug=True) - - # if not self.context.input_mesh_path.exists() and self.context.rank == 0: - # CR_LOGGER.info("writing mpas scrip grid") - # mpas_desc = MpasCellMeshDescriptor( - # str(self.context.dst_path), self.context.mesh_name + ".init" - # ) - # mpas_desc.to_scrip(str(self.context.input_mesh_path)) - - CR_LOGGER.info("create destination mesh") - dst_mesh = esmpy.Mesh(filename=str(self.context.input_mesh_path), filetype=esmpy.FileFormat.UGRID, meshname="grid_topology") - - # Create destination field (using logic from your original initialize method) - ndbounds = None - if self.context.level_out_size > 1 and self.context.time_size > 1: - ndbounds = (self.context.level_out_size, self.context.time_size) - elif self.context.level_out_size > 1 and self.context.time_size == 1: - ndbounds = (self.context.level_out_size,) - elif self.context.level_out_size == 1 and self.context.time_size > 1: - ndbounds = (self.context.time_size,) - - self._dst_field = esmpy.Field(dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=ndbounds) - - def process_ngfs_file(self, file_path: Path, resolution: float = 0.01) -> None: - """Dynamically builds a mesh for NGFS points, regrids, and writes the output.""" - CR_LOGGER.info(f"Processing NGFS file: {file_path}") - - # 1. Read NGFS Coordinates AND Area - with open_nc(file_path, mode="r") as ds: - lats = ds.variables["lat"][:].filled(np.nan) - lons = ds.variables["lon"][:].filled(np.nan) - - # Read the NGFS area (in km2) - if "GRID_AREA" in ds.variables: - grid_area = ds.variables["GRID_AREA"][:].filled(np.nan) - else: - CR_LOGGER.warning("GRID_AREA not found! Defaulting to 1.0 km2.") - grid_area = np.ones_like(lats) - - # Filter out NaNs - valid = ~np.isnan(lats) & ~np.isnan(lons) & ~np.isnan(grid_area) - lats = lats[valid] - lons = lons[valid] - grid_area = grid_area[valid] - - # CRITICAL FIX: Convert -180/180 to 0/360 to match MPAS grid - lons = lons % 360.0 - - if len(lats) == 0: - CR_LOGGER.warning("No valid fires in file.") - return - - # 2. Build Sparse Source Mesh - src_mesh = create_ngfs_sparse_mesh(lats, lons, resolution) - if src_mesh is None: - return - - # 3. Create Output NetCDF File (Header Info) - if self.context.rank == 0: - with open_nc(self.context.new_dst_path, mode="w", clobber=True, parallel=False) as dst_nc: - dst_nc.createDimension("nCells", self.context.num_cells) - dst_nc.createDimension(self.context.level_out_name, self.context.level_out_size) - dst_nc.createDimension("StrLen", 64) - if self.context.time_size > 1: - dst_nc.createDimension("Time", self.context.time_size) - elif self.context.time_size == 1: - dst_nc.createDimension("Time") - dst_nc.setncattr("created_at", str(datetime.now(timezone.utc))) - dst_nc.setncattr("src_path", str(self.context.src_path)) - dst_nc.setncattr("dst_path", str(self.context.dst_path)) - - # Copy base MPAS variables - with open_nc(self.context.dst_path, mode="r", parallel=False) as src_nc: - for varname in ("latCell", "lonCell", "areaCell", "xland", "xtime"): - copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) - - # 4. Process Each Variable - for src_field in self.context.src_fields: - CR_LOGGER.info(f"regridding NGFS {src_field.name=}") - - # Create Source Field dynamically - src_field = esmpy.Field(src_mesh, name=src_field.name, meshloc=esmpy.MeshLoc.ELEMENT) - - # Map MPAS expected name to NGFS actual name - if src_field.name == "PM25": - ngfs_var_name = "EMIS_PM25" - else: - ngfs_var_name = src_field.name - - # Load the raw data - with open_nc(file_path, mode="r") as ds: - if ngfs_var_name in ds.variables: - raw_data = ds.variables[ngfs_var_name][:].filled(0.0)[valid] - else: - CR_LOGGER.warning(f"Variable {ngfs_var_name} not found! Skipping.") - continue - - # --------------------------------------------------------- - # UNIT CONVERSIONS (Identical to RAVE logic) - # --------------------------------------------------------- - if src_field.name in ("PM25", "TPM"): - # Convert from kg/hr to ug/m2/s (1e3 handles the km2 to m2 and kg to ug ratio) - src_data = np.where(raw_data < 0.0, 0.0, raw_data * 1.0e3 / grid_area / 3600.0) - elif src_field.name in ("FRE", "FRP_MEAN"): - # For FRE, FRP: MW to W (1e6) cancels out with km2 to m2 (1e6) - src_data = np.where(raw_data < 0.0, 0.0, raw_data / grid_area) - else: - src_data = np.where(raw_data < 0.0, 0.0, raw_data) - - src_field.data[:] = src_data - - # Create Dynamic Regridder - regridder = esmpy.Regrid( - srcfield=src_field, - dstfield=self._dst_field, - regrid_method=esmpy.RegridMethod.CONSERVE, - unmapped_action=esmpy.UnmappedAction.IGNORE, - ) - - # Apply Regridding - self._dst_field.data.fill(0.0) - regridder(src_field, self._dst_field) - - # Write to Output NetCDF - local_bounds = (self._dst_field.lower_bounds[0], self._dst_field.upper_bounds[0]) - reconciled_bounds = reconcile_bounds(local_bounds) - dims = src_field.create_dimension_collection(reconciled_bounds) - - with open_nc(self.context.new_dst_path, mode="a") as ds: - var = ds.createVariable( - src_field.name, # Keep it as standard name in output! - src_field.dtype, - [dim.name[0] for dim in dims.value], - fill_value=src_field.fill_value, - ) - for k, v in src_field.attrs.items(): - setattr(var, k, v) - - # Multiply by areaCell for Power/Energy variables (back to total W in cell) - if src_field.name in ("FRP_MEAN", "FRE"): - area = np.asarray(ds.variables["areaCell"]) - area_subset = area[reconciled_bounds[0] : reconciled_bounds[1]] - set_variable_data(var, dims, src_field.reshape_field_data(self._dst_field.data * area_subset), collective=True) - else: - set_variable_data(var, dims, src_field.reshape_field_data(self._dst_field.data), collective=True) - - # Clean up memory - regridder.destroy() - src_field.destroy() - - # Clean up mesh - src_mesh.destroy() + def get_dst_mesh(self) -> esmpy.Mesh: + if self._dst_mesh is None: + raise ValueError + return self._dst_mesh def run_regridding(ctx: AbstractDatasetRegridContext) -> None: @@ -552,32 +355,6 @@ def main(ctx: ChemRegridContext) -> None: ) if ctx.dataset_name == "NGFS": - processor = ChemRegridProcessor(context=regrid_context) - - for date_to_process in regrid_context.dates_needed: - # Construct the filename (Adjust the prefix 'ngfs_' if your files are named differently) - # print("GAF debug: attempting to read: " + input_dir + "/NGFS_v0.31_" + date_to_process + "_0p01.nc") - ngfs_paths = glob.glob(str(ctx.input_dir) + "/NGFS_v0.31_0p01_" + date_to_process + "0000.nc") - - if not ngfs_paths: - print(f"ERROR: Missing NGFS file for {date_to_process}. Skipping.") - exit(1) - # TODO: perhaps add a helper similarly as I added for RAVE to search for the latest - # available file in case that the current datetime does not exist - continue - - ngfs_path = Path(ngfs_paths[0]) - new_dst_path = Path(str(ctx.output_dir) + "/" + ctx.mesh_name + "-NGFS-" + date_to_process + ".nc") - print(f"GAF reading NGFS file: {ngfs_path}") - - # Update context paths for the current hour - processor.context.src_path = ngfs_path - processor.context.new_dst_path = new_dst_path - - # Execute the dynamic regridding for this specific hour's fires - # Note that resolution is hard coded... - processor.process_ngfs_file(ngfs_path, resolution=0.01) - - CR_LOGGER.info("NGFS success") + run_ngfs_regridding(regrid_context) else: run_regridding(regrid_context) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/base.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/base.py index 1aa9b86..b0110bb 100644 --- a/src/regrid_wrapper/app/chem_regrid/dataset/context/base.py +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/base.py @@ -1,6 +1,6 @@ import glob from abc import ABC, abstractmethod -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from enum import StrEnum, unique from functools import cached_property from pathlib import Path @@ -8,7 +8,7 @@ import esmpy import numpy as np -from pydantic import BaseModel +from pydantic import BaseModel, model_validator from regrid_wrapper.app.chem_regrid import CR_LOGGER from regrid_wrapper.app.chem_regrid.dataset.src_field import SrcField @@ -17,6 +17,7 @@ DimensionCollection, FieldWrapper, HasNcAttrsType, + copy_nc_variable, open_nc, ) @@ -90,9 +91,9 @@ class AbstractDatasetRegridContext(ABC, BaseModel): level_in_name: str | None # level_in_size: int level_out_name: str | None - level_out_size: int + level_out_size: int | None time_name: str | None - time_size: int + time_size: int | None cycle: str ebb_dcycle: int input_dir: Path @@ -224,13 +225,12 @@ def _get_src_search_path(self, this_time: datetime) -> str: def transform_regridded_data( self, src_field: SrcField, - dst_field_data: np.ndarray, + dst_fwrap: FieldWrapper, ds: Any, - reconciled_bounds: tuple[int, int], dims: DimensionCollection, ) -> np.ndarray: """Hook for dataset-specific data transformations after regridding but before writing.""" - return dst_field_data + return dst_fwrap.data def post_regrid_processing( self, @@ -241,3 +241,35 @@ def post_regrid_processing( ) -> None: """Hook for dataset-specific operations after a field has been regridded and written.""" pass + + def create_output_file(self) -> None: + if self.rank == 0: + with open_nc(self.new_dst_path, mode="w", clobber=True, parallel=False) as dst_nc: + dst_nc.createDimension("nCells", self.num_cells) + if self.level_out_name is not None: + dst_nc.createDimension(self.level_out_name, self.level_out_size) + dst_nc.createDimension("StrLen", 64) + if self.time_size is not None and self.time_size > 1: + dst_nc.createDimension("Time", self.time_size) + elif self.time_size == 1: + if "Time" not in dst_nc.dimensions: + dst_nc.createDimension("Time") + else: + CR_LOGGER.debug("Not creating a time dimension") + dst_nc.setncattr("created_at", str(datetime.now(timezone.utc))) + dst_nc.setncattr("src_path", str(self.src_path)) + dst_nc.setncattr("dst_path", str(self.dst_path)) + + with open_nc(self.dst_path, mode="r", parallel=False) as src_nc: + for varname in self.var_names_to_copy_to_output_file: + copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) + + @model_validator(mode="after") + def _validate_model(self) -> "AbstractDatasetRegridContext": + level_values = [self.level_out_name, self.level_out_size] + if any(level_values) and not all(level_values): + raise ValueError("level_out_name and level_out_size must be specified together") + time_values = [self.time_name, self.time_size] + if any(time_values) and not all(time_values): + raise ValueError("time_name and time_size must be specified together") + return self diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/ngfs.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/ngfs.py index 6ecc5b9..9ad2d4e 100644 --- a/src/regrid_wrapper/app/chem_regrid/dataset/context/ngfs.py +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/ngfs.py @@ -1,12 +1,228 @@ -from datetime import timedelta +import glob +from datetime import datetime, timedelta, timezone from functools import cached_property +from pathlib import Path from typing import Iterator +import esmpy +import numpy as np + from regrid_wrapper.app.chem_regrid import CR_LOGGER from regrid_wrapper.app.chem_regrid.dataset.context.base import ( AbstractDatasetRegridContext, RegridFilePair, ) +from regrid_wrapper.context.comm import COMM, reconcile_bounds +from regrid_wrapper.esmpy.field_wrapper import FieldWrapper, copy_nc_variable, open_nc, set_variable_data + + +def create_ngfs_sparse_mesh(lat_1d: np.ndarray, lon_1d: np.ndarray, resolution: float = 0.01) -> esmpy.Mesh: + """ + Creates an esmpy.Mesh dynamically from 1-D point source data. + Calculates the 4 corners of a square cell of size `resolution` + around each center point in memory. + This is the best approach since NGFS data are point-source (1-D), + but we rarely have more than 1000 fires in the domain, so we + can afford to keep this in memory instead of creating a file. + """ + + num_cells = len(lat_1d) + if num_cells == 0: + raise ValueError("must have at least one cell in the mesh") + + num_nodes = num_cells * 4 + d = resolution / 2.0 + + node_lons = np.column_stack([lon_1d - d, lon_1d + d, lon_1d + d, lon_1d - d]).flatten() + + node_lats = np.column_stack([lat_1d - d, lat_1d - d, lat_1d + d, lat_1d + d]).flatten() + + node_coords = np.empty(num_nodes * 2, dtype=np.float64) + node_coords[0::2] = node_lons + node_coords[1::2] = node_lats + + node_ids = np.arange(1, num_nodes + 1, dtype=np.int32) + node_owners = np.full(num_nodes, COMM.rank, dtype=np.int32) + + element_ids = np.arange(1, num_cells + 1, dtype=np.int32) + element_types = np.full(num_cells, esmpy.MeshElemType.QUAD, dtype=np.int32) + + # CRITICAL FIX: esmpy expects 0-based indexing for connectivity! + element_conn = np.arange(0, num_nodes, dtype=np.int32) + + # Explicitly set spherical coordinates + mesh = esmpy.Mesh(parametric_dim=2, spatial_dim=2, coord_sys=esmpy.CoordSys.SPH_DEG) + + mesh.add_nodes(node_count=num_nodes, node_ids=node_ids, node_coords=node_coords, node_owners=node_owners) + + mesh.add_elements(element_count=num_cells, element_ids=element_ids, element_types=element_types, element_conn=element_conn) + + return mesh + + +def process_ngfs_file( + ctx: AbstractDatasetRegridContext, file_path: Path, dst_fwrap: FieldWrapper, resolution: float = 0.01 +) -> None: + """Dynamically builds a mesh for NGFS points, regrids, and writes the output.""" + CR_LOGGER.info(f"Processing NGFS file: {file_path}") + + # 1. Read NGFS Coordinates AND Area + with open_nc(file_path, mode="r") as ds: + lats = ds.variables["lat"][:].filled(np.nan) + lons = ds.variables["lon"][:].filled(np.nan) + + # Read the NGFS area (in km2) + if "GRID_AREA" in ds.variables: + grid_area = ds.variables["GRID_AREA"][:].filled(np.nan) + else: + CR_LOGGER.warning("GRID_AREA not found! Defaulting to 1.0 km2.") + grid_area = np.ones_like(lats) + + # Filter out NaNs + valid = ~np.isnan(lats) & ~np.isnan(lons) & ~np.isnan(grid_area) + lats = lats[valid] + lons = lons[valid] + grid_area = grid_area[valid] + + # CRITICAL FIX: Convert -180/180 to 0/360 to match MPAS grid + lons = lons % 360.0 + + if len(lats) == 0: + CR_LOGGER.warning("No valid fires in file.") + return + + # 2. Build Sparse Source Mesh + src_mesh = create_ngfs_sparse_mesh(lats, lons, resolution) + + # 3. Create Output NetCDF File (Header Info) + if ctx.rank == 0: + with open_nc(ctx.new_dst_path, mode="w", clobber=True, parallel=False) as dst_nc: + dst_nc.createDimension("nCells", ctx.num_cells) + if ctx.level_out_name is None: + raise ValueError("level_out_name must be set for NGFS regridding") + dst_nc.createDimension(ctx.level_out_name, ctx.level_out_size) + dst_nc.createDimension("StrLen", 64) + if ctx.time_size is not None and ctx.time_size > 1: + dst_nc.createDimension("Time", ctx.time_size) + elif ctx.time_size == 1: + dst_nc.createDimension("Time") + dst_nc.setncattr("created_at", str(datetime.now(timezone.utc))) + dst_nc.setncattr("src_path", str(ctx.src_path)) + dst_nc.setncattr("dst_path", str(ctx.dst_path)) + + # Copy base MPAS variables + with open_nc(ctx.dst_path, mode="r", parallel=False) as src_nc: + for varname in ("latCell", "lonCell", "areaCell", "xland", "xtime"): + copy_nc_variable(src_nc, dst_nc, varname, copy_data=True) + + # 4. Process Each Variable + for src_field in ctx.src_fields: + CR_LOGGER.info(f"regridding NGFS {src_field.name=}") + + # Create Source Field dynamically + src_field = esmpy.Field(src_mesh, name=src_field.name, meshloc=esmpy.MeshLoc.ELEMENT) + + # Map MPAS expected name to NGFS actual name + if src_field.name == "PM25": + ngfs_var_name = "EMIS_PM25" + else: + ngfs_var_name = src_field.name + + # Load the raw data + with open_nc(file_path, mode="r") as ds: + if ngfs_var_name in ds.variables: + raw_data = ds.variables[ngfs_var_name][:].filled(0.0)[valid] + else: + CR_LOGGER.warning(f"Variable {ngfs_var_name} not found! Skipping.") + continue + + # --------------------------------------------------------- + # UNIT CONVERSIONS (Identical to RAVE logic) + # --------------------------------------------------------- + if src_field.name in ("PM25", "TPM"): + # Convert from kg/hr to ug/m2/s (1e3 handles the km2 to m2 and kg to ug ratio) + src_data = np.where(raw_data < 0.0, 0.0, raw_data * 1.0e3 / grid_area / 3600.0) + elif src_field.name in ("FRE", "FRP_MEAN"): + # For FRE, FRP: MW to W (1e6) cancels out with km2 to m2 (1e6) + src_data = np.where(raw_data < 0.0, 0.0, raw_data / grid_area) + else: + src_data = np.where(raw_data < 0.0, 0.0, raw_data) + + src_field.data[:] = src_data + + # Create Dynamic Regridder + regridder = esmpy.Regrid( + srcfield=src_field, + dstfield=dst_fwrap.value, + regrid_method=esmpy.RegridMethod.CONSERVE, + unmapped_action=esmpy.UnmappedAction.IGNORE, + ) + + # Apply Regridding + dst_fwrap.data.fill(0.0) + regridder(src_field, dst_fwrap) + + # Write to Output NetCDF + local_bounds = (dst_fwrap.value.lower_bounds[0], dst_fwrap.value.upper_bounds[0]) + reconciled_bounds = reconcile_bounds(local_bounds) + dims = src_field.create_dimension_collection(reconciled_bounds) + + with open_nc(ctx.new_dst_path, mode="a") as ds: + var = ds.createVariable( + src_field.name, # Keep it as standard name in output! + src_field.dtype, + [dim.name[0] for dim in dims.value], + fill_value=src_field.fill_value, + ) + for k, v in src_field.attrs.items(): + setattr(var, k, v) + + # Multiply by areaCell for Power/Energy variables (back to total W in cell) + if src_field.name in ("FRP_MEAN", "FRE"): + area = np.asarray(ds.variables["areaCell"]) + area_subset = area[reconciled_bounds[0] : reconciled_bounds[1]] + set_variable_data(var, dims, src_field.reshape_field_data(dst_fwrap.data * area_subset), collective=True) + else: + set_variable_data(var, dims, src_field.reshape_field_data(dst_fwrap.data), collective=True) + + # Clean up memory + regridder.destroy() + src_field.destroy() + + # Clean up mesh + src_mesh.destroy() + + +def run_ngfs_regridding(regrid_context: AbstractDatasetRegridContext) -> None: + from regrid_wrapper.app.chem_regrid.chem_regrid_impl import ChemRegridProcessor + + processor = ChemRegridProcessor(context=regrid_context) + + for date_to_process in regrid_context.dates_needed: + # Construct the filename (Adjust the prefix 'ngfs_' if your files are named differently) + # print("GAF debug: attempting to read: " + input_dir + "/NGFS_v0.31_" + date_to_process + "_0p01.nc") + ngfs_paths = glob.glob(str(regrid_context.input_dir) + "/NGFS_v0.31_0p01_" + date_to_process + "0000.nc") + + if not ngfs_paths: + print(f"ERROR: Missing NGFS file for {date_to_process}. Skipping.") + exit(1) + # TODO: perhaps add a helper similarly as I added for RAVE to search for the latest + # available file in case that the current datetime does not exist + continue + + ngfs_path = Path(ngfs_paths[0]) + new_dst_path = Path(str(regrid_context.output_dir) + "/" + regrid_context.mesh_name + "-NGFS-" + date_to_process + ".nc") + print(f"GAF reading NGFS file: {ngfs_path}") + + # Update context paths for the current hour + processor.context.src_path = ngfs_path + processor.context.new_dst_path = new_dst_path + + # Execute the dynamic regridding for this specific hour's fires + # Note that resolution is hard coded... + process_ngfs_file(regrid_context, ngfs_path, processor.get_dst_fwrap(), resolution=0.01) + + CR_LOGGER.info("NGFS success") class NGFS_DatasetRegridContext(AbstractDatasetRegridContext): diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/pecm.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/pecm.py index 83c6796..b753207 100644 --- a/src/regrid_wrapper/app/chem_regrid/dataset/context/pecm.py +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/pecm.py @@ -38,16 +38,14 @@ def post_regrid_processing( CR_LOGGER.info("renaming and combining tree fields") src_fwrap_enl = processor.create_src_field_wrapper(field_name="ENL_POLL") - dst_field_enl = processor.get_dst_field() + dst_field_enl = processor.get_dst_fwrap() dst_field_enl.data.fill(0.0) - regridder(src_fwrap_enl.value, dst_field_enl) - data_enl = src_field.reshape_field_data(dst_field_enl.data).copy() + regridder(src_fwrap_enl.value, dst_field_enl.value) src_fwrap_dbl = processor.create_src_field_wrapper(field_name="DBL_POLL") - dst_field_dbl = processor.get_dst_field() + dst_field_dbl = processor.get_dst_fwrap() dst_field_dbl.data.fill(0.0) - regridder(src_fwrap_dbl.value, dst_field_dbl) - data_dbl = src_field.reshape_field_data(dst_field_dbl.data) + regridder(src_fwrap_dbl.value, dst_field_dbl.value) var = ds.createVariable( "TREE_POLL", @@ -59,8 +57,8 @@ def post_regrid_processing( setattr(var, k, v) set_variable_data( var, - dims, - data_enl + data_dbl, + dst_field_enl.dims, + dst_field_enl.data + dst_field_dbl.data, collective=True, ) src_fwrap_enl.value.destroy() diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/context/rave.py b/src/regrid_wrapper/app/chem_regrid/dataset/context/rave.py index 5bcd830..a90ca23 100644 --- a/src/regrid_wrapper/app/chem_regrid/dataset/context/rave.py +++ b/src/regrid_wrapper/app/chem_regrid/dataset/context/rave.py @@ -17,6 +17,7 @@ DimensionCollection, FieldWrapper, NcToField, + load_variable_data, open_nc, set_variable_data, ) @@ -114,17 +115,18 @@ def dates_needed(self) -> list[str]: def transform_regridded_data( self, src_field: SrcField, - dst_field_data: np.ndarray, + dst_field: FieldWrapper, ds: Any, - reconciled_bounds: tuple[int, int], dims: DimensionCollection, ) -> np.ndarray: if src_field.name in ("FRP_MEAN", "FRE"): # Multiply FRE/FRP by output area so it is back to W or J*s - area = np.asarray(ds.variables["areaCell"]) - area_subset = area[reconciled_bounds[0] : reconciled_bounds[1]].reshape(dims.shape_local) - return dst_field_data * area_subset - return dst_field_data + area_dims = DimensionCollection(value=(dims.get("nCells"),)) + area_subset = load_variable_data(ds.variables["areaCell"], area_dims) + area_subset = area_subset.reshape(dst_field.dims.shape_local) + dst_field_data = dst_field.data * area_subset + return dst_field_data + return dst_field.data def post_regrid_processing( self, @@ -139,15 +141,13 @@ def post_regrid_processing( src_fwrap_ttl = processor.create_src_field_wrapper(field_name="TPM") src_fwrap_p25 = processor.create_src_field_wrapper(field_name="PM25") - dst_field_ttl = processor.get_dst_field() + dst_field_ttl = processor.get_dst_fwrap() dst_field_ttl.data.fill(0.0) - regridder(src_fwrap_ttl.value, dst_field_ttl) - data1 = src_field.reshape_field_data(dst_field_ttl.data).copy() + regridder(src_fwrap_ttl.value, dst_field_ttl.value) - dst_field_p25 = processor.get_dst_field() + dst_field_p25 = processor.get_dst_fwrap() dst_field_p25.data.fill(0.0) - regridder(src_fwrap_p25.value, dst_field_p25) - data2 = src_field.reshape_field_data(dst_field_p25.data) + regridder(src_fwrap_p25.value, dst_field_p25.value) # use the same src_field metadata for PM10 var = ds.createVariable( @@ -159,10 +159,10 @@ def post_regrid_processing( for k, v in src_field.attrs.items(): setattr(var, k, v) - data3 = data1 - data2 + data3 = dst_field_ttl.data - dst_field_p25.data set_variable_data( var, - dims, + dst_field_ttl.dims, data3, collective=True, ) diff --git a/src/regrid_wrapper/app/chem_regrid/dataset/src_field.py b/src/regrid_wrapper/app/chem_regrid/dataset/src_field.py index 2fd520a..897c90c 100644 --- a/src/regrid_wrapper/app/chem_regrid/dataset/src_field.py +++ b/src/regrid_wrapper/app/chem_regrid/dataset/src_field.py @@ -2,7 +2,6 @@ from typing import Any import esmpy -import numpy as np from pydantic import BaseModel from regrid_wrapper.esmpy.field_wrapper import Dimension, DimensionCollection @@ -59,27 +58,16 @@ def create_ncells_dimension(self, bounds: tuple[int, int]) -> Dimension: def create_dimension_collection(self, ncells_bounds: tuple[int, int]) -> DimensionCollection: """Creates a collection of dimensions based on the field's shape.""" + + ncells_dim = self.create_ncells_dimension(ncells_bounds) dims = [] if self.level_out_size == 0: if self.time_size > 0: dims.append(self.time_dimension) - dims.append(self.create_ncells_dimension(ncells_bounds)) + dims.append(ncells_dim) else: - dims.append(self.create_ncells_dimension(ncells_bounds)) - dims.append(self.nklevel_dimension) if self.time_size > 0: dims.append(self.time_dimension) + dims.append(ncells_dim) + dims.append(self.nklevel_dimension) return DimensionCollection(value=tuple(dims)) - - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - """Reshapes the field data to match the expected output dimensions.""" - if self.level_out_size == 0: - if self.time_size == 0: - return target.reshape(-1) - else: - return target.reshape(self.time_size, -1) - else: - if self.time_size == 0: - return target.reshape(-1, self.level_out_size) - else: - return target.reshape(-1, self.level_out_size, self.time_size) From 9c6593423cc23e1dcaa4b1e802f8d2d6a4d74bbe Mon Sep 17 00:00:00 2001 From: Ben Koziol Date: Fri, 24 Apr 2026 10:46:05 -0600 Subject: [PATCH 4/4] enable ci on PR --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 768e4c8..51e0720 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,8 +3,8 @@ name: CI on: push: # branches: [ main ] -# pull_request: -# branches: [ main ] + pull_request: +# branches: [ mpas_aero_v0 ] jobs: lint-and-test: