diff --git a/climate_api/data/datasets/era5_land.yaml b/climate_api/data/datasets/era5_land.yaml index 7f3b5a24..133d2a71 100644 --- a/climate_api/data/datasets/era5_land.yaml +++ b/climate_api/data/datasets/era5_land.yaml @@ -8,15 +8,20 @@ sync_availability: latest_available_function: climate_api.providers.availability.lagged_latest_available lag_hours: 120 - ingestion: + ingestion: function: dhis2eo.data.destine.era5_land.hourly.download default_params: variables: ['t2m'] + transforms: + - climate_api.transforms.convert_units units: kelvin convert_units: degC resolution: 9 km x 9 km source: ERA5-Land Reanalysis source_url: https://earthdatahub.destine.eu/collections/era5/datasets/reanalysis-era5-land + display: + colormap: rdbu_r + range: [15.0, 40.0] - id: era5land_precipitation_hourly name: Total precipitation (ERA5-Land) @@ -32,12 +37,18 @@ function: dhis2eo.data.destine.era5_land.hourly.download default_params: variables: ['tp'] - pre_process: ['deaccumulate_era5'] + transforms: + - climate_api.transforms.deaccumulate_era5 + - climate_api.transforms.convert_units units: m convert_units: mm resolution: 9 km x 9 km source: ERA5-Land Reanalysis source_url: https://earthdatahub.destine.eu/collections/era5/datasets/reanalysis-era5-land + display: + colormap: blues + range: [0.0, 5.0] + nodata: 0.0 - id: era5land_temperature_daily name: Daily mean 2m temperature (ERA5-Land) diff --git a/climate_api/data_manager/services/downloader.py b/climate_api/data_manager/services/downloader.py index 3c93772e..e294b144 100644 --- a/climate_api/data_manager/services/downloader.py +++ b/climate_api/data_manager/services/downloader.py @@ -143,6 +143,7 @@ def build_dataset_zarr(dataset: dict[str, Any], *, start: str | None = None, end dims = [lon_dim, lat_dim] ds = _select_time_range(ds, dataset=dataset, start=start, end=end) + ds = _run_transforms(ds, dataset) xmin = ds[lon_dim].min().item() xmax = ds[lon_dim].max().item() @@ -243,6 +244,16 @@ def _select_time_range( return selected +def _run_transforms(ds: xr.Dataset, dataset: dict[str, Any]) -> xr.Dataset: + for entry in dataset.get("transforms", []): + func_path = entry if isinstance(entry, str) else entry["function"] + params = {} if isinstance(entry, str) else entry.get("params", {}) + func = _get_dynamic_function(func_path) + logger.info("Applying transform %s to dataset %s", func_path, dataset.get("id", "?")) + ds = func(ds, dataset, **params) + return ds + + def _compute_time_space_chunks( ds: xr.Dataset, dataset: dict[str, Any], diff --git a/src/climate_api/transforms/__init__.py b/src/climate_api/transforms/__init__.py new file mode 100644 index 00000000..e5988537 --- /dev/null +++ b/src/climate_api/transforms/__init__.py @@ -0,0 +1,13 @@ +"""Built-in dataset transform functions for the transforms pipeline. + +Each function has the signature: + (ds: xr.Dataset, dataset: dict[str, Any]) -> xr.Dataset + +Functions can be referenced by their dotted module path in the dataset YAML +``transforms`` list, the same way ``ingestion.function`` works. +""" + +from .deaccumulate import deaccumulate_era5 +from .unit_conversion import convert_units + +__all__ = ["convert_units", "deaccumulate_era5"] diff --git a/src/climate_api/transforms/deaccumulate.py b/src/climate_api/transforms/deaccumulate.py new file mode 100644 index 00000000..fd46dec8 --- /dev/null +++ b/src/climate_api/transforms/deaccumulate.py @@ -0,0 +1,22 @@ +"""Deaccumulation transforms for ERA5 accumulated fields.""" + +from typing import Any + +import xarray as xr + + +def deaccumulate_era5(ds: xr.Dataset, dataset: dict[str, Any]) -> xr.Dataset: + """Convert ERA5 accumulated fields to per-step values by forward differencing. + + ERA5 stores precipitation and other flux variables as accumulations from the + start of the forecast step. This subtracts consecutive steps so each value + represents the amount in that step alone, then clips negative artefacts. + """ + varname = dataset["variable"] + da = ds[varname] + time_dim = next(d for d in da.dims if "time" in d) + diff = da.diff(dim=time_dim) + diff = diff.clip(min=0) + # Drop the first time step (no previous step to diff against) and reassign. + ds = ds.sel({time_dim: ds[time_dim][1:]}) + return ds.assign({varname: diff.assign_attrs(da.attrs)}) diff --git a/src/climate_api/transforms/unit_conversion.py b/src/climate_api/transforms/unit_conversion.py new file mode 100644 index 00000000..65045ff8 --- /dev/null +++ b/src/climate_api/transforms/unit_conversion.py @@ -0,0 +1,39 @@ +"""Unit conversion transform: scale + offset applied to the dataset variable.""" + +import logging +from typing import Any + +import xarray as xr + +logger = logging.getLogger(__name__) + +# (from_units, to_units) -> (display_label, scale, offset) +# Applied as: converted = original * scale + offset +_CONVERSIONS: dict[tuple[str, str], tuple[str, float, float]] = { + ("kelvin", "degc"): ("degC", 1.0, -273.15), + ("m", "mm"): ("mm", 1000.0, 0.0), +} + + +def convert_units(ds: xr.Dataset, dataset: dict[str, Any]) -> xr.Dataset: + """Convert the dataset variable from ``units`` to ``convert_units``. + + Reads ``units`` and ``convert_units`` from the dataset template dict. + Returns the dataset unchanged if either field is absent or the conversion + is not registered in ``_CONVERSIONS``. + """ + convert_to = dataset.get("convert_units") + if not convert_to: + return ds + units = dataset.get("units", "") + key = (units.lower(), convert_to.lower()) + conversion = _CONVERSIONS.get(key) + if conversion is None: + logger.warning("No unit conversion registered for %s -> %s; skipping", units, convert_to) + return ds + label, scale, offset = conversion + varname = dataset["variable"] + logger.info("Converting %s from %s to %s", varname, units, label) + da = ds[varname] + converted = da * scale + offset if scale != 1.0 else da + offset + return ds.assign({varname: converted.assign_attrs({**da.attrs, "units": label})}) diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 00000000..d1ab1dd6 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,101 @@ +import numpy as np +import xarray as xr + +from climate_api.transforms import convert_units, deaccumulate_era5 + + +def _ds(varname: str, values: list[float], time_steps: int = 1) -> xr.Dataset: + if time_steps > 1: + data = np.array(values, dtype=float).reshape(time_steps, -1) + return xr.Dataset({varname: xr.DataArray(data, dims=["time", "x"])}) + return xr.Dataset({varname: xr.DataArray(np.array(values, dtype=float))}) + + +class TestConvertUnits: + def test_kelvin_to_celsius(self): + ds = _ds("t2m", [273.15, 293.15, 313.15]) + result = convert_units(ds, {"variable": "t2m", "units": "kelvin", "convert_units": "degC"}) + np.testing.assert_allclose(result["t2m"].values, [0.0, 20.0, 40.0]) + assert result["t2m"].attrs["units"] == "degC" + + def test_metres_to_mm(self): + ds = _ds("tp", [0.001, 0.005]) + result = convert_units(ds, {"variable": "tp", "units": "m", "convert_units": "mm"}) + np.testing.assert_allclose(result["tp"].values, [1.0, 5.0]) + assert result["tp"].attrs["units"] == "mm" + + def test_no_convert_units_field_is_noop(self): + ds = _ds("t2m", [300.0]) + result = convert_units(ds, {"variable": "t2m", "units": "kelvin"}) + np.testing.assert_array_equal(result["t2m"].values, ds["t2m"].values) + + def test_unknown_conversion_is_noop(self): + ds = _ds("x", [1.0]) + result = convert_units(ds, {"variable": "x", "units": "foo", "convert_units": "bar"}) + np.testing.assert_array_equal(result["x"].values, ds["x"].values) + + def test_preserves_existing_attrs(self): + ds = xr.Dataset({"t2m": xr.DataArray([300.0], attrs={"long_name": "temperature", "units": "K"})}) + result = convert_units(ds, {"variable": "t2m", "units": "kelvin", "convert_units": "degC"}) + assert result["t2m"].attrs["long_name"] == "temperature" + + +class TestDeaccumulateEra5: + def test_differences_along_time(self): + ds = _ds("tp", [0.0, 1.0, 3.0, 6.0], time_steps=4) + result = deaccumulate_era5(ds, {"variable": "tp"}) + assert result.sizes["time"] == 3 + np.testing.assert_array_equal(result["tp"].values.flatten(), [1.0, 2.0, 3.0]) + + def test_clips_negative_values(self): + ds = _ds("tp", [3.0, 1.0, 4.0], time_steps=3) + result = deaccumulate_era5(ds, {"variable": "tp"}) + assert (result["tp"].values >= 0).all() + + def test_preserves_attrs(self): + data = np.array([[0.0], [1.0]]) + ds = xr.Dataset({"tp": xr.DataArray(data, dims=["time", "x"], attrs={"units": "m"})}) + result = deaccumulate_era5(ds, {"variable": "tp"}) + assert result["tp"].attrs["units"] == "m" + + +class TestRunTransformsPipeline: + def test_pipeline_via_dotted_path(self): + ds = _ds("t2m", [273.15]) + dataset = { + "variable": "t2m", + "units": "kelvin", + "convert_units": "degC", + "transforms": ["climate_api.transforms.convert_units"], + } + from climate_api.data_manager.services.downloader import _run_transforms + + result = _run_transforms(ds, dataset) + np.testing.assert_allclose(result["t2m"].values, [0.0]) + + def test_empty_transforms_is_noop(self): + ds = _ds("x", [1.0, 2.0]) + from climate_api.data_manager.services.downloader import _run_transforms + + result = _run_transforms(ds, {"variable": "x", "transforms": []}) + np.testing.assert_array_equal(result["x"].values, ds["x"].values) + + def test_no_transforms_key_is_noop(self): + ds = _ds("x", [1.0]) + from climate_api.data_manager.services.downloader import _run_transforms + + result = _run_transforms(ds, {"variable": "x"}) + np.testing.assert_array_equal(result["x"].values, ds["x"].values) + + def test_dict_entry_with_params(self): + ds = _ds("t2m", [273.15]) + dataset = { + "variable": "t2m", + "units": "kelvin", + "convert_units": "degC", + "transforms": [{"function": "climate_api.transforms.convert_units"}], + } + from climate_api.data_manager.services.downloader import _run_transforms + + result = _run_transforms(ds, dataset) + np.testing.assert_allclose(result["t2m"].values, [0.0])