diff --git a/README.md b/README.md index 69fa758..9f762c1 100644 --- a/README.md +++ b/README.md @@ -4,45 +4,42 @@ ## What is this? -`mxalign` is an `xarray`-based package designed for the alignment and verification of meteorological datasets. It standardizes operations across datasets by attaching properties along three main axes: -- **Space:** Grid or point-based data -- **Time:** Forecasts, observations, or climatology -- **Uncertainty:** Deterministic, ensemble, or quantile forecasts +`mxalign` is an `xarray`-based package for aligning meteorological datasets. It operates on datasets that carry **traits** — metadata attributes that describe the nature of a dataset along three axes: -Currently, `mxalign` also acts as a full execution engine. It can load datasets (e.g., Anemoi inference outputs, observation datasets), apply transformations, align datasets in both space and time to match a reference, safely broadcast NaNs, and execute verification metrics on scaled Dask clusters (Local or Slurm). +`mxalign` is an `xarray`-based package for aligning meteorological datasets. It operates on datasets that carry **traits** — metadata attributes that describe the nature of a dataset along three axes: +- **Space:** `grid` or `point` +- **Time:** `forecast`, `observation`, or `climatology` +- **Uncertainty:** `deterministic`, `ensemble`, or `quantile` -> ⚠️ **Roadmap & Future Architecture Changes (planned for v0.2.0):** -> Currently, `mxalign` handles both alignment and the execution of the verification tooling pipeline, including loading and validation. In the upcoming `v0.2.0` release, this architecture will be refactored: -> - **Loading** will be split out into [`mlwp-data-loaders`](https://github.com/mlwp-tools/mlwp-data-loaders). -> - **Validation** of loaded `xr.Dataset`s will be moved to [`mlwp-data-specs`](https://github.com/mlwp-tools/mlwp-data-specs) (which will contain the requirements for each of the dataset traits and the validation logic). -> - **Execution** of the full verification pipeline (loading, transformations, alignment, and verification) from configuration files may be moved to a separate package in future releases. -> - **Tests** will be added to `mxalign` (building on test datasets already integrated into `mlwp-data-loaders`) that ensure that all alignment operations work correctly (Testing notebook execution inside `mxalign` is explicitly excluded from the current roadmap). +These traits are defined and validated by [`mlwp-data-specs`](https://github.com/mlwp-tools/mlwp-data-specs) and attached to datasets by [`mlwp-data-loaders`](https://github.com/mlwp-tools/mlwp-data-loaders). `mxalign` reads them to infer how datasets should be aligned, without needing to know how they were loaded. + +`mxalign` currently supports alignment in **space** and **time**. Alignment along the **uncertainty** axis (e.g. ensemble to deterministic) is planned for a future release. ## Python API -`mxalign` provides building blocks for manual alignment, transformations, and interpolations of `xarray` datasets. This is ideal for interactive use in Jupyter notebooks or custom Python scripts. +`mxalign` provides building blocks for spatial and temporal alignment of `xarray` datasets. This is ideal for interactive use in Jupyter notebooks or custom Python scripts. ```python -import xarray as xr -from mxalign import load, align_space, align_time, transform +import mlwp_data_loaders as dl +import mxalign as mx -# Load datasets (using registered loaders) -ds_obs = load(name="observations_loader", files=["obs.nc"]) -ds_fcst = load(name="anemoi_inference", files=["forecast.nc"]) +# Load datasets — traits are attached by the loader +ds_obs = dl.load("observations_loader", files=["obs.nc"]) +ds_fcst = dl.load("anemoi_inference", files=["forecast.nc"]) # Align the forecast spatially to match the observation reference -ds_fcst_aligned_space = align_space(ds_fcst, reference=ds_obs, method="interpolation") +ds_fcst_aligned = mx.align_space(ds_fcst, reference=ds_obs, method="interpolation") # Align datasets temporally -datasets = {"obs": ds_obs, "fcst": ds_fcst_aligned_space} -aligned_datasets = align_time(datasets, method="intersection") +datasets = {"obs": ds_obs, "fcst": ds_fcst_aligned} +aligned_datasets = mx.align_time(datasets, method="intersection") ``` For a more comprehensive interactive example, check out the [introductory notebook](./examples/introduction.ipynb). ## Executing via a Configuration -For full verification pipeline execution, `mxalign` uses a YAML configuration file. This allows you to declaratively define how datasets are loaded, transformed, aligned, and verified. +`mxalign` can drive a full verification pipeline from a YAML configuration file, orchestrating dataset loading (via `mlwp-data-loaders`), transformations, alignment, and verification. ### Configuration Contents diff --git a/pyproject.toml b/pyproject.toml index 1016cdd..6e76b7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,11 @@ jobqueue = [ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.pytest.ini_options] +testpaths = ["tests"] + [dependency-groups] dev = [ "ipykernel>=7.2.0", + "pytest>=8.0", ] diff --git a/src/mxalign/__init__.py b/src/mxalign/__init__.py index 5a76c8d..61576c3 100644 --- a/src/mxalign/__init__.py +++ b/src/mxalign/__init__.py @@ -1,6 +1,3 @@ -from .properties.properties import Properties, Time, Space, Uncertainty -from .loaders.loader import load -from .loaders.registry import available_loaders, register_loader from .transformations.transform import transform from .transformations.registry import available_transformations, register_transformation from .interpolations.interpolate import interpolate @@ -9,18 +6,10 @@ from .align.space import align_space from . import accessors -from . import loaders from . import transformations from . import interpolations __all__ = [ - "Properties", - "Time", - "Space", - "Uncertainty", - "load", - "available_loaders", - "register_loader", "transform", "available_transformations", "register_transformation", diff --git a/src/mxalign/accessors/__init__.py b/src/mxalign/accessors/__init__.py index a0833a3..30d84d3 100644 --- a/src/mxalign/accessors/__init__.py +++ b/src/mxalign/accessors/__init__.py @@ -1,7 +1,3 @@ -from . import space -from . import time +from . import mx -__all__ = [ - "space", - "time", -] +__all__ = ["mx"] diff --git a/src/mxalign/accessors/mx.py b/src/mxalign/accessors/mx.py new file mode 100644 index 0000000..02692fa --- /dev/null +++ b/src/mxalign/accessors/mx.py @@ -0,0 +1,231 @@ +import numpy as np +import xarray as xr +import cartopy.crs as ccrs + +from mlwp_data_specs.api import SPACE_TRAIT_ATTR, TIME_TRAIT_ATTR +from mlwp_data_specs.specs.traits.spatial_coordinate import Space +from mlwp_data_specs.specs.traits.time_coordinate import Time + +from ..utils.projections import create_cartopy_crs, BUILTIN +from . import time as _time +from . import space as _space + + +@xr.register_dataset_accessor("mx") +class MxAccessor: + def __init__(self, ds): + self._space = Space(ds.attrs[SPACE_TRAIT_ATTR]) + self._time = Time(ds.attrs[TIME_TRAIT_ATTR]) + self._ds = ds + + # --- Space predicates --- + + def is_grid(self): + return self._space == Space.GRID + + def is_point(self): + return self._space == Space.POINT + + # --- Time predicates --- + + def is_forecast(self): + return self._time == Time.FORECAST + + def is_observation(self): + return self._time == Time.OBSERVATION + + # --- Space operations --- + + def add_crs(self, crs): + if self.is_point(): + raise ValueError("Cannot add CRS to a point dataset") + if isinstance(crs, str): + try: + crs = BUILTIN[crs.lower()] + except KeyError: + raise ValueError(f"crs: {crs} not found in supported projections") + if isinstance(crs, dict): + crs = create_cartopy_crs( + projection=crs["projection"], + kws_projection=crs["kws_projection"], + kws_globe=crs.get("kws_globe", None), + ) + return self._ds.assign_attrs({"crs": crs}) + + def add_grid_mapping(self, grid_mapping: str | dict): + if self.is_point(): + raise ValueError("Cannot add grid mapping to a point dataset") + if isinstance(grid_mapping, str): + try: + grid_mapping = BUILTIN[grid_mapping.lower()]["kws_grid"] + except KeyError: + raise ValueError( + f"grid mapping: {grid_mapping} not found in supported mappings" + ) + return self._ds.assign_attrs({"grid_mapping": grid_mapping}) + + def add_xy(self, crs=None): + if crs is not None: + self._ds = self.add_crs(crs) + + crs = self._ds.attrs.get("crs", None) + if crs is None: + raise ValueError("No CRS provided and no CRS found in dataset attributes") + + if {"longitude", "latitude"}.issubset(self._ds.dims): + raise ValueError( + "Cannot add x/y coordinates to a GRID dataset that has longitude/latitude dimensions" + ) + elif {"xc", "yc"}.issubset(self._ds.coords): + return self._ds + else: + xyz = crs.transform_points( + x=self._ds["longitude"].values, + y=self._ds["latitude"].values, + src_crs=ccrs.PlateCarree(), + ) + + if self.is_grid(): + return self._ds.assign_coords( + xc=("grid_index", xyz[:, 0]), yc=("grid_index", xyz[:, 1]) + ) + elif self.is_point(): + return self._ds.assign_coords( + xc=("point_index", xyz[:, 0]), yc=("point_index", xyz[:, 1]) + ) + else: + raise ValueError("Dataset does not have expected spatial properties") + + def is_stacked(self): + if {"xc", "yc"}.issubset(self._ds.dims) or {"longitude", "latitude"}.issubset( + self._ds.dims + ): + return False + elif "grid_index" in self._ds.dims: + return True + else: + raise ValueError("Dataset does not have expected dimensions for GRID") + + def stack(self): + if self.is_point(): + raise ValueError("POINT datasets cannot be stacked") + if self.is_stacked(): + return self._ds + else: + if {"xc", "yc"}.issubset(self._ds.dims): + dims_to_stack = ["yc", "xc"] + elif {"lat", "lon"}.issubset(self._ds.dims): + dims_to_stack = ["lat", "lon"] + else: + raise ValueError("Could not find correct dimensions to stack") + return self._ds.stack({"grid_index": dims_to_stack}).reset_index("grid_index") + + def unstack(self, crs=None, **kwargs): + if self.is_point(): + raise ValueError("POINT datasets cannot be unstacked") + if not self.is_stacked(): + return self._ds + else: + if crs: + self._ds = self.add_crs(crs) + kws_mindex = dict.fromkeys(["nx", "ny", "lon_ll", "lat_ll", "dx", "dy"]) + for key in kws_mindex.keys(): + value = kwargs.get(key, None) + if value is None: + try: + value = self._ds.attrs["grid_mapping"][key] + except KeyError: + raise KeyError( + f"Did not find a value for {key} in dataset attributes, please provide it as an argument" + ) + kws_mindex[key] = value + + mindex = self._create_multiindex(**kws_mindex) + mcoords = xr.Coordinates.from_pandas_multiindex(mindex, "grid_index") + ds_mindex = self._ds.assign_coords(mcoords) + ds_mindex.attrs["grid_mapping"] = kws_mindex + return ds_mindex.unstack() + + def _create_multiindex(self, nx, ny, lon_ll, lat_ll, dx, dy, **kwargs): + from pandas import MultiIndex + + if self._ds.sizes["grid_index"] != nx * ny: + raise ValueError( + f"Size of grid_index ({self._ds.sizes['grid_index']}) does not match nx*ny ({nx * ny})" + ) + + crs = self._ds.attrs["crs"] + x_ll, y_ll = crs.transform_point(x=lon_ll, y=lat_ll, src_crs=ccrs.PlateCarree()) + + xc = x_ll + np.arange(nx) * dx + yc = y_ll + np.arange(ny) * dy + + return MultiIndex.from_product([yc, xc], names=["yc", "xc"]) + + # --- Time operations --- + + def add_valid_time(self): + if self.is_forecast(): + return _time._add_valid_time(self._ds) + return self._ds + + # --- Alignment --- + + def align_time_with(self, ds2, lead_time="shortest"): + """Align this dataset's time axis to match ds2. + + Always uses "reference" semantics: self is reindexed to ds2's time + coordinates, with NaN-fill for times not present in self. ds2 is never + modified. For symmetric inner-join behaviour across multiple datasets use + the module-level ``align_time`` function instead. + + Parameters + ---------- + ds2 : xr.Dataset + The reference dataset to align to. + lead_time : str or timedelta or list + For Forecast→Observation: "shortest" | "longest" | specific value or list. + For Forecast→Forecast: "reference" | "intersection" | "union" (default "reference"). + Ignored for observation→* cases. + """ + if self.is_forecast() and ds2.mx.is_observation(): + return _time.align_forecast_to_observation(self._ds, ds2, lead_time=lead_time) + elif self.is_observation() and ds2.mx.is_forecast(): + return _time.align_observation_to_forecast(self._ds, ds2) + elif self.is_observation() and ds2.mx.is_observation(): + return _time.align_observation_to_observation(self._ds, ds2) + elif self.is_forecast() and ds2.mx.is_forecast(): + ff_lead_time = ( + lead_time if lead_time in ("reference", "intersection", "union") else "reference" + ) + return _time.align_forecast_to_forecast(self._ds, ds2, lead_time=ff_lead_time) + else: + raise ValueError("Cannot align datasets with unknown time properties") + + def align_space_with(self, ds2, **kwargs): + """Align this dataset's spatial grid to match ds2. + + Always uses "reference" semantics: self is interpolated or reindexed to + ds2's spatial coordinates. ds2 is never modified. + + Parameters + ---------- + ds2 : xr.Dataset + The reference dataset to align to. + method : str + Interpolation method for grid→point alignment. One of "xarray" or + "delaunay" (default "xarray"). Ignored for grid→grid. + **kwargs + Passed through to the interpolator. + """ + if self.is_grid(): + if ds2.mx.is_grid(): + return _space.align_grid_grid(self._ds, ds2, **kwargs) + elif ds2.mx.is_point(): + return _space.align_grid_point(self._ds, ds2, **kwargs) + elif self.is_point(): + if ds2.mx.is_point(): + raise NotImplementedError("Point-to-point alignment not implemented") + elif ds2.mx.is_grid(): + raise NotImplementedError("Point-to-grid alignment not implemented") + raise ValueError("Datasets do not have compatible spatial properties") diff --git a/src/mxalign/accessors/space.py b/src/mxalign/accessors/space.py index 4817f8b..15aed4d 100644 --- a/src/mxalign/accessors/space.py +++ b/src/mxalign/accessors/space.py @@ -1,205 +1,30 @@ -import xarray as xr -import cartopy.crs as ccrs import numpy as np -from ..properties.properties import Space -from ..properties.utils import properties_from_attrs - -from ..utils.projections import create_cartopy_crs, BUILTIN - -# Tolerance in degrees that the coordinates of two grids can differ while still being interpreted as the same grid. -# 0.0001 degrees ~ 10m at 45 deg latitude +# Tolerance in degrees that the coordinates of two grids can differ while still +# being interpreted as the same grid. 0.0001 degrees ~ 10m at 45 deg latitude. COORD_TOLERANCE = 0.0001 -@xr.register_dataset_accessor("space") -class SpaceAccessor: - def __init__(self, ds): - self._space = properties_from_attrs(ds).space - self._ds = ds - - def is_grid(self): - return self._space == Space.GRID - - def is_point(self): - return self._space == Space.POINT - - def add_crs(self, crs): - if self.is_point(): - raise ValueError("Cannot add CRS to a point dataset") - if isinstance(crs, str): - try: - crs = BUILTIN[crs.lower()] - except KeyError: - raise ValueError("crs: {crs} not found in supported projections") - if isinstance(crs, dict): - crs = create_cartopy_crs( - projection=crs["projection"], - kws_projection=crs["kws_projection"], - kws_globe=crs.get("kws_globe", None), - ) - return self._ds.assign_attrs({"crs": crs}) - - def add_grid_mapping(self, grid_mapping: str | dict): - if self.is_point(): - raise ValueError("Cannot add grid mapping to a point dataset") - if isinstance(grid_mapping, str): - try: - grid_mapping = BUILTIN[grid_mapping.lower()]["kws_grid"] - except KeyError: - raise ValueError( - "grid mapping: {grid_mapping} not found in supported mappings" - ) - return self._ds.assign_attrs({"grid_mapping": grid_mapping}) - - def add_xy(self, crs=None): - if crs is not None: - self._ds = self.add_crs(crs) - - crs = self._ds.attrs.get("crs", None) - - if crs is None: - raise ValueError("No CRS provided and no CRS found in dataset attributes") - - if {"longitude", "latitude"}.issubset(self._ds.dims): - raise ValueError( - "Cannot add x/y coordinates to a GRID dataset that has longitude/latitude dimensions" - ) - elif {"xc", "yc"}.issubset(self._ds.coords): - return self._ds - else: - xyz = crs.transform_points( - x=self._ds["longitude"].values, - y=self._ds["latitude"].values, - src_crs=ccrs.PlateCarree(), - ) - - if self.is_grid(): - ds_out = self._ds.assign_coords( - xc=("grid_index", xyz[:, 0]), yc=("grid_index", xyz[:, 1]) - ) - elif self.is_point(): - ds_out = self._ds.assign_coords( - xc=("point_index", xyz[:, 0]), yc=("point_index", xyz[:, 1]) - ) - else: - raise ValueError("Dataset does not have expected spatial properties") - - return ds_out - - def is_stacked(self): - if {"xc", "yc"}.issubset(self._ds.dims) or {"longitude", "latitude"}.issubset( - self._ds.dims - ): - return False - elif "grid_index" in self._ds.dims: - return True - else: - raise ValueError("Dataset does not have expected dimensions for GRID") - - def stack(self): - if self.is_point(): - raise ValueError("POINT datasets cannot be stacked") - if self.is_stacked(): - return self._ds - else: - if {"xc", "yc"}.issubset(self._ds.dims): - dims_to_stack = ["yc", "xc"] - elif {"lat", "lon"}.issubset(self._ds.dims): - dims_to_stack = ["lat", "lon"] - else: - raise ValueError("Could not find correct dimensions to stack") - return self._ds.stack({"grid_index": dims_to_stack}).reset_index("grid_index") - - def unstack(self, crs=None, **kwargs): - if self.is_point(): - raise ValueError("POINT datasets cannot be unstacked") - if not self.is_stacked(): - return self._ds - else: - if crs: - self.add_crs(crs) - kws_mindex = dict.fromkeys(["nx", "ny", "lon_ll", "lat_ll", "dx", "dy"]) - for key in kws_mindex.keys(): - value = kwargs.get(key, None) - if value is None: - try: - value = self._ds.attrs["grid_mapping"][key] - except KeyError: - raise KeyError( - f"Did not find a value for {key} in the dataset attributes, please provide it as an argument" - ) - kws_mindex[key] = value - - mindex = self._create_multiindex(**kws_mindex) - mcoords = xr.Coordinates.from_pandas_multiindex(mindex, "grid_index") - ds_mindex = self._ds.assign_coords(mcoords) - ds_mindex.attrs["grid_mapping"] = kws_mindex - return ds_mindex.unstack() - - def _create_multiindex(self, nx, ny, lon_ll, lat_ll, dx, dy, **kwargs): - from pandas import MultiIndex - - if self._ds.sizes["grid_index"] != nx * ny: - raise ValueError( - f"Size of grid_index ({self._ds.sizes['grid_index']}) does not match product of nx and ny ({nx * ny})" - ) - - crs = self._ds.attrs["crs"] - x_ll, y_ll = crs.transform_point(x=lon_ll, y=lat_ll, src_crs=ccrs.PlateCarree()) - - xc = x_ll + np.arange(nx) * dx - yc = y_ll + np.arange(ny) * dy - - mindex = MultiIndex.from_product([yc, xc], names=["yc", "xc"]) - - return mindex - - def align_with(self, ds, **kwargs): - if self.is_grid(): - if ds.space.is_grid(): - return _align_grid_grid(self._ds, ds, **kwargs) - elif ds.space.is_point(): - return _align_grid_point(self._ds, ds, **kwargs) - elif self.is_point(): - if ds.space.is_point(): - return _align_point_point(self._ds, ds, **kwargs) - elif ds.space.is_grid(): - return _align_point_grid(self._ds, ds, **kwargs) - else: - raise ValueError("Datasets do not have compatible spatial properties") - - -def _align_grid_grid(ds1, ds2, **kwargs): +def align_grid_grid(ds1, ds2, **kwargs): if np.array_equal( ds1["longitude"].values, ds2["longitude"].values ) and np.array_equal(ds1["latitude"].values, ds2["latitude"].values): - return ds1, ds2 + return ds1 elif np.allclose( ds1["longitude"].values, ds2["longitude"].values, atol=COORD_TOLERANCE ) and np.allclose( ds1["latitude"].values, ds2["latitude"].values, atol=COORD_TOLERANCE ): print( - f"Some lat-lon coordinates differ. But the difference is smaller than {COORD_TOLERANCE} degrees, considering both grids as equal" + f"Some lat-lon coordinates differ but within {COORD_TOLERANCE}°, treating as equal" ) - return ds1, ds2 + return ds1 else: raise NotImplementedError("Regridding not implemented") -def _align_grid_point(ds1, ds2, **kwargs): - from ..interpolations.interpolate import interpolate - - method = kwargs.pop("method", "xarray") - ds1 = interpolate(ds1, ds2, method, **kwargs) - - return ds1, ds2 - - -def _align_point_point(ds1, ds2, **kwargs): - raise NotImplementedError("Point selection not implemented") - +def align_grid_point(ds1, ds2, method="xarray", **kwargs): + from ..interpolations.registry import get_interpolation -def _align_point_grid(ds1, ds2, **kwargs): - raise NotImplementedError("Gridding of Point datanot implemented") + interp_cls = get_interpolation(method) + return interp_cls(ds2, **kwargs).interpolate(ds1.copy()) diff --git a/src/mxalign/accessors/time.py b/src/mxalign/accessors/time.py index 60f254e..d957c66 100644 --- a/src/mxalign/accessors/time.py +++ b/src/mxalign/accessors/time.py @@ -1,180 +1,95 @@ -import xarray as xr import numpy as np +import pandas as pd +import xarray as xr + +from mlwp_data_specs.api import TIME_TRAIT_ATTR +from mlwp_data_specs.specs.traits.time_coordinate import Time -from ..properties.properties import Time -from ..properties.utils import properties_from_attrs, update_time_property - - -@xr.register_dataset_accessor("time") -class TimeAccessor: - def __init__(self, ds): - self._time = properties_from_attrs(ds).time - self._ds = ds - - def is_forecast(self): - return self._time == Time.FORECAST - - def is_observation(self): - return self._time == Time.OBSERVATION - - def add_valid_time(self): - if self.is_forecast(): - valid_time = ( - self._ds["reference_time"].values[:, np.newaxis] - + self._ds["lead_time"].values - ) - ds_out = self._ds.assign_coords( - {"valid_time": (["reference_time", "lead_time"], valid_time)} - ) - else: - ds_out = self._ds - return ds_out - - def align_with(self, ds, **kwargs): - if self.is_forecast(): - if ds.time.is_forecast(): - return _align_forecast_forecast(self._ds, ds, **kwargs) - elif ds.time.is_observation(): - return _align_forecast_observation(self._ds, ds, **kwargs) - elif self.is_observation(): - if ds.time.is_observation(): - return _align_observation_observation(self._ds, ds, **kwargs) - elif ds.time.is_forecast(): - return _align_observation_forecast(self._ds, ds, **kwargs) - else: - raise ValueError("Datasets do not have compatible temporal properties") - - -def _align_forecast_forecast(ds1, ds2, only_common=False): - # Align the reference times - common_reference_times = ds1.indexes["reference_time"].intersection( - ds2.indexes["reference_time"] + +def _add_valid_time(ds_fcst): + valid_time = ( + ds_fcst["reference_time"].values[:, np.newaxis] + ds_fcst["lead_time"].values ) - ds1_aligned = ds1.sel(reference_time=common_reference_times) - ds2_aligned = ds2.sel(reference_time=common_reference_times) - - # Align the lead times - if only_common: - common_lead_times = ds1_aligned.indexes["lead_time"].intersection( - ds2_aligned.indexes["lead_time"] - ) - ds1_aligned = ds1_aligned.sel(lead_time=common_lead_times) - ds2_aligned = ds2_aligned.sel(lead_time=common_lead_times) - else: - non_aligning_dims = (set(ds1.dims) | set(ds2.dims)) - set(["lead_time"]) - ds1_aligned, ds2_aligned = xr.align( - ds1_aligned, ds2_aligned, join="outer", exclude=non_aligning_dims - ) - ds1_aligned = ds1_aligned.time.add_valid_time() - ds2_aligned = ds2_aligned.time.add_valid_time() - return ds1_aligned, ds2_aligned - - -def _align_forecast_observation( - ds_forecast, ds_observation, only_common=False, lead_time="start-min" -): - ds_forecast = ds_forecast.time.add_valid_time() - - # Check if reference_times are continuous - reference_time_diff = ds_forecast.reference_time.diff("reference_time").values - if not (reference_time_diff[0] == reference_time_diff).all(): - raise NotImplementedError( - "Aligning a forecast with non-continuous reference times with an observation is not implemented." - ) - if lead_time == "start-min": - min_diff = reference_time_diff[0] - ds_forecast_reduced = ds_forecast.where( - ds_forecast.lead_time < min_diff, drop=True - ) - elif lead_time == "start-max": - max_diff = ds_forecast.lead_time.max().values - reference_times = np.arange( - ds_forecast.reference_time.min().values, - ds_forecast.reference_time.max().values, - max_diff, - dtype="datetime64[ns]", - ) - ds_forecast_reduced = ds_forecast.sel(reference_time=reference_times) - else: - raise ValueError( - "Invalid value for lead_time. Expected 'start-min' or 'start-max'." - ) - - ds_forecast_stacked = ( - ds_forecast_reduced.stack(time=["reference_time", "lead_time"]) - .reset_index("time") - .swap_dims({"time": "valid_time"}) - .transpose("valid_time", ...) + return ds_fcst.assign_coords( + {"valid_time": (["reference_time", "lead_time"], valid_time)} ) - if only_common: - ds_forecast_aligned, ds_observation_aligned = xr.align( - ds_forecast_stacked, - ds_observation, - join="inner", - exclude=set(ds_forecast_stacked.coords) - | set(ds_observation.coords) - set(["valid_time"]), - ) - else: - ds_forecast_aligned, ds_observation_aligned = xr.align( - ds_forecast_stacked, - ds_observation, - join="outer", - exclude=set(ds_forecast_stacked.coords) - | set(ds_observation.coords) - set(["valid_time"]), - ) - ds_forecast_aligned = update_time_property(ds_forecast_aligned, Time.OBSERVATION) - return ds_forecast_aligned, ds_observation_aligned - - -def _align_observation_observation(ds1, ds2, only_common=False): - exclude = (set(ds1.dims) | set(ds2.dims)) - set(["valid_time"]) - if only_common: - ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="inner", exclude=exclude) + + +def align_forecast_to_observation(ds_fcst, ds_obs, lead_time="shortest"): + ds_with_vt = _add_valid_time(ds_fcst) + ds_stacked = ds_with_vt.stack(time=["reference_time", "lead_time"]).reset_index("time") + + vt_vals = ds_stacked.valid_time.values + lt_vals = ds_stacked.lead_time.values + + if lead_time in ("shortest", "longest"): + df = pd.DataFrame({"vt": vt_vals, "lt": lt_vals}) + agg = "min" if lead_time == "shortest" else "max" + is_extreme = (df.groupby("vt")["lt"].transform(agg) == df["lt"]).values + # Among entries that match the extreme lead_time, keep first per valid_time + # (handles ties: same vt + same extreme lt appearing via different ref_times) + extreme_positions = np.where(is_extreme)[0] + _, first_in_group = np.unique(vt_vals[extreme_positions], return_index=True) + positions = extreme_positions[first_in_group] + elif isinstance(lead_time, (list, np.ndarray)): + lt_set = set(np.asarray(lead_time).tolist()) + seen_vt = set() + positions = [] + for i, (vt, lt) in enumerate(zip(vt_vals, lt_vals)): + if lt in lt_set and vt not in seen_vt: + positions.append(i) + seen_vt.add(vt) + positions = np.array(positions) else: - ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="outer", exclude=exclude) - return ds1_aligned, ds2_aligned - - -def _align_observation_forecast(ds_observation, ds_forecast, only_common=False): - ds_forecast_cut = ds_forecast.time.add_valid_time() - if ( - ds_forecast_cut.reference_time.min().values - < ds_observation.valid_time.min().values - ): - ds_forecast_cut = ds_forecast_cut.sel( - reference_time=slice(ds_observation.valid_time.min().values, None) - ) - if ds_forecast_cut.valid_time.max().values > ds_observation.valid_time.max().values: - # The forecast time-step/lead times might not always align with the maximum observation time - valid_diff = ( - ds_forecast_cut["valid_time"] - (ds_observation["valid_time"].max()) - ).isel(lead_time=-1) - last_valid_index = ( - np.abs(valid_diff.where(valid_diff <= 0, drop=True)).argmin().values - ) - max_reference_time = ds_forecast_cut.isel(reference_time=last_valid_index)[ - "reference_time" - ].values - - # max_reference_time = ds_observation.valid_time.max().values - (ds_forecast_cut.lead_time.max().values - shift) - ds_forecast_cut = ds_forecast_cut.sel( - reference_time=slice(None, max_reference_time) - ) - - ds_observation_aligned = ds_observation.sel(valid_time=ds_forecast_cut.valid_time) - ds_observation_aligned = ds_observation_aligned.transpose( - "reference_time", "lead_time", ... + # single lead_time value — filter directly + positions = np.where(lt_vals == lead_time)[0] + + ds_1d = ds_stacked.isel(time=positions) + ds_1d = ds_1d.swap_dims({"time": "valid_time"}) + ds_1d = ds_1d.drop_vars( + [v for v in ["reference_time", "lead_time", "time"] if v in ds_1d.coords] ) - ds_observation_aligned = update_time_property(ds_observation_aligned, Time.FORECAST) - if only_common: - return ds_observation_aligned, ds_forecast_cut + ds_1d = ds_1d.transpose("valid_time", ...) + + ds_1d = ds_1d.reindex(valid_time=ds_obs.valid_time) + ds_1d.attrs[TIME_TRAIT_ATTR] = Time.OBSERVATION.value + return ds_1d + + +def align_observation_to_forecast(ds_obs, ds_fcst): + ds_fcst_with_vt = _add_valid_time(ds_fcst) + valid_time_2d = ds_fcst_with_vt["valid_time"] # shape (reference_time, lead_time) + + # Reindex obs onto all unique fcst valid_times (NaN-fills fcst valid_times not in obs) + fcst_vt_flat = np.unique(valid_time_2d.values.ravel()) + obs_reindexed = ds_obs.reindex(valid_time=fcst_vt_flat) + + # sel with a 2D DataArray indexer broadcasts 1D obs → (reference_time, lead_time) + ds_out = obs_reindexed.sel(valid_time=valid_time_2d) + + ds_out.attrs[TIME_TRAIT_ATTR] = Time.FORECAST.value + return ds_out + + +def align_observation_to_observation(ds1, ds2): + return ds1.reindex(valid_time=ds2.valid_time) + + +def align_forecast_to_forecast(ds1, ds2, lead_time="reference"): + ds_out = ds1.reindex(reference_time=ds2.reference_time) + + if lead_time == "reference": + ds_out = ds_out.reindex(lead_time=ds2.lead_time) + elif lead_time == "intersection": + common_lt = np.intersect1d(ds_out.lead_time.values, ds2.lead_time.values) + ds_out = ds_out.sel(lead_time=common_lt) + elif lead_time == "union": + all_lt = np.union1d(ds_out.lead_time.values, ds2.lead_time.values) + ds_out = ds_out.reindex(lead_time=all_lt) else: - ds_observation_aligned, ds_forecast_aligned = xr.align( - ds_observation_aligned, - ds_forecast.time.add_valid_time(), - join="outer", - exclude=(set(ds_observation_aligned.coords) | set(ds_forecast_cut.coords)) - - set(["reference_time", "lead_time"]), - ) - ds_observation_aligned["valid_time"] = ds_forecast_aligned["valid_time"] - return ds_observation_aligned, ds_forecast_aligned + raise ValueError(f"Unknown lead_time option for F→F alignment: {lead_time!r}") + + # Refresh valid_time + if "valid_time" in ds_out.coords: + ds_out = ds_out.drop_vars("valid_time") + return _add_valid_time(ds_out) diff --git a/src/mxalign/align/space.py b/src/mxalign/align/space.py index b18586a..e4dab55 100644 --- a/src/mxalign/align/space.py +++ b/src/mxalign/align/space.py @@ -10,7 +10,7 @@ def align_space(datasets, reference, **kwargs): else: keys = None - datasets = [ds.space.align_with(reference, **kwargs)[0] for ds in datasets] + datasets = [ds.mx.align_space_with(reference, **kwargs) for ds in datasets] if keys is None: if len(datasets) == 1: diff --git a/src/mxalign/align/time.py b/src/mxalign/align/time.py index 4653e17..542c0a2 100644 --- a/src/mxalign/align/time.py +++ b/src/mxalign/align/time.py @@ -2,61 +2,35 @@ def align_time( - datasets: list[xr.Dataset] | dict[str, xr.Dataset], return_as: str = "forecast" + datasets: list[xr.Dataset] | dict[str, xr.Dataset], + reference: str | xr.Dataset, + **kwargs, ): - if isinstance(datasets, (xr.Dataset, xr.DataArray)): - datasets = [datasets] + """Align all datasets temporally to a reference dataset. + + Each non-reference dataset is aligned by calling ``ds.mx.align_time_with(ref_ds)``. + Extra kwargs are forwarded to ``align_time_with`` (e.g. ``lead_time``, ``join``). + + Parameters + ---------- + datasets : list or dict of xr.Dataset + reference : str or xr.Dataset + Key into *datasets* dict, or an xr.Dataset to align to. + """ if isinstance(datasets, dict): - keys = datasets.keys() - datasets = datasets.values() + keys = list(datasets.keys()) + ds_list = list(datasets.values()) + ref_ds = datasets[reference] if isinstance(reference, str) else reference else: + ds_list = [datasets] if isinstance(datasets, xr.Dataset) else list(datasets) keys = None + ref_ds = reference - if return_as != "forecast": - NotImplementedError( - "Currently only temporal alignment return forecast structure is supported." - ) - - # Get the first forecast to start building the valid times - valid_times_fcst = None - valid_times_obs = None - first_fcst = True - first_obs = True - for ds in datasets: - if ds.time.is_forecast(): - if first_fcst: - valid_times_fcst = ds.time.add_valid_time()["valid_time"].to_dataset( - name="valid_times" - ) - valid_times_fcst = valid_times_fcst.assign_attrs(ds.attrs) - first_fcst = False - else: - _ds = ds.time.add_valid_time()["valid_time"].to_dataset( - name="valid_times" - ) - _ds = _ds.assign_attrs(ds.attrs) - _, valid_times_fcst = _ds.time.align_with(valid_times_fcst) - elif ds.time.is_observation(): - if first_obs: - valid_times_obs = ds["valid_time"].to_dataset(name="valid_times") - valid_times_obs = valid_times_obs.assign_attrs(ds.attrs) - first_obs = False - else: - _ds = ds["valid_time"].to_dataset(name="valid_times") - _ds = _ds.assign_attrs(ds.attrs) - _, valid_times_obs = _ds.time.align_with(valid_times_obs) + aligned = [ + ds if ds is ref_ds else ds.mx.align_time_with(ref_ds, **kwargs) + for ds in ds_list + ] - if (valid_times_obs is None) and (valid_times_fcst is None): - raise ValueError("No observations or forecasts found") - elif valid_times_fcst is None: - valid_times = valid_times_obs - elif valid_times_obs is None: - valid_times = valid_times_fcst - else: - _, valid_times = valid_times_obs.time.align_with(valid_times_fcst) - - datasets = [ds.time.align_with(valid_times)[0] for ds in datasets] - if keys is None: - return datasets - else: - return {key: value for (key, value) in zip(keys, datasets)} + if keys is not None: + return dict(zip(keys, aligned)) + return aligned[0] if len(aligned) == 1 else aligned diff --git a/src/mxalign/interpolations/base.py b/src/mxalign/interpolations/base.py index bcba616..5134a44 100644 --- a/src/mxalign/interpolations/base.py +++ b/src/mxalign/interpolations/base.py @@ -1,6 +1,6 @@ import xarray as xr -from ..properties.properties import Space -from ..properties.utils import update_space_property +from mlwp_data_specs.specs.traits.spatial_coordinate import Space +from ..utils.traits import update_space_trait class BaseInterpolator: @@ -21,7 +21,7 @@ def interpolate( self, source_dataset: xr.Dataset | xr.DataArray ) -> xr.Dataset | xr.DataArray: ds_out = self._interpolate(source_dataset) - return update_space_property(ds_out, self.target_space) + return update_space_trait(ds_out, self.target_space) def _interpolate( self, source_dataset: xr.Dataset | xr.DataArray diff --git a/src/mxalign/interpolations/delaunay.py b/src/mxalign/interpolations/delaunay.py index 164c255..3227c34 100644 --- a/src/mxalign/interpolations/delaunay.py +++ b/src/mxalign/interpolations/delaunay.py @@ -9,7 +9,8 @@ from .base import BaseInterpolator from .registry import register_interpolator -from ..properties.properties import Space + +from mlwp_data_specs.specs.traits.spatial_coordinate import Space @register_interpolator @@ -81,7 +82,7 @@ def _interpolate(self, source_dataset): latitude=self.target_dataset["latitude"], longitude=self.target_dataset["longitude"], ) - ds_out.attrs["properties"] = source_dataset.attrs["properties"] + ds_out.attrs.update(source_dataset.attrs) return ds_out diff --git a/src/mxalign/interpolations/xarray.py b/src/mxalign/interpolations/xarray.py index e3a8d29..403029c 100644 --- a/src/mxalign/interpolations/xarray.py +++ b/src/mxalign/interpolations/xarray.py @@ -1,6 +1,6 @@ from .base import BaseInterpolator from .registry import register_interpolator -from ..properties.properties import Space +from mlwp_data_specs.specs.traits.spatial_coordinate import Space import xarray as xr @@ -17,9 +17,9 @@ def _interpolate(self, source_dataset): ds_out = self._interpolate_from_latlon(source_dataset) else: - if source_dataset.space.is_stacked(): + if source_dataset.mx.is_stacked(): try: - source_dataset = source_dataset.space.unstack() + source_dataset = source_dataset.mx.unstack() except ValueError: raise ValueError( "Cannot unstack dataset, dataset must be unstacked to use xarray interpolation" @@ -46,11 +46,10 @@ def _interpolate_from_xcyc(self, source_dataset): y = xr.DataArray(xyz[:, 1], dims="point_index") ds_out = source_dataset.interp(xc=x, yc=y, **self.options) - # ).assing_coords( - # longitude=self.target_dataset["longitude"], - # latitude=self.target_dataset["latitude"] - # ) - + ds_out = ds_out.assign_coords( + longitude=("point_index", self.target_dataset["longitude"].values), + latitude=("point_index", self.target_dataset["latitude"].values), + ) return ds_out def _interpolate_from_latlon(self, source_dataset): diff --git a/src/mxalign/loaders/__init__.py b/src/mxalign/loaders/__init__.py index f04b80e..329bfa3 100644 --- a/src/mxalign/loaders/__init__.py +++ b/src/mxalign/loaders/__init__.py @@ -1,11 +1,13 @@ from . import anemoi_datasets from . import anemoi_inference from . import harp_obstable +from . import ifs_forecast from . import base __all__ = [ "anemoi_datasets", "anemoi_inference", + "ifs_forecast", "harp_obstable", "base", ] diff --git a/src/mxalign/loaders/anemoi_datasets.py b/src/mxalign/loaders/anemoi_datasets.py deleted file mode 100644 index 3ccb646..0000000 --- a/src/mxalign/loaders/anemoi_datasets.py +++ /dev/null @@ -1,92 +0,0 @@ -import numpy as np -import xarray as xr - -from .registry import register_loader -from ..properties.properties import Space, Time, Uncertainty -from .base import BaseLoader - -DROP_VARS = [ - "latitude", - "longitude", - "time", - "cos_julian_day", - "cos_latitude", - "cos_local_time", - "cos_longitude", - "insolation", - "sin_julian_day", - "sin_latitude", - "sin_local_time", - "sin_longitude", -] - -COORDS = dict(longitude="longitudes", latitude="latitudes", valid_time="dates") - -DEFAULTS = {"chunks": "auto"} - - -@register_loader -class AnemoiDatasetsLoader(BaseLoader): - name = "anemoi-datasets" - - space = Space.GRID - time = Time.OBSERVATION - uncertainty = Uncertainty.DETERMINISTIC - - def _load(self): - - if isinstance(self.files, list): - dss = [xr.open_zarr(file, consolidated=False) for file in self.files] - dss_postproc = [_postprocess(ds) for ds in dss] - ds_postproc = xr.concat(dss_postproc, dim="valid_time") - else: - ds = xr.open_zarr(self.files, consolidated=False) - ds_postproc = _postprocess(ds) - - if self.variables: - ds_selected = ds_postproc.sel(variable=self.variables) - else: - ds_selected = ds_postproc - if len(ds_selected["variable"]) > 10: - print( - f"Transforming anemoi-datasets xr.DataArray with {len(ds_postproc['variable'])} variables to xr.Dataset, this might take some time. Consider selecting the relevant variables during loading" - ) - return ds_selected.to_dataset(dim="variable") - - -def _postprocess(dataset: xr.Dataset) -> xr.Dataset: - """Post-process the dataset to add coordinates and drop unused variables. - - Args: - dataset (xr.Dataset): The input dataset to be processed. - - Returns: - xr.Dataset: The processed dataset with assigned coordinates and - attributes. - """ - - # Add coordinates - coords = { - key: dataset[value].astype("datetime64[ns]").load() - if key == "valid_time" - else dataset[value].load() - for key, value in COORDS.items() - } - for key in ("latitude", "longitude"): - coords[key] = coords[key].astype(np.float32) - - coords["variable"] = dataset.attrs["variables"] - coords["valid_time"] = coords["valid_time"].astype("datetime64[ns]") - ds_coords = dataset.assign_coords(coords) - - # Drop unused variables and remove ensemble dimension - drop_vars = [var for var in DROP_VARS if var in coords["variable"]] - - ds_pruned = ( - ds_coords["data"] - .isel(ensemble=0) - .drop_sel(variable=drop_vars) - .swap_dims({"time": "valid_time"}) - .rename({"cell": "grid_index"}) - ) - return ds_pruned diff --git a/src/mxalign/loaders/anemoi_inference.py b/src/mxalign/loaders/anemoi_inference.py deleted file mode 100644 index 68f8799..0000000 --- a/src/mxalign/loaders/anemoi_inference.py +++ /dev/null @@ -1,103 +0,0 @@ -from pathlib import Path -import xarray as xr - -from .registry import register_loader -from ..properties.properties import Space, Time, Uncertainty -from .base import BaseLoader - -DEFAULTS_NETCDF = {"chunks": "auto", "engine": "h5netcdf", "parallel": True} - -DEFAULTS_ZARR = { - "chunks": "auto", - "storage_options": {"anon": True}, -} - - -@register_loader -class AnemoiInferenceLoader(BaseLoader): - name = "anemoi-inference" - - space = Space.GRID - time = Time.FORECAST - uncertainty = Uncertainty.DETERMINISTIC - - def _load(self): - - kwargs = self.kwargs.copy() - - if isinstance(self.files, str): - if Path(self.files).suffix.lower() == ".zarr": - files = self.files - - for k, v in DEFAULTS_ZARR.items(): - kwargs[k] = self.kwargs.get(k, v) - - loader = _open_zarr - else: - files = [self.files] - - for k, v in DEFAULTS_NETCDF.items(): - kwargs[k] = self.kwargs.get(k, v) - - loader = _open_mf_dataset - else: - files = self.files - if Path(files[0]).suffix.lower() == ".zarr": - for k, v in DEFAULTS_ZARR.items(): - kwargs[k] = self.kwargs.get(k, v) - kwargs["engine"] = "zarr" - - else: - for k, v in DEFAULTS_NETCDF.items(): - kwargs[k] = self.kwargs.get(k, v) - - loader = _open_mf_dataset - - ds = loader(files, **kwargs) - return ds - - -def _open_mf_dataset(files, **kwargs): - - times = xr.open_dataset(files[0], engine=kwargs["engine"], chunks=kwargs["chunks"])[ - "time" - ].values - lead_times = times - times[0] - - ds = xr.open_mfdataset(files, preprocess=_preprocess, **kwargs) - - ds_out = ( - ds.assign_coords({"lead_time": ("time", lead_times)}) - .rename_dims({"values": "grid_index"}) - .swap_dims({"time": "lead_time"}) - ) - - return ds_out - - -def _open_zarr(files, **kwargs): - - ds = xr.open_zarr(files, **kwargs) - times = ds["time"].values - lead_times = times - times[0] - - ds_out = _preprocess(ds) - - ds_out = ( - ds_out.assign_coords({"lead_time": ("time", lead_times)}) - .rename_dims({"values": "grid_index"}) - .swap_dims({"time": "lead_time"}) - ) - - return ds_out - - -def _preprocess(ds): - ds_out = ( - ds.set_coords(["longitude", "latitude"]) - .expand_dims("reference_time") - .assign_coords({"reference_time": ("reference_time", [ds["time"].values[0]])}) - .drop_vars("time") - ) - - return ds_out diff --git a/src/mxalign/loaders/base.py b/src/mxalign/loaders/base.py index 3ab8bd9..f6f779b 100644 --- a/src/mxalign/loaders/base.py +++ b/src/mxalign/loaders/base.py @@ -47,8 +47,8 @@ def _select_variables(self, ds): return ds[self.variables] def _add_grid_mapping(self, ds): - ds = ds.space.add_crs(self.grid_mapping) - ds = ds.space.add_grid_mapping(self.grid_mapping) + ds = ds.mx.add_crs(self.grid_mapping) + ds = ds.mx.add_grid_mapping(self.grid_mapping) return ds def _get_properties(self, ds): diff --git a/src/mxalign/loaders/harp_obstable.py b/src/mxalign/loaders/harp_obstable.py deleted file mode 100644 index fc04533..0000000 --- a/src/mxalign/loaders/harp_obstable.py +++ /dev/null @@ -1,81 +0,0 @@ -import sqlite3 -import pandas as pd - -from .registry import register_loader -from ..properties.properties import Space, Time, Uncertainty -from .base import BaseLoader - -COORDS = { - "longitude": "lon", - "latitude": "lat", - "valid_time": "validdate", - "code": "SID", - "altitude": "elev", -} - - -@register_loader -class ObstableLoader(BaseLoader): - name = "harp-obstable" - - space = Space.POINT - time = Time.OBSERVATION - uncertainty = Uncertainty.DETERMINISTIC - - def _load(self): - if isinstance(self.files, list) and len(self.files > 1): - raise NotImplementedError( - "Reading from multiple SQLite-files not implemented" - ) - - conn = sqlite3.connect(self.files) - - if self.variables is None: - # Retrieve all variables - variables = [ - var - for var in pd.read_sql_query( - "SELECT * FROM SYNOP LIMIT 0", conn - ).columns - if var not in COORDS.values() - ] - print(variables) - else: - variables = self.variables - - # Read the SIDs - codes = pd.read_sql( - "SELECT SID as code, MIN(lat) AS latitude, MIN(lon) AS longitude, elev as altitude FROM SYNOP GROUP BY SID", - conn, - index_col="code", - ).to_xarray() - - print(codes) - # Read the data - query = f""" - SELECT SID as code, validdate as valid_time, {", ".join(variables)} - FROM SYNOP - """ - print(query) - df = pd.read_sql( - query, - conn, - index_col=["code", "valid_time"], - parse_dates={"valid_time": {"unit": "s"}}, - ) - print(df) - - ds = df.to_xarray() - lon_values = codes["longitude"].sel(code=ds["code"]).values - lat_values = codes["latitude"].sel(code=ds["code"]).values - alt_values = codes["altitude"].sel(code=ds["code"]).values - - ds = ds.assign_coords( - longitude=("code", lon_values), - latitude=("code", lat_values), - altitude=("code", alt_values), - ) - - return ds.rename_dims({"code": "point_index"}).transpose( - "valid_time", "point_index" - ) diff --git a/src/mxalign/loaders/ifs_forecast.py b/src/mxalign/loaders/ifs_forecast.py new file mode 100644 index 0000000..db02942 --- /dev/null +++ b/src/mxalign/loaders/ifs_forecast.py @@ -0,0 +1,67 @@ +import xarray as xr + +from .registry import register_loader +from ..properties.properties import Space, Time, Uncertainty +from .base import BaseLoader + + +@register_loader +class IFSForecastLoader(BaseLoader): + try: + import cfgrib + except Exception: + raise ImportError("Please install the cfgrib package to load IFS-Forecasts") + + name = "ifs-forecast" + + space = Space.GRID + time = Time.FORECAST + uncertainty = None + + def _load(self): + kwargs = self.kwargs.copy() + files = [self.files] if isinstance(self.files, str) else self.files + + ds = xr.open_mfdataset( + files, + combine="nested", + concat_dim="time", + chunks={ + "time": 1, + "step": -1, + "values": -1, + }, + **kwargs, + ) + + ds.coords["longitude"] = (ds.coords["longitude"] + 180.0) % 360.0 - 180.0 + + rename_dims = { + "time": "reference_time", + "step": "lead_time", + "values": "grid_index", + } + rename_vars = { + "time": "reference_time", + "step": "lead_time", + } + + if "number" in ds.dims and "number" in ds.coords: + rename_dims["number"] = "member" + rename_vars["number"] = "member" + else: + ds = ds.drop_vars("number") + + ds = ds.rename_dims({k: v for k, v in rename_dims.items() if k in ds.dims}) + ds = ds.rename_vars({k: v for k, v in rename_vars.items() if k in ds.variables}) + + if "surface" in ds.variables: + ds = ds.drop_vars("surface") + + if "member" in ds.dims: + self.uncertainty = Uncertainty.ENSEMBLE + elif "quantile" in ds.dims: + self.uncertainty = Uncertainty.QUANTILE + else: + self.uncertainty = Uncertainty.DETERMINISTIC + return ds.transpose("reference_time", "lead_time", ...) diff --git a/src/mxalign/loaders/loader.py b/src/mxalign/loaders/loader.py deleted file mode 100644 index 86c03e8..0000000 --- a/src/mxalign/loaders/loader.py +++ /dev/null @@ -1,8 +0,0 @@ -from .registry import get_loader - - -def load(name, files, variables=None, grid_mapping=None, **kwargs): - loader_cls = get_loader(name) - loader = loader_cls(files, variables, grid_mapping, **kwargs) - - return loader.load() diff --git a/src/mxalign/loaders/registry.py b/src/mxalign/loaders/registry.py deleted file mode 100644 index 505200e..0000000 --- a/src/mxalign/loaders/registry.py +++ /dev/null @@ -1,17 +0,0 @@ -_LOADERS = {} - - -def register_loader(cls): - _LOADERS[cls.name] = cls - return cls - - -def available_loaders(): - return list(_LOADERS.keys()) - - -def get_loader(name): - try: - return _LOADERS[name] - except KeyError: - raise ValueError(f"Unknown loader: {name}") diff --git a/src/mxalign/properties/__init__.py b/src/mxalign/properties/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/mxalign/properties/properties.py b/src/mxalign/properties/properties.py deleted file mode 100644 index a39be11..0000000 --- a/src/mxalign/properties/properties.py +++ /dev/null @@ -1,25 +0,0 @@ -from enum import Enum -from dataclasses import dataclass - - -class Space(str, Enum): - GRID = "grid" - POINT = "point" - - -class Time(str, Enum): - FORECAST = "forecast" - OBSERVATION = "observation" - - -class Uncertainty(str, Enum): - DETERMINISTIC = "deterministic" - ENSEMBLE = "ensemble" - QUANTILE = "quantile" - - -@dataclass(frozen=True) -class Properties: - space: Space - time: Time - uncertainty: Uncertainty = Uncertainty.DETERMINISTIC diff --git a/src/mxalign/properties/specs.py b/src/mxalign/properties/specs.py deleted file mode 100644 index b6d89bc..0000000 --- a/src/mxalign/properties/specs.py +++ /dev/null @@ -1,54 +0,0 @@ -from dataclasses import dataclass, field -from typing import Callable -from .properties import Space, Time, Uncertainty - - -@dataclass -class PropertySpec: - dim_variants: list[set[str]] = field(default_factory=list) - required_coords: set[str] = field(default_factory=set) - optional_dims: set[str] = field(default_factory=set) - optional_coords: set[str] = field(default_factory=set) - validators: list[Callable] = field(default_factory=list) - - -SPACE_SPECS = { - Space.GRID: PropertySpec( - dim_variants=[ - {"xc", "yc"}, - {"grid_index"}, - {"longitude", "latitude"}, - ], - required_coords={"longitude", "latitude"}, - optional_coords={"xc", "yc"}, - optional_dims={"member"}, - ), - Space.POINT: PropertySpec( - dim_variants=[ - {"point_index"}, - ], - required_coords={"longitude", "latitude"}, - optional_coords={"code", "elevation", "name", "country"}, - ), -} -TIME_SPECS = { - Time.FORECAST: PropertySpec( - dim_variants=[{"reference_time", "lead_time"}], - required_coords={"reference_time", "lead_time"}, - optional_coords={"valid_time"}, - ), - Time.OBSERVATION: PropertySpec( - dim_variants=[{"valid_time"}], - required_coords={"valid_time"}, - ), -} - -UNCERTAINTY_SPECS = { - Uncertainty.DETERMINISTIC: PropertySpec(), - Uncertainty.ENSEMBLE: PropertySpec( - dim_variants=[{"member"}], required_coords={"member"} - ), - Uncertainty.QUANTILE: PropertySpec( - dim_variants=[{"quantile"}], required_coords={"quantile"} - ), -} diff --git a/src/mxalign/properties/utils.py b/src/mxalign/properties/utils.py deleted file mode 100644 index 53555d3..0000000 --- a/src/mxalign/properties/utils.py +++ /dev/null @@ -1,43 +0,0 @@ -from .properties import Properties, Space, Time, Uncertainty -from .validation import validate_time_dataset, validate_space_dataset - - -def properties_to_attrs(prop: Properties) -> dict: - return { - "space": prop.space.value, - "time": prop.time.value, - "uncertainty": prop.uncertainty.value, - } - - -def properties_from_attrs(ds) -> Properties: - attrs = ds.attrs.get("properties", {}) - return Properties( - space=Space(attrs["space"]), - time=Time(attrs["time"]), - uncertainty=Uncertainty(attrs.get("uncertainty", Uncertainty.DETERMINISTIC)), - ) - - -def update_space_property(ds, prop: Space): - old_props = properties_from_attrs(ds) - new_props = Properties( - space=prop, - time=old_props.time, - uncertainty=old_props.uncertainty, - ) - validate_space_dataset(ds, new_props) - ds.attrs["properties"] = properties_to_attrs(new_props) - return ds - - -def update_time_property(ds, prop: Time): - old_props = properties_from_attrs(ds) - new_props = Properties( - space=old_props.space, - time=prop, - uncertainty=old_props.uncertainty, - ) - validate_time_dataset(ds, new_props) - ds.attrs["properties"] = properties_to_attrs(new_props) - return ds diff --git a/src/mxalign/properties/validation.py b/src/mxalign/properties/validation.py deleted file mode 100644 index dbc6d34..0000000 --- a/src/mxalign/properties/validation.py +++ /dev/null @@ -1,48 +0,0 @@ -from .specs import SPACE_SPECS, TIME_SPECS, UNCERTAINTY_SPECS - - -def _validate_dims(ds, variants): - if not variants: - return - - ds_dims = set(ds.dims) - - for variant in variants: - if variant.issubset(ds_dims): - return - - raise ValueError(f"Dataset dims {ds_dims} do not match allowed variants {variants}") - - -def _validate_coords(ds, required_coords, axis): - missing = required_coords - set(ds.coords) - if missing: - raise ValueError(f"{axis}: missing required coordinates {missing}") - - -# TIME -def validate_time_dataset(ds, properties): - time_spec = TIME_SPECS[properties.time.value] - _validate_dims(ds, time_spec.dim_variants) - _validate_coords(ds, time_spec.required_coords, "time") - - -# SPACE -def validate_space_dataset(ds, properties): - space_spec = SPACE_SPECS[properties.space.value] - _validate_dims(ds, space_spec.dim_variants) - _validate_coords(ds, space_spec.required_coords, "space") - validate_time_dataset(ds, properties) - - -# UNCERTAINTY -def validate_uncertainty_dataset(ds, properties): - uncertainty_spec = UNCERTAINTY_SPECS[properties.uncertainty.value] - _validate_dims(ds, uncertainty_spec.dim_variants) - _validate_coords(ds, uncertainty_spec.required_coords, "uncertainty") - - -def validate_dataset(ds, properties): - validate_time_dataset(ds, properties) - validate_space_dataset(ds, properties) - validate_uncertainty_dataset(ds, properties) diff --git a/src/mxalign/runner.py b/src/mxalign/runner.py index aff4398..9fc5739 100644 --- a/src/mxalign/runner.py +++ b/src/mxalign/runner.py @@ -74,7 +74,7 @@ def align(self): # align in time if config_align_time: - self.align_time(config_align_time) + self.align_time(config_align_time, reference_name=reference) else: print("Skipping temporal alignment") @@ -148,8 +148,9 @@ def verify(self): method = config.pop("method") save_metrics(method, self.metrics, **config) - def align_time(self, config): - self.datasets = align_time(self.datasets, **config) + def align_time(self, config, reference_name=None): + config = {k: v for k, v in config.items() if k != "method"} + self.datasets = align_time(self.datasets, reference=reference_name, **config) def align_space(self, reference, config): ds_ref = self.datasets[reference] @@ -160,8 +161,8 @@ def align_space(self, reference, config): def get_spatial_alignment(ds, reference): - if reference.space.is_point() and ds.space.is_grid(): + if reference.mx.is_point() and ds.mx.is_grid(): return "interpolation" - if reference.space.is_grid() and ds.space.is_grid(): + if reference.mx.is_grid() and ds.mx.is_grid(): return "regrid" return "null" diff --git a/src/mxalign/utils/save.py b/src/mxalign/utils/save.py index 1698282..b56b19c 100644 --- a/src/mxalign/utils/save.py +++ b/src/mxalign/utils/save.py @@ -4,7 +4,7 @@ class DatasetPath: def __init__(self, name, ds): self.name = name - if ds.time.is_forecast(): + if ds.mx.is_forecast(): years = ds["reference_time"].groupby(ds["reference_time"].dt.year).count() self.year = int(years.isel(year=years.argmax())["year"].values) ds_month = ds.sel(reference_time=ds.reference_time.dt.year == self.year) @@ -23,7 +23,7 @@ def __init__(self, name, ds): .count() ) self.day = int(days.isel(day=days.argmax())["day"].values) - elif ds.time.is_observation(): + elif ds.mx.is_observation(): years = ds["valid_time"].groupby(ds["valid_time"].dt.year).count() self.year = int(years.isel(year=years.argmax())["year"].values) ds_month = ds.sel(valid_time=ds.valid_time.dt.year == self.year) diff --git a/src/mxalign/utils/traits.py b/src/mxalign/utils/traits.py new file mode 100644 index 0000000..cb6b1f9 --- /dev/null +++ b/src/mxalign/utils/traits.py @@ -0,0 +1,37 @@ +from mlwp_data_specs.api import ( + TIME_TRAIT_ATTR, + SPACE_TRAIT_ATTR, + UNCERTAINTY_TRAIT_ATTR, +) + +from mlwp_data_specs.specs.traits.spatial_coordinate import Space +from mlwp_data_specs.specs.traits.spatial_coordinate import ( + validate_dataset as validate_space_dataset, +) + +from mlwp_data_specs.specs.traits.time_coordinate import Time +from mlwp_data_specs.specs.traits.time_coordinate import ( + validate_dataset as validate_time_dataset, +) + +from mlwp_data_specs.specs.traits.uncertainty import ( + validate_dataset as validate_uncertainty_dataset, +) + + +def update_space_trait(ds, new_trait: Space): + validate_space_dataset(ds, trait=new_trait) + ds.attrs[SPACE_TRAIT_ATTR] = new_trait.value + return ds + + +def update_time_trait(ds, new_trait: Time): + validate_time_dataset(ds, trait=new_trait) + ds.attrs[TIME_TRAIT_ATTR] = new_trait.value + return ds + + +def update_uncertainty_trait(ds, new_trait: Time): + validate_uncertainty_dataset(ds, trait=new_trait) + ds.attrs[UNCERTAINTY_TRAIT_ATTR] = new_trait.value + return ds diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1409430 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,51 @@ +import numpy as np +import pytest +import xarray as xr + +import mxalign # registers ds.mx accessor # noqa: F401 + +# --------------------------------------------------------------------------- +# Shared time coordinates +# --------------------------------------------------------------------------- + +T0 = np.datetime64("2020-01-01T00:00", "ns") +H6 = np.timedelta64(6, "h") + +REFERENCE_TIMES = np.array([T0, T0 + H6, T0 + 2 * H6]) +LEAD_TIMES = np.array([np.timedelta64(h, "h") for h in [0, 6, 12, 18]]) + +# Forecast value convention: values[i, j] = float(i * 10 + j) +# Row 0 (ref=T0): 0, 1, 2, 3 +# Row 1 (ref=T0+6h): 10, 11, 12, 13 +# Row 2 (ref=T0+12h): 20, 21, 22, 23 +FORECAST_VALUES = np.array([[float(i * 10 + j) for j in range(4)] for i in range(3)]) + +# Observation covers T0-6h … T0+24h (7 steps) +OBS_TIMES = np.array([T0 + i * H6 for i in range(-1, 5)]) # T0-6h … T0+24h +OBS_VALUES = np.arange(len(OBS_TIMES), dtype=float) * 10.0 # 0, 10, 20, 30, 40, 50 + + +def _props(time: str) -> dict: + return { + "mlwp_time_trait": time, + "mlwp_space_trait": "point", + "mlwp_uncertainty_trait": "deterministic", + } + + +@pytest.fixture +def ds_fcst() -> xr.Dataset: + return xr.Dataset( + {"temp": (["reference_time", "lead_time"], FORECAST_VALUES)}, + coords={"reference_time": REFERENCE_TIMES, "lead_time": LEAD_TIMES}, + attrs=_props("forecast"), + ) + + +@pytest.fixture +def ds_obs() -> xr.Dataset: + return xr.Dataset( + {"temp": ("valid_time", OBS_VALUES)}, + coords={"valid_time": OBS_TIMES}, + attrs=_props("observation"), + ) diff --git a/tests/test_align_space.py b/tests/test_align_space.py new file mode 100644 index 0000000..a621299 --- /dev/null +++ b/tests/test_align_space.py @@ -0,0 +1,154 @@ +"""Tests for ds.mx.align_space_with() covering all spatial alignment cases. + +Fixtures: + + ds_grid — 3×3 lat/lon grid, temp[i, j] = lat[i] + lon[j] + lat: 0°, 1°, 2° lon: 0°, 1°, 2° + + ds_point — 3 observation points: + point 0: (lat=0.0, lon=0.0) — exact grid node → expected temp=0.0 + point 1: (lat=0.5, lon=0.5) — interior → expected temp=1.0 + point 2: (lat=1.0, lon=1.0) — exact grid node → expected temp=2.0 + +Pure interpolation logic (values, coords) is tested in test_interpolations.py. +These tests focus on trait propagation and accessor dispatch. +""" + +import numpy as np +import pytest +import xarray as xr + +import mxalign # registers ds.mx accessor # noqa: F401 + + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + + +def _props(space, time="forecast"): + return { + "mlwp_space_trait": space, + "mlwp_time_trait": time, + "mlwp_uncertainty_trait": "deterministic", + } + + +@pytest.fixture +def ds_grid(): + lats = np.array([0.0, 1.0, 2.0]) + lons = np.array([0.0, 1.0, 2.0]) + temp = lats[:, np.newaxis] + lons[np.newaxis, :] + return xr.Dataset( + {"temp": (["latitude", "longitude"], temp)}, + coords={"latitude": lats, "longitude": lons}, + attrs=_props("grid"), + ) + + +@pytest.fixture +def ds_point(): + return xr.Dataset( + {"temp": ("point_index", np.array([0.0, 1.0, 2.0]))}, + coords={ + "latitude": ("point_index", np.array([0.0, 0.5, 1.0])), + "longitude": ("point_index", np.array([0.0, 0.5, 1.0])), + }, + attrs=_props("point", "observation"), + ) + + +# --------------------------------------------------------------------------- +# Case 1: Grid → Grid +# --------------------------------------------------------------------------- + + +class TestGridToGrid: + def test_identical_grids_return_self(self, ds_grid): + result = ds_grid.mx.align_space_with(ds_grid) + + assert result is ds_grid + + def test_within_tolerance_treated_as_equal(self, ds_grid): + ds_grid2 = ds_grid.assign_coords( + latitude=ds_grid.latitude + 1e-5, + longitude=ds_grid.longitude + 1e-5, + ).assign_attrs(_props("grid")) + result = ds_grid.mx.align_space_with(ds_grid2) + + assert result is ds_grid + + def test_different_grids_raise(self, ds_grid): + ds_grid2 = ds_grid.assign_coords( + latitude=ds_grid.latitude + 10.0, + ).assign_attrs(_props("grid")) + + with pytest.raises(NotImplementedError): + ds_grid.mx.align_space_with(ds_grid2) + + def test_result_stays_grid(self, ds_grid): + result = ds_grid.mx.align_space_with(ds_grid) + + assert result.mx.is_grid() + assert not result.mx.is_point() + + +# --------------------------------------------------------------------------- +# Case 2: Grid → Point (xarray interpolator) +# --------------------------------------------------------------------------- + + +class TestGridToPoint: + def test_result_has_point_trait(self, ds_grid, ds_point): + result = ds_grid.mx.align_space_with(ds_point, method="xarray") + + assert result.mx.is_point() + assert not result.mx.is_grid() + + def test_result_has_point_index_dim(self, ds_grid, ds_point): + result = ds_grid.mx.align_space_with(ds_point, method="xarray") + + assert "point_index" in result.dims + + def test_result_has_target_latlon_coords(self, ds_grid, ds_point): + result = ds_grid.mx.align_space_with(ds_point, method="xarray") + + np.testing.assert_array_equal( + result["latitude"].values, ds_point["latitude"].values + ) + np.testing.assert_array_equal( + result["longitude"].values, ds_point["longitude"].values + ) + + def test_interpolated_values_xarray(self, ds_grid, ds_point): + result = ds_grid.mx.align_space_with(ds_point, method="xarray") + + assert result["temp"].isel(point_index=0).item() == pytest.approx(0.0) + assert result["temp"].isel(point_index=1).item() == pytest.approx(1.0) + assert result["temp"].isel(point_index=2).item() == pytest.approx(2.0) + + def test_interpolated_values_delaunay(self, ds_grid, ds_point): + ds_stacked = ds_grid.stack(grid_index=["latitude", "longitude"]).reset_index( + "grid_index" + ) + ds_stacked.attrs.update(_props("grid")) + result = ds_stacked.mx.align_space_with(ds_point, method="delaunay") + + assert result["temp"].isel(point_index=0).item() == pytest.approx(0.0) + assert result["temp"].isel(point_index=1).item() == pytest.approx(1.0) + assert result["temp"].isel(point_index=2).item() == pytest.approx(2.0) + + +# --------------------------------------------------------------------------- +# Case 3: Point → * (not implemented) +# --------------------------------------------------------------------------- + + +class TestPointAlignmentNotImplemented: + def test_point_to_grid_raises(self, ds_point, ds_grid): + with pytest.raises(NotImplementedError): + ds_point.mx.align_space_with(ds_grid) + + def test_point_to_point_raises(self, ds_point): + with pytest.raises(NotImplementedError): + ds_point.mx.align_space_with(ds_point) diff --git a/tests/test_align_time.py b/tests/test_align_time.py new file mode 100644 index 0000000..628c875 --- /dev/null +++ b/tests/test_align_time.py @@ -0,0 +1,334 @@ +"""Tests for ds.mx.align_time_with() covering all four alignment cases. + +Fixtures (defined in conftest.py): + + ds_fcst — 3 reference_times × 4 lead_times, values[i,j] = float(i*10 + j) + ref: T0, T0+6h, T0+12h + lead: 0h, 6h, 12h, 18h + + ds_obs — 6 valid_times from T0-6h to T0+24h (step 6h) + values: 0, 10, 20, 30, 40, 50 (each 10 apart for easy reading) + +Valid-time coverage from ds_fcst: + T0 → only (ref=T0, lead=0h) = 0 + T0+6h → (T0,6h)=1 or (T0+6h,0h)=10 + T0+12h → (T0,12h)=2 or (T0+6h,6h)=11 or (T0+12h,0h)=20 + T0+18h → (T0,18h)=3 or (T0+6h,12h)=12 or (T0+12h,6h)=21 + T0+24h → (T0+6h,18h)=13 or (T0+12h,12h)=22 + T0+30h → only (T0+12h,18h)=23 [not in obs] +""" + +import numpy as np +import pytest +import xarray as xr + +T0 = np.datetime64("2020-01-01T00:00", "ns") +H6 = np.timedelta64(6, "h") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _props(time: str) -> dict: + return { + "mlwp_time_trait": time, + "mlwp_space_trait": "point", + "mlwp_uncertainty_trait": "deterministic", + } + + +def obs(valid_times, values): + return xr.Dataset( + {"temp": ("valid_time", np.asarray(values, dtype=float))}, + coords={"valid_time": np.asarray(valid_times)}, + attrs=_props("observation"), + ) + + +def fcst(reference_times, lead_times, values): + return xr.Dataset( + {"temp": (["reference_time", "lead_time"], np.asarray(values, dtype=float))}, + coords={ + "reference_time": np.asarray(reference_times), + "lead_time": np.asarray(lead_times), + }, + attrs=_props("forecast"), + ) + + +# --------------------------------------------------------------------------- +# Case 1: Forecast → Observation +# --------------------------------------------------------------------------- + + +class TestForecastToObservation: + def test_shortest_lead_time(self, ds_fcst, ds_obs): + result = ds_fcst.mx.align_time_with(ds_obs, lead_time="shortest") + + assert result.mx.is_observation() + assert list(result.valid_time.values) == list(ds_obs.valid_time.values) + + # T0-6h has no forecast coverage → NaN + assert np.isnan(result["temp"].sel(valid_time=T0 - H6).item()) + + # For each covered time, shortest lead_time wins + assert result["temp"].sel(valid_time=T0).item() == 0.0 # only (T0, 0h) + assert ( + result["temp"].sel(valid_time=T0 + H6).item() == 10.0 + ) # (T0+6h, 0h) beats (T0, 6h) + assert ( + result["temp"].sel(valid_time=T0 + 2 * H6).item() == 20.0 + ) # (T0+12h, 0h) is shortest + assert ( + result["temp"].sel(valid_time=T0 + 3 * H6).item() == 21.0 + ) # (T0+12h, 6h) is shortest + assert ( + result["temp"].sel(valid_time=T0 + 4 * H6).item() == 22.0 + ) # (T0+12h, 12h) is shortest + + def test_longest_lead_time(self, ds_fcst, ds_obs): + result = ds_fcst.mx.align_time_with(ds_obs, lead_time="longest") + + assert result.mx.is_observation() + assert np.isnan(result["temp"].sel(valid_time=T0 - H6).item()) + + assert result["temp"].sel(valid_time=T0).item() == 0.0 # only one entry + assert ( + result["temp"].sel(valid_time=T0 + H6).item() == 1.0 + ) # (T0, 6h) beats (T0+6h, 0h) + assert ( + result["temp"].sel(valid_time=T0 + 2 * H6).item() == 2.0 + ) # (T0, 12h) is longest + assert ( + result["temp"].sel(valid_time=T0 + 3 * H6).item() == 3.0 + ) # (T0, 18h) is longest + assert ( + result["temp"].sel(valid_time=T0 + 4 * H6).item() == 13.0 + ) # (T0+6h, 18h) beats (T0+12h, 12h) + + def test_specific_lead_time(self, ds_fcst, ds_obs): + lt = np.timedelta64(6, "h") + result = ds_fcst.mx.align_time_with(ds_obs, lead_time=lt) + + assert result.mx.is_observation() + # Only T0+6h, T0+12h, T0+18h are produced by lead_time=6h + assert np.isnan(result["temp"].sel(valid_time=T0 - H6).item()) + assert np.isnan(result["temp"].sel(valid_time=T0).item()) + assert result["temp"].sel(valid_time=T0 + H6).item() == 1.0 # (T0, 6h) + assert result["temp"].sel(valid_time=T0 + 2 * H6).item() == 11.0 # (T0+6h, 6h) + assert result["temp"].sel(valid_time=T0 + 3 * H6).item() == 21.0 # (T0+12h, 6h) + assert np.isnan(result["temp"].sel(valid_time=T0 + 4 * H6).item()) + + def test_nan_filled_for_times_not_in_forecast(self, ds_fcst, ds_obs): + result = ds_fcst.mx.align_time_with(ds_obs, lead_time="shortest") + # T0-6h is in obs but never produced by any (ref_time, lead_time) pair + assert np.isnan(result["temp"].sel(valid_time=T0 - H6).item()) + + def test_result_has_observation_property(self, ds_fcst, ds_obs): + result = ds_fcst.mx.align_time_with(ds_obs, lead_time="shortest") + assert result.mx.is_observation() + assert not result.mx.is_forecast() + + +# --------------------------------------------------------------------------- +# Case 2: Observation → Forecast +# --------------------------------------------------------------------------- + + +class TestObservationToForecast: + def test_values_placed_at_correct_positions(self, ds_obs, ds_fcst): + result = ds_obs.mx.align_time_with(ds_fcst) + + assert result.mx.is_forecast() + assert set(result.dims) == {"reference_time", "lead_time"} + + # obs values: T0→10, T0+6h→20, T0+12h→30, T0+18h→40, T0+24h→50 + assert ( + result["temp"] + .sel(reference_time=T0, lead_time=np.timedelta64(0, "h")) + .item() + == 10.0 + ) + assert ( + result["temp"] + .sel(reference_time=T0, lead_time=np.timedelta64(6, "h")) + .item() + == 20.0 + ) + assert ( + result["temp"] + .sel(reference_time=T0 + H6, lead_time=np.timedelta64(6, "h")) + .item() + == 30.0 + ) + assert ( + result["temp"] + .sel(reference_time=T0 + H6, lead_time=np.timedelta64(18, "h")) + .item() + == 50.0 + ) + + def test_nan_where_obs_missing(self, ds_obs, ds_fcst): + result = ds_obs.mx.align_time_with(ds_fcst) + + # T0+30h is not in obs; it appears at (T0+12h, lead=18h) + assert np.isnan( + result["temp"] + .sel( + reference_time=T0 + 2 * H6, + lead_time=np.timedelta64(18, "h"), + ) + .item() + ) + + def test_obs_value_repeated_for_shared_valid_times(self, ds_obs, ds_fcst): + result = ds_obs.mx.align_time_with(ds_fcst) + + # T0+12h appears at (T0,12h), (T0+6h,6h), (T0+12h,0h) — all should equal 30.0 + assert ( + result["temp"] + .sel(reference_time=T0, lead_time=np.timedelta64(12, "h")) + .item() + == 30.0 + ) + assert ( + result["temp"] + .sel(reference_time=T0 + H6, lead_time=np.timedelta64(6, "h")) + .item() + == 30.0 + ) + assert ( + result["temp"] + .sel(reference_time=T0 + 2 * H6, lead_time=np.timedelta64(0, "h")) + .item() + == 30.0 + ) + + def test_result_has_forecast_property(self, ds_obs, ds_fcst): + result = ds_obs.mx.align_time_with(ds_fcst) + assert result.mx.is_forecast() + assert not result.mx.is_observation() + + +# --------------------------------------------------------------------------- +# Case 3: Observation → Observation +# --------------------------------------------------------------------------- + + +class TestObservationToObservation: + @pytest.fixture + def ds_obs1(self): + times = np.array([T0, T0 + H6, T0 + 2 * H6, T0 + 3 * H6]) + return obs(times, [0.0, 1.0, 2.0, 3.0]) + + @pytest.fixture + def ds_obs2(self): + times = np.array([T0 + H6, T0 + 2 * H6, T0 + 3 * H6, T0 + 4 * H6]) + return obs(times, [10.0, 20.0, 30.0, 40.0]) + + def test_reindexes_to_ds2_valid_times(self, ds_obs1, ds_obs2): + result = ds_obs1.mx.align_time_with(ds_obs2) + + assert list(result.valid_time.values) == list(ds_obs2.valid_time.values) + # T0+24h is in ds2 but not ds1 → NaN + assert np.isnan(result["temp"].sel(valid_time=T0 + 4 * H6).item()) + # Overlapping times retain ds1 values + assert result["temp"].sel(valid_time=T0 + H6).item() == 1.0 + assert result["temp"].sel(valid_time=T0 + 2 * H6).item() == 2.0 + assert result["temp"].sel(valid_time=T0 + 3 * H6).item() == 3.0 + + def test_ds1_only_times_are_dropped(self, ds_obs1, ds_obs2): + result = ds_obs1.mx.align_time_with(ds_obs2) + assert T0 not in result.valid_time.values # only in ds1 + + def test_result_stays_observation(self, ds_obs1, ds_obs2): + result = ds_obs1.mx.align_time_with(ds_obs2) + assert result.mx.is_observation() + + +# --------------------------------------------------------------------------- +# Case 4: Forecast → Forecast +# --------------------------------------------------------------------------- + + +class TestForecastToForecast: + @pytest.fixture + def ds_fcst2(self): + ref = np.array([T0 + H6, T0 + 2 * H6, T0 + 3 * H6]) + lead = np.array([np.timedelta64(6, "h"), np.timedelta64(12, "h")]) + values = np.zeros((3, 2)) + return fcst(ref, lead, values) + + def test_reindexes_to_ds2_reference_times(self, ds_fcst, ds_fcst2): + result = ds_fcst.mx.align_time_with(ds_fcst2, lead_time="reference") + + np.testing.assert_array_equal( + result.reference_time.values, ds_fcst2.reference_time.values + ) + + def test_lead_time_reference_drops_ds1_only_leads(self, ds_fcst, ds_fcst2): + result = ds_fcst.mx.align_time_with(ds_fcst2, lead_time="reference") + + np.testing.assert_array_equal( + result.lead_time.values, ds_fcst2.lead_time.values + ) + assert np.timedelta64(0, "h") not in result.lead_time.values + assert np.timedelta64(18, "h") not in result.lead_time.values + + def test_nan_for_ref_times_not_in_ds1(self, ds_fcst, ds_fcst2): + result = ds_fcst.mx.align_time_with(ds_fcst2, lead_time="reference") + + # T0+18h is in ds_fcst2 but not ds_fcst → all NaN + assert np.isnan(result["temp"].sel(reference_time=T0 + 3 * H6).values).all() + + def test_values_preserved_for_common_ref_times(self, ds_fcst, ds_fcst2): + result = ds_fcst.mx.align_time_with(ds_fcst2, lead_time="reference") + + assert ( + result["temp"] + .sel(reference_time=T0 + H6, lead_time=np.timedelta64(6, "h")) + .item() + == 11.0 + ) + assert ( + result["temp"] + .sel(reference_time=T0 + H6, lead_time=np.timedelta64(12, "h")) + .item() + == 12.0 + ) + assert ( + result["temp"] + .sel(reference_time=T0 + 2 * H6, lead_time=np.timedelta64(6, "h")) + .item() + == 21.0 + ) + assert ( + result["temp"] + .sel(reference_time=T0 + 2 * H6, lead_time=np.timedelta64(12, "h")) + .item() + == 22.0 + ) + + def test_lead_time_intersection_keeps_common_leads(self, ds_fcst, ds_fcst2): + result = ds_fcst.mx.align_time_with(ds_fcst2, lead_time="intersection") + + expected_lead = np.array([np.timedelta64(6, "h"), np.timedelta64(12, "h")]) + np.testing.assert_array_equal(result.lead_time.values, expected_lead) + assert np.timedelta64(0, "h") not in result.lead_time.values + assert np.timedelta64(18, "h") not in result.lead_time.values + + def test_lead_time_union_keeps_all_leads(self, ds_fcst, ds_fcst2): + result = ds_fcst.mx.align_time_with(ds_fcst2, lead_time="union") + + expected_lead = np.array([np.timedelta64(h, "h") for h in [0, 6, 12, 18]]) + np.testing.assert_array_equal(result.lead_time.values, expected_lead) + + def test_result_has_valid_time_coord(self, ds_fcst, ds_fcst2): + result = ds_fcst.mx.align_time_with(ds_fcst2, lead_time="reference") + assert "valid_time" in result.coords + + def test_result_stays_forecast(self, ds_fcst, ds_fcst2): + result = ds_fcst.mx.align_time_with(ds_fcst2, lead_time="reference") + assert result.mx.is_forecast() diff --git a/tests/test_interpolations.py b/tests/test_interpolations.py new file mode 100644 index 0000000..6c2fe3d --- /dev/null +++ b/tests/test_interpolations.py @@ -0,0 +1,130 @@ +"""Tests for XarrayInterpolator and DelaunayInterpolator core logic. + +Both interpolators take a stacked or lat/lon-dim source GRID dataset and +produce a POINT dataset aligned to target point locations. + +Fixtures use temp = lat + lon, a linear function, so both bilinear (xarray) +and linear barycentric (Delaunay) interpolation reproduce it exactly. + +Trait / accessor-level tests live in test_align_space.py. +""" + +import numpy as np +import pytest +import xarray as xr + +from mxalign.interpolations.delaunay import DelaunayInterpolator +from mxalign.interpolations.xarray import XarrayInterpolator + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def grid_latlon(): + """3×3 regular lat/lon grid. temp[i, j] = lat[i] + lon[j].""" + lats = np.array([0.0, 1.0, 2.0]) + lons = np.array([0.0, 1.0, 2.0]) + temp = lats[:, np.newaxis] + lons[np.newaxis, :] + return xr.Dataset( + {"temp": (["latitude", "longitude"], temp)}, + coords={"latitude": lats, "longitude": lons}, + ) + + +@pytest.fixture +def grid_stacked(grid_latlon): + """grid_latlon stacked to grid_index (required by DelaunayInterpolator).""" + return grid_latlon.stack(grid_index=["latitude", "longitude"]).reset_index("grid_index") + + +@pytest.fixture +def target_points(): + """3 target points: two at grid nodes, one at an interior location.""" + return xr.Dataset( + coords={ + "latitude": ("point_index", np.array([0.0, 0.5, 1.0])), + "longitude": ("point_index", np.array([0.0, 0.5, 1.0])), + }, + ) + + +# --------------------------------------------------------------------------- +# XarrayInterpolator +# --------------------------------------------------------------------------- + + +class TestXarrayInterpolator: + def test_values_at_grid_nodes(self, grid_latlon, target_points): + result = XarrayInterpolator(target_points)._interpolate(grid_latlon) + + # lat=0, lon=0 → 0+0=0; lat=1, lon=1 → 1+1=2 + assert result["temp"].isel(point_index=0).item() == pytest.approx(0.0) + assert result["temp"].isel(point_index=2).item() == pytest.approx(2.0) + + def test_value_at_interior_point(self, grid_latlon, target_points): + result = XarrayInterpolator(target_points)._interpolate(grid_latlon) + + # lat=0.5, lon=0.5 → 0.5+0.5=1.0 (exact for bilinear on linear function) + assert result["temp"].isel(point_index=1).item() == pytest.approx(1.0) + + def test_output_has_point_index_dim(self, grid_latlon, target_points): + result = XarrayInterpolator(target_points)._interpolate(grid_latlon) + + assert "point_index" in result.dims + + def test_output_has_latlon_coords_from_target(self, grid_latlon, target_points): + result = XarrayInterpolator(target_points)._interpolate(grid_latlon) + + np.testing.assert_array_equal(result["latitude"].values, target_points["latitude"].values) + np.testing.assert_array_equal(result["longitude"].values, target_points["longitude"].values) + + +# --------------------------------------------------------------------------- +# DelaunayInterpolator +# --------------------------------------------------------------------------- + + +class TestDelaunayInterpolator: + def test_values_at_grid_nodes(self, grid_stacked, target_points): + result = DelaunayInterpolator(target_points)._interpolate(grid_stacked) + + assert result["temp"].isel(point_index=0).item() == pytest.approx(0.0) + assert result["temp"].isel(point_index=2).item() == pytest.approx(2.0) + + def test_value_at_interior_point(self, grid_stacked, target_points): + result = DelaunayInterpolator(target_points)._interpolate(grid_stacked) + + # Barycentric interpolation is exact for any linear function + assert result["temp"].isel(point_index=1).item() == pytest.approx(1.0) + + def test_output_has_point_index_dim(self, grid_stacked, target_points): + result = DelaunayInterpolator(target_points)._interpolate(grid_stacked) + + assert "point_index" in result.dims + + def test_output_has_latlon_coords_from_target(self, grid_stacked, target_points): + result = DelaunayInterpolator(target_points)._interpolate(grid_stacked) + + np.testing.assert_array_equal(result["latitude"].values, target_points["latitude"].values) + np.testing.assert_array_equal(result["longitude"].values, target_points["longitude"].values) + + def test_weight_matrix_is_cached(self, grid_stacked, target_points): + interp = DelaunayInterpolator(target_points) + interp._interpolate(grid_stacked) + interp._interpolate(grid_stacked) + + assert len(interp._W_cache) == 1 + + def test_outside_convex_hull_is_nan(self, grid_stacked): + far_point = xr.Dataset( + coords={ + "latitude": ("point_index", np.array([10.0])), + "longitude": ("point_index", np.array([10.0])), + }, + ) + result = DelaunayInterpolator(far_point)._interpolate(grid_stacked) + + assert np.isnan(result["temp"].isel(point_index=0).item())