Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 14 additions & 112 deletions chlorophyll/chl_climatology_and_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,104 +33,17 @@

import numpy as np
import xarray as xr
import regionmask
from scipy import ndimage
import scipy.sparse as sp
import scipy.sparse.linalg as spla
from distributed import Client

path_root = Path(__file__).parents[1]
sys.path.append(str(path_root))

from scripts_common import get_provenance_metadata, md5sum
from regrid_common import fill_ocean_horiz
from scripts_common import get_provenance_metadata

xr.set_options(keep_attrs=True)


def fill_missing_data(field, wet_mask, maxiter=0):
"""
Fill missing ocean values using a sparse Laplacian solve.
Adapted from https://github.com/adcroft/interp_and_fill/blob/main/Interpolate%20and%20fill%20SeaWIFS.ipynb

Parameters
----------
field : numpy.ndarray
Input data containing missing data
wet_mask : numpy.ndarray
Wet cell mask (0 land, 1 ocean)

Returns
-------
numpy.ma.array
Data array with missing ocean points filled.
"""

def _process_neighbour(n, jn, in_):
"""Process neighbour at (jn, in_) for row n."""
if wet_mask[jn, in_] <= 0:
return

ld[n] -= 1
idx = ind[jn, in_]

if idx >= 0:
A[n, idx] = 1.0
else:
b[n] -= field[jn, in_]

nj, ni = field.shape
missing_mask = np.isnan(field)
field = np.where(missing_mask, 0, field)

# Index lookup for missing points
missing_j, missing_i = np.where(missing_mask & (wet_mask > 0))
n_missing = missing_j.size
ind = np.full(field.shape, -1, dtype=int)
ind[missing_j, missing_i] = np.arange(n_missing)

# Sparse matrix in LIL format (fast incremental building)
A = sp.lil_matrix((n_missing, n_missing))
b = np.zeros(n_missing)
ld = np.zeros(n_missing)

# Build matrix row-by-row
for n in range(n_missing):
j = missing_j[n]
i = missing_i[n]

im1 = (i - 1) % ni
ip1 = (i + 1) % ni
jm1 = j - 1 if j > 0 else 0
jp1 = j + 1 if j < nj - 1 else nj - 1

if j > 0:
_process_neighbour(n, jm1, i)
_process_neighbour(n, j, im1)
_process_neighbour(n, j, ip1)
if j < nj - 1:
_process_neighbour(n, jp1, i)

# Tri-polar fold
if j == nj - 1:
fold_i = ni - 1 - i
_process_neighbour(n, j, fold_i)

# Set leading diagonal
b[ld >= 0] = 0.0
stabilizer = 1e-14
diag_vals = ld - stabilizer
A[np.arange(n_missing), np.arange(n_missing)] = diag_vals

# Convert to CSR and solve
A = A.tocsr()
x = spla.spsolve(A, b)

# Fill the missing values
field[missing_j, missing_i] = x

return np.where(wet_mask, field, np.NaN)


def main():
parser = argparse.ArgumentParser(
description=(
Expand Down Expand Up @@ -174,7 +87,7 @@ def main():

print("Calculating the monthly climatology...")

with Client(threads_per_worker=1) as client:
with Client(threads_per_worker=1):
ds = xr.open_mfdataset(
input_files,
chunks={"lat": 1024, "lon": 1024},
Expand All @@ -183,33 +96,22 @@ def main():
)
chl = ds[["CHL"]].groupby("time.month").mean("time").compute()

print("Filling missing data...")

# Create land mask, eroded to ensure we have values at wet cells near coasts
land = (
regionmask.defined_regions.natural_earth_v5_0_0.land_110.mask(chl).values == 0.0
)
land_eroded = ndimage.binary_erosion(land, structure=np.ones((200, 200)))

# Fill missing data for each month
print("Filling missing data...")
chl_filled = []
for month in range(1, 13):
print(f" Filling month {month}...")

chl_month = chl["CHL"].sel(month=month)

# Remove chl values on land
chl_month = chl_month.where(np.logical_not(land)).values

# Fill missing values in two steps. First, fill the missing wet cells, then
# the eroded land areas. If this is done in one step, high CHL values near
# the coast have a larger weighting leading to larger values in high latitude
# filled regions.
chl_filled = fill_missing_data(chl_month, 1.0 - land)

chl["CHL"].sel(month=month).values[:] = fill_missing_data(
chl_filled, 1.0 - land_eroded
chl_filled.append(
fill_ocean_horiz(
chl["CHL"].sel(month=month),
top_bound="regular",
n_erode=200,
erode_first=False,
)
)

chl["CHL"] = xr.concat(chl_filled, dim="month")

# Add time array
calendar = "gregorian"
times = xr.date_range(
Expand Down
158 changes: 158 additions & 0 deletions regrid_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
import xesmf as xe
import xarray as xr
import argparse
import numpy as np
import regionmask
from scipy import ndimage
import scipy.sparse as sp
import scipy.sparse.linalg as spla

from scripts_common import md5sum

Expand Down Expand Up @@ -250,3 +255,156 @@ def save_output(self):

unlimited_dims = "time" if "time" in forcing_regrid.dims else None
forcing_regrid.to_netcdf(self.output_filename, unlimited_dims=unlimited_dims)


def _fill_missing_horiz(field, mask, top_bound="none"):
"""
Fill missing values using a sparse Laplacian solve.
Adapted from https://github.com/adcroft/interp_and_fill/blob/main/Interpolate%20and%20fill%20SeaWIFS.ipynb

Parameters
----------
field : numpy.ndarray
Input data containing missing data
mask : numpy.ndarray
Fill mask (0 or 1). Missing values are not filled where mask==1
top_bound : {"none", "tripole", "regular"}, optional
Connectivity across the northern boundary. "none" uses only the south, west and east
neighbours at the top row. "tripole" uses tripolar connectivity; "regular" uses a
regular lat/lon pole. Default is "none".

Returns
-------
numpy.ndarray
Data array with missing points filled.
"""

def _process_neighbour(n, jn, in_):
"""Process neighbour at (jn, in_) for row n."""
if mask[jn, in_] <= 0:
return

ld[n] -= 1
idx = ind[jn, in_]

if idx >= 0:
A[n, idx] = 1.0
else:
b[n] -= field[jn, in_]

nj, ni = field.shape

if top_bound not in ("none", "regular", "tripole"):
raise ValueError(
f"top_bound must be one of ('none', 'regular', 'tripole'), got {top_bound}"
)
if top_bound in ("regular", "tripole") and ni % 2 != 0:
raise ValueError(
f"top_bound='{top_bound}' requires an even number of longitude points, got {ni}"
)

missing_mask = np.isnan(field)
field = np.where(missing_mask, 0, field)

# Index lookup for missing points
missing_j, missing_i = np.where(missing_mask & (mask > 0))
n_missing = missing_j.size
ind = np.full(field.shape, -1, dtype=int)
ind[missing_j, missing_i] = np.arange(n_missing)

# Sparse matrix in LIL format (fast incremental building)
A = sp.lil_matrix((n_missing, n_missing))
b = np.zeros(n_missing)
ld = np.zeros(n_missing)

# Build matrix row-by-row
for n in range(n_missing):
j = missing_j[n]
i = missing_i[n]

im1 = (i - 1) % ni
ip1 = (i + 1) % ni
jm1 = j - 1 if j > 0 else 0
jp1 = j + 1 if j < nj - 1 else nj - 1

if j > 0:
_process_neighbour(n, jm1, i)
_process_neighbour(n, j, im1)
_process_neighbour(n, j, ip1)
if j < nj - 1:
_process_neighbour(n, jp1, i)

# Top boundary
if (top_bound != "none") and (j == nj - 1):
if top_bound == "tripole":
fold_i = ni - 1 - i
elif top_bound == "regular":
fold_i = (i + ni // 2) % ni
_process_neighbour(n, j, fold_i)

# Set leading diagonal
b[ld >= 0] = 0.0
stabilizer = 1e-14
diag_vals = ld - stabilizer
A[np.arange(n_missing), np.arange(n_missing)] = diag_vals

# Convert to CSR and solve
A = A.tocsr()
x = spla.spsolve(A, b)

# Fill the missing values
field[missing_j, missing_i] = x

return np.where(mask, field, np.nan)


def fill_ocean_horiz(da, top_bound="none", n_erode=0, erode_first=True):
"""
Fill missing ocean values using a sparse Laplacian solve.
The land mask is determined from Natural Earth v5.0.0 and can be eroded using n_erode.

Parameters
----------
da : xr.DataArray
Input data containing missing ocean data (horziontal slice)
top_bound : {"none", "tripole", "regular"}, optional
Connectivity across the northern boundary. "none" uses only the south, west and east
neighbours at the top row. "tripole" uses tripolar connectivity; "regular" uses a
regular lat/lon pole. Default is "none".
n_erode : int
The size of the structure used to erode the land mask. This can be useful for
ensuring there are values at near coastal points, where the land-sea mask may
differ between the input data and the model
erode_first : boolean
If False, fill missing values in two steps. First, fill the missing wet cells, then
the eroded land areas (if n_erode > 0). If this is done in one step, high values
near the coast have a larger weighting leading to larger values in high latitude
filled regions.

Returns
-------
xr.DataArray
Data array with missing ocean points filled
"""

land = (
regionmask.defined_regions.natural_earth_v5_0_0.land_10.mask(da).values == 0.0
)

# Remove any values on land so that they don't influence the fill of ocean values
da = da.where(np.logical_not(land))

if erode_first and (n_erode > 0):
land = ndimage.binary_erosion(land, structure=np.ones((n_erode, n_erode)))

da_filled = _fill_missing_horiz(da, 1.0 - land, top_bound=top_bound)

if (not erode_first) and (n_erode > 0):
land_eroded = ndimage.binary_erosion(
land, structure=np.ones((n_erode, n_erode))
)
da_filled = _fill_missing_horiz(
da_filled, 1.0 - land_eroded, top_bound=top_bound
)

return da.copy(data=da_filled)
Loading
Loading