Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,42 @@

## What is this?

`mxalign` is an `xarray`-based package designed for the alignment and verification of meteorological datasets. It standardizes operations across datasets by attaching properties along three main axes:
- **Space:** Grid or point-based data
- **Time:** Forecasts, observations, or climatology
- **Uncertainty:** Deterministic, ensemble, or quantile forecasts
`mxalign` is an `xarray`-based package for aligning meteorological datasets. It operates on datasets that carry **traits** — metadata attributes that describe the nature of a dataset along three axes:

Currently, `mxalign` also acts as a full execution engine. It can load datasets (e.g., Anemoi inference outputs, observation datasets), apply transformations, align datasets in both space and time to match a reference, safely broadcast NaNs, and execute verification metrics on scaled Dask clusters (Local or Slurm).
`mxalign` is an `xarray`-based package for aligning meteorological datasets. It operates on datasets that carry **traits** — metadata attributes that describe the nature of a dataset along three axes:
- **Space:** `grid` or `point`
- **Time:** `forecast`, `observation`, or `climatology`
- **Uncertainty:** `deterministic`, `ensemble`, or `quantile`

> ⚠️ **Roadmap & Future Architecture Changes (planned for v0.2.0):**
> Currently, `mxalign` handles both alignment and the execution of the verification tooling pipeline, including loading and validation. In the upcoming `v0.2.0` release, this architecture will be refactored:
> - **Loading** will be split out into [`mlwp-data-loaders`](https://github.com/mlwp-tools/mlwp-data-loaders).
> - **Validation** of loaded `xr.Dataset`s will be moved to [`mlwp-data-specs`](https://github.com/mlwp-tools/mlwp-data-specs) (which will contain the requirements for each of the dataset traits and the validation logic).
> - **Execution** of the full verification pipeline (loading, transformations, alignment, and verification) from configuration files may be moved to a separate package in future releases.
> - **Tests** will be added to `mxalign` (building on test datasets already integrated into `mlwp-data-loaders`) that ensure that all alignment operations work correctly (Testing notebook execution inside `mxalign` is explicitly excluded from the current roadmap).
These traits are defined and validated by [`mlwp-data-specs`](https://github.com/mlwp-tools/mlwp-data-specs) and attached to datasets by [`mlwp-data-loaders`](https://github.com/mlwp-tools/mlwp-data-loaders). `mxalign` reads them to infer how datasets should be aligned, without needing to know how they were loaded.

`mxalign` currently supports alignment in **space** and **time**. Alignment along the **uncertainty** axis (e.g. ensemble to deterministic) is planned for a future release.

## Python API

`mxalign` provides building blocks for manual alignment, transformations, and interpolations of `xarray` datasets. This is ideal for interactive use in Jupyter notebooks or custom Python scripts.
`mxalign` provides building blocks for spatial and temporal alignment of `xarray` datasets. This is ideal for interactive use in Jupyter notebooks or custom Python scripts.

```python
import xarray as xr
from mxalign import load, align_space, align_time, transform
import mlwp_data_loaders as dl
import mxalign as mx

# Load datasets (using registered loaders)
ds_obs = load(name="observations_loader", files=["obs.nc"])
ds_fcst = load(name="anemoi_inference", files=["forecast.nc"])
# Load datasets — traits are attached by the loader
ds_obs = dl.load("observations_loader", files=["obs.nc"])
ds_fcst = dl.load("anemoi_inference", files=["forecast.nc"])

# Align the forecast spatially to match the observation reference
ds_fcst_aligned_space = align_space(ds_fcst, reference=ds_obs, method="interpolation")
ds_fcst_aligned = mx.align_space(ds_fcst, reference=ds_obs, method="interpolation")

# Align datasets temporally
datasets = {"obs": ds_obs, "fcst": ds_fcst_aligned_space}
aligned_datasets = align_time(datasets, method="intersection")
datasets = {"obs": ds_obs, "fcst": ds_fcst_aligned}
aligned_datasets = mx.align_time(datasets, method="intersection")
```

For a more comprehensive interactive example, check out the [introductory notebook](./examples/introduction.ipynb).

## Executing via a Configuration

For full verification pipeline execution, `mxalign` uses a YAML configuration file. This allows you to declaratively define how datasets are loaded, transformed, aligned, and verified.
`mxalign` can drive a full verification pipeline from a YAML configuration file, orchestrating dataset loading (via `mlwp-data-loaders`), transformations, alignment, and verification.

### Configuration Contents

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ jobqueue = [
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.pytest.ini_options]
testpaths = ["tests"]

[dependency-groups]
dev = [
"ipykernel>=7.2.0",
"pytest>=8.0",
]
11 changes: 0 additions & 11 deletions src/mxalign/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from .properties.properties import Properties, Time, Space, Uncertainty
from .loaders.loader import load
from .loaders.registry import available_loaders, register_loader
from .transformations.transform import transform
from .transformations.registry import available_transformations, register_transformation
from .interpolations.interpolate import interpolate
Expand All @@ -9,18 +6,10 @@
from .align.space import align_space

from . import accessors
from . import loaders
from . import transformations
from . import interpolations

__all__ = [
"Properties",
"Time",
"Space",
"Uncertainty",
"load",
"available_loaders",
"register_loader",
"transform",
"available_transformations",
"register_transformation",
Expand Down
8 changes: 2 additions & 6 deletions src/mxalign/accessors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from . import space
from . import time
from . import mx

__all__ = [
"space",
"time",
]
__all__ = ["mx"]
231 changes: 231 additions & 0 deletions src/mxalign/accessors/mx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import numpy as np
import xarray as xr
import cartopy.crs as ccrs

from mlwp_data_specs.api import SPACE_TRAIT_ATTR, TIME_TRAIT_ATTR
from mlwp_data_specs.specs.traits.spatial_coordinate import Space
from mlwp_data_specs.specs.traits.time_coordinate import Time

from ..utils.projections import create_cartopy_crs, BUILTIN
from . import time as _time
from . import space as _space


@xr.register_dataset_accessor("mx")
class MxAccessor:
def __init__(self, ds):
self._space = Space(ds.attrs[SPACE_TRAIT_ATTR])
self._time = Time(ds.attrs[TIME_TRAIT_ATTR])
self._ds = ds

# --- Space predicates ---

def is_grid(self):
return self._space == Space.GRID

def is_point(self):
return self._space == Space.POINT

# --- Time predicates ---

def is_forecast(self):
return self._time == Time.FORECAST

def is_observation(self):
return self._time == Time.OBSERVATION

# --- Space operations ---

def add_crs(self, crs):
if self.is_point():
raise ValueError("Cannot add CRS to a point dataset")
if isinstance(crs, str):
try:
crs = BUILTIN[crs.lower()]
except KeyError:
raise ValueError(f"crs: {crs} not found in supported projections")
if isinstance(crs, dict):
crs = create_cartopy_crs(
projection=crs["projection"],
kws_projection=crs["kws_projection"],
kws_globe=crs.get("kws_globe", None),
)
return self._ds.assign_attrs({"crs": crs})

def add_grid_mapping(self, grid_mapping: str | dict):
if self.is_point():
raise ValueError("Cannot add grid mapping to a point dataset")
if isinstance(grid_mapping, str):
try:
grid_mapping = BUILTIN[grid_mapping.lower()]["kws_grid"]
except KeyError:
raise ValueError(
f"grid mapping: {grid_mapping} not found in supported mappings"
)
return self._ds.assign_attrs({"grid_mapping": grid_mapping})

def add_xy(self, crs=None):
if crs is not None:
self._ds = self.add_crs(crs)

crs = self._ds.attrs.get("crs", None)
if crs is None:
raise ValueError("No CRS provided and no CRS found in dataset attributes")

if {"longitude", "latitude"}.issubset(self._ds.dims):
raise ValueError(
"Cannot add x/y coordinates to a GRID dataset that has longitude/latitude dimensions"
)
elif {"xc", "yc"}.issubset(self._ds.coords):
return self._ds
else:
xyz = crs.transform_points(
x=self._ds["longitude"].values,
y=self._ds["latitude"].values,
src_crs=ccrs.PlateCarree(),
)

if self.is_grid():
return self._ds.assign_coords(
xc=("grid_index", xyz[:, 0]), yc=("grid_index", xyz[:, 1])
)
elif self.is_point():
return self._ds.assign_coords(
xc=("point_index", xyz[:, 0]), yc=("point_index", xyz[:, 1])
)
else:
raise ValueError("Dataset does not have expected spatial properties")

def is_stacked(self):
if {"xc", "yc"}.issubset(self._ds.dims) or {"longitude", "latitude"}.issubset(
self._ds.dims
):
return False
elif "grid_index" in self._ds.dims:
return True
else:
raise ValueError("Dataset does not have expected dimensions for GRID")

def stack(self):
if self.is_point():
raise ValueError("POINT datasets cannot be stacked")
if self.is_stacked():
return self._ds
else:
if {"xc", "yc"}.issubset(self._ds.dims):
dims_to_stack = ["yc", "xc"]
elif {"lat", "lon"}.issubset(self._ds.dims):
dims_to_stack = ["lat", "lon"]
else:
raise ValueError("Could not find correct dimensions to stack")
return self._ds.stack({"grid_index": dims_to_stack}).reset_index("grid_index")

def unstack(self, crs=None, **kwargs):
if self.is_point():
raise ValueError("POINT datasets cannot be unstacked")
if not self.is_stacked():
return self._ds
else:
if crs:
self._ds = self.add_crs(crs)
kws_mindex = dict.fromkeys(["nx", "ny", "lon_ll", "lat_ll", "dx", "dy"])
for key in kws_mindex.keys():
value = kwargs.get(key, None)
if value is None:
try:
value = self._ds.attrs["grid_mapping"][key]
except KeyError:
raise KeyError(
f"Did not find a value for {key} in dataset attributes, please provide it as an argument"
)
kws_mindex[key] = value

mindex = self._create_multiindex(**kws_mindex)
mcoords = xr.Coordinates.from_pandas_multiindex(mindex, "grid_index")
ds_mindex = self._ds.assign_coords(mcoords)
ds_mindex.attrs["grid_mapping"] = kws_mindex
return ds_mindex.unstack()

def _create_multiindex(self, nx, ny, lon_ll, lat_ll, dx, dy, **kwargs):
from pandas import MultiIndex

if self._ds.sizes["grid_index"] != nx * ny:
raise ValueError(
f"Size of grid_index ({self._ds.sizes['grid_index']}) does not match nx*ny ({nx * ny})"
)

crs = self._ds.attrs["crs"]
x_ll, y_ll = crs.transform_point(x=lon_ll, y=lat_ll, src_crs=ccrs.PlateCarree())

xc = x_ll + np.arange(nx) * dx
yc = y_ll + np.arange(ny) * dy

return MultiIndex.from_product([yc, xc], names=["yc", "xc"])

# --- Time operations ---

def add_valid_time(self):
if self.is_forecast():
return _time._add_valid_time(self._ds)
return self._ds

# --- Alignment ---

def align_time_with(self, ds2, lead_time="shortest"):
"""Align this dataset's time axis to match ds2.

Always uses "reference" semantics: self is reindexed to ds2's time
coordinates, with NaN-fill for times not present in self. ds2 is never
modified. For symmetric inner-join behaviour across multiple datasets use
the module-level ``align_time`` function instead.

Parameters
----------
ds2 : xr.Dataset
The reference dataset to align to.
lead_time : str or timedelta or list
For Forecast→Observation: "shortest" | "longest" | specific value or list.
For Forecast→Forecast: "reference" | "intersection" | "union" (default "reference").
Ignored for observation→* cases.
"""
if self.is_forecast() and ds2.mx.is_observation():
return _time.align_forecast_to_observation(self._ds, ds2, lead_time=lead_time)
elif self.is_observation() and ds2.mx.is_forecast():
return _time.align_observation_to_forecast(self._ds, ds2)
elif self.is_observation() and ds2.mx.is_observation():
return _time.align_observation_to_observation(self._ds, ds2)
elif self.is_forecast() and ds2.mx.is_forecast():
ff_lead_time = (
lead_time if lead_time in ("reference", "intersection", "union") else "reference"
)
return _time.align_forecast_to_forecast(self._ds, ds2, lead_time=ff_lead_time)
else:
raise ValueError("Cannot align datasets with unknown time properties")

def align_space_with(self, ds2, **kwargs):
"""Align this dataset's spatial grid to match ds2.

Always uses "reference" semantics: self is interpolated or reindexed to
ds2's spatial coordinates. ds2 is never modified.

Parameters
----------
ds2 : xr.Dataset
The reference dataset to align to.
method : str
Interpolation method for grid→point alignment. One of "xarray" or
"delaunay" (default "xarray"). Ignored for grid→grid.
**kwargs
Passed through to the interpolator.
"""
if self.is_grid():
if ds2.mx.is_grid():
return _space.align_grid_grid(self._ds, ds2, **kwargs)
elif ds2.mx.is_point():
return _space.align_grid_point(self._ds, ds2, **kwargs)
elif self.is_point():
if ds2.mx.is_point():
raise NotImplementedError("Point-to-point alignment not implemented")
elif ds2.mx.is_grid():
raise NotImplementedError("Point-to-grid alignment not implemented")
raise ValueError("Datasets do not have compatible spatial properties")
Loading