Source code for dpyverification.datasources.fewsnetcdf

"""Read and write NetCDF files in a fews compatible format."""

from collections.abc import Generator
from enum import StrEnum
from pathlib import Path
from typing import ClassVar, Self

import numpy as np
import xarray as xr
from numpy.typing import NDArray

from dpyverification.configuration.default.datasources import FewsNetCDFConfig, FewsNetCDFKind
from dpyverification.constants import (
    FORECAST_DATA_TYPES,
    DataType,
    StandardCoord,
    StandardDim,
)
from dpyverification.datasources.base import BaseDatasource

__all__ = [
    "FewsNetCDF",
    "FewsNetCDFConfig",
]


class FewsNetcdfDims(StrEnum):
    """List of dimension names."""

    analysis_time = "analysis_time"
    stations = "stations"


class FewsNetcdfCoord(StrEnum):
    """List of dimension names."""

    station_id = "station_id"
    station_names = "station_names"


class Preprocessor:
    """Used in xr.open_mfdataset(preprocess=preprocessor_instance)."""

    def __init__(
        self,
        fews_netcdf_kind: FewsNetCDFKind,
        filter_variables: list[str] | None = None,
        filter_stations: list[str] | None = None,
        filter_forecast_periods: list[np.timedelta64] | None = None,
    ) -> None:
        self.fews_netcdf_kind = fews_netcdf_kind
        self.variables = filter_variables
        self.stations = filter_stations
        self.forecast_periods = filter_forecast_periods

    @staticmethod
    def convert_byte_string_coord_to_utf8(
        dataset: xr.Dataset,
        coords: list[FewsNetcdfCoord],
    ) -> xr.Dataset:
        """Convert byte strings."""
        for coord in coords:
            dataset[coord] = xr.DataArray(
                [  # type:ignore[misc]
                    v.decode("utf-8") if isinstance(v, bytes) else v  # type:ignore[misc]
                    for v in dataset[coord].to_numpy()  # type:ignore[misc]
                ],
                dims=dataset[coord].dims,
            )
        return dataset

    @staticmethod
    def rename_to_internal(
        dataset: xr.Dataset,
    ) -> xr.Dataset:
        """Rename dims, coords to internal definition."""
        # Rename station coords/dims
        dataset = dataset.rename(
            {
                FewsNetcdfCoord.station_names: StandardCoord.station_name.name,
                FewsNetcdfCoord.station_id: StandardCoord.station.name,
            },  # type:ignore[misc]
        )
        # Swap stations dim and drop stations dim.
        dataset = dataset.swap_dims(
            {FewsNetcdfDims.stations: StandardDim.station},  # type:ignore[misc]
        )

        if FewsNetcdfDims.stations in dataset:
            dataset = dataset.drop(FewsNetcdfDims.stations)

        # Only the case when retrieving full forecasts (per forecast reference time)
        if StandardCoord.station_name.name in dataset:
            dataset = dataset.set_coords(StandardCoord.station_name.name)

        # Rename analysis_time for simulations
        if FewsNetcdfDims.analysis_time in dataset:
            dataset = dataset.rename(
                {
                    FewsNetcdfDims.analysis_time: StandardDim.forecast_reference_time,  # type:ignore[misc]
                },
            )
        return dataset

    @staticmethod
    def filter_stations(
        dataset: xr.Dataset,
        stations: list[str],
    ) -> xr.Dataset:
        """Filter stations."""
        swap_dict = {StandardDim.station: StandardCoord.station.name}
        dataset = dataset.swap_dims(swap_dict)
        dataset = dataset.sel({StandardCoord.station.name: stations})  # type:ignore[misc]
        return dataset.swap_dims(
            {StandardCoord.station.name: StandardDim.station},  # type:ignore[misc]
        )

    @staticmethod
    def set_internal_time_dims_on_forecast(
        dataset: xr.Dataset,
    ) -> xr.Dataset:
        """Transform the FEWS NetCDF time dims/coords to internal."""
        forecast_periods = (
            (dataset[StandardDim.time] - dataset[StandardDim.forecast_reference_time])
            # type:ignore[misc]
            .to_numpy()
            .ravel()
        )
        ds = dataset.assign_coords(
            {StandardDim.forecast_period: (StandardDim.time, forecast_periods)},  # type:ignore[misc]
        )

        ds = ds.swap_dims({StandardDim.time: StandardDim.forecast_period}).drop_vars(  # type:ignore[misc]
            StandardDim.time,
        )

        # Re-compute time as 2d matrix along forecast_period and forecast_reference_time
        time_index_2d = (
            (ds[StandardDim.forecast_reference_time] + ds[StandardDim.forecast_period])
            # type:ignore[misc]
            .to_numpy()
            .swapaxes(0, 1)
        )

        # Now assign time
        ds = ds.assign_coords(
            {
                StandardDim.time: (  # type:ignore[misc]
                    (StandardDim.forecast_period, StandardDim.forecast_reference_time),
                    time_index_2d,  # type:ignore[misc]
                ),
            },
        )

        # Set forecast reference time as dim on all vars
        for data_var in ds.data_vars:
            ds[data_var] = ds[data_var].expand_dims(
                {StandardDim.forecast_reference_time: ds[StandardDim.forecast_reference_time]},
            )

        return ds

    def __call__(self, dataset: xr.Dataset) -> xr.Dataset:
        """Sequence of processing tasks."""
        # Decode byte-string coords
        dataset = Preprocessor.convert_byte_string_coord_to_utf8(
            dataset,
            coords=[FewsNetcdfCoord.station_id],
        )

        # Rename dims/coords to internal definitions
        dataset = Preprocessor.rename_to_internal(
            dataset,
        )

        if self.fews_netcdf_kind == FewsNetCDFKind.simulated_forecast_per_forecast_reference_time:
            # Transform to full info sim dataset
            dataset = self.set_internal_time_dims_on_forecast(
                dataset,
            )

        # Filter variables
        if self.variables is not None:
            dataset = dataset[self.variables]

        # Sometimes time_bnds is included in the dataset, but this is not expected in the internal
        # datamodel and can cause issues with alignment later on, so we drop it if it exists.
        if "time_bnds" in dataset:
            dataset = dataset.drop_vars("time_bnds")

        # Filter stations
        if self.stations is not None:
            dataset = self.filter_stations(dataset, self.stations)

        # Filter forecast periods for simulations
        if (
            self.forecast_periods is not None
            and self.fews_netcdf_kind
            == FewsNetCDFKind.simulated_forecast_per_forecast_reference_time
        ):
            # Filter the relevant forecast_periods to maximize memory efficiency
            selector = {StandardDim.forecast_period: self.forecast_periods}
            dataset = dataset.sel(selector)

        return dataset


def quantiles_to_cdf_data_array(
    sim: xr.DataArray,
    n_thresholds: int = 3000,
    padding_percentage: int = 5,
) -> xr.DataArray:
    """Create a cdf data array from Delft-FEWS NetCDF with quantiles.

    Verification metrics for CDFs may require a 'threshold' dimension to represent an array of
    thresholds of a continuous variable (like the crps_for_cdf function in the scores package).
    Here, the actual data variable represents probabilities of (non)-exceedance for the defined
    thresholds represented by the thresholds coordinate along the threshold dimension.

    In Delft-FEWS, however, such a datamodel is not available. Many users therefore use a standard
    FEWS NetCDF for ensembles, and use the realization dimension to represent probabilities (i.e.
    0.01, 0.02, ... 0.99) and the data variable (i.e. discharge) to represent the thresholds.

    This function converts a standard Delft-FEWS NetCDF to the internal input data structure.
    Because Delft-FEWS NetCDFs always only represent one forecast, the discretization of
    of probabilities and thresholds is always relative to the range of values found in one
    specific forecast (i.e. with discharges between 100-200). Because the input to this function
    is a set of FEWS NetCDFs, we need to resample the thresholds so that we can keep an accurate
    representation of each individual forecast CDF in a set of forecasts, but not completely blow
    up the matrix. Given the min and max value found in the dataset (with multiple forecasts),
    n_thresholds will be created by interpolating between min and max. For the set of forecasts, a
    new coordinate will then be created given the interpolated values. In this way, a new threshold
    dimension is created that can represent all values of the continuous variable found in the set
    of forecasts.

    Parameters
    ----------
    sim : xr.DataArray
        Forecast with realization dimension representing quantiles
    n_thresholds : int, optional
        Number of thresholds to use, by default 3000. Given the range of values found in the
        variable, interpolate between min-max with n_thresholds steps. For example if the range of
        discharge [m3/s] values found in the data is between 0-3000, the resolution of the threshold
        coordinate will correspond to 1 m3/s. Should be tuned according to the variable.
    padding_percentage : int, optional
        The % of padding to apply to the threshold coordinate, by default 5. This padding is applied
        so that the new threshold dimension (containing the variable values) has a wide enough range
        to also capture observed values during verification.


    Returns
    -------
    xr.DataArray
        A new data array with shared threshold dim and coords, compatible
        with the scores package (i.e. scores.probability.crps_cdf)
    """

    def check_non_decreasing_and_not_nan(arr: NDArray) -> None:
        """Check an array in non-decreasing and does not contain any NaN."""
        if np.isnan(arr).any():  # type:ignore[misc]
            msg = "NaN values found in input CDF."
            raise ValueError(msg)
        # Check for non-decreasing order
        if not (arr[:-1] <= arr[1:]).all():  # type:ignore[misc]
            msg = "Decreasing values found in input CDF."
            raise ValueError(msg)

    if StandardDim.realization not in sim.dims:
        msg = "No realization dimension found in input CDF."
        raise ValueError(msg)

    realization_index = sim[StandardDim.realization].to_numpy()  # type:ignore[misc]
    check_non_decreasing_and_not_nan(realization_index)  # type:ignore[misc]

    # Get the min / max probabilities
    min_probability: float = float(realization_index[0])  # type:ignore[misc]
    max_probability: float = float(realization_index[-1])  # type:ignore[misc]

    # Probabilities should be between 0-1. However, in FEWS NetCDFs, users may define
    #   a different scale, such as between 0-100. Scale any given array to the desired
    #   range, by finding base 10 logarithm and taking the ceiling, so we always scale by
    #   an integer.
    scaling_factor: int = 10 ** np.ceil(np.log10(max_probability))
    min_probability = min_probability / scaling_factor  # i.e. 99 > 0.99 and 0.99 > 0.99
    max_probability = max_probability / scaling_factor

    if not (0 <= min_probability <= max_probability <= 1):
        msg = "Probabilities must lie in [0, 1] after scaling."
        raise ValueError(msg)

    # Get the min / max values
    vmin = float(sim.min())
    vmax = float(sim.max())

    # Apply padding to min / max
    width = vmax - vmin
    padded_vmin = vmin - (padding_percentage * 0.01 * width)
    padded_vmax = vmax + (padding_percentage * 0.01 * width)

    # Define the steps and threshold index, for new shared coordinate
    thresholds = np.linspace(padded_vmin, padded_vmax, n_thresholds)

    def interpolate_cdf(cdf: NDArray[np.floating]) -> NDArray:  # type:ignore[misc]
        # If all NaN, return a NaN array
        if np.all(np.isnan(cdf)):  # type:ignore[misc]
            return np.full_like(thresholds, np.nan, dtype=float)  # type:ignore[misc]

        # If non all are Nan, require all not Nan and non-decreasing
        check_non_decreasing_and_not_nan(cdf)  # type:ignore[misc]
        probs = np.linspace(min_probability, max_probability, len(cdf))  # type:ignore[misc]
        return np.interp(
            thresholds,
            cdf,  # type:ignore[misc]
            probs,
            left=0.0,
            right=1.0,
        )

    result: xr.DataArray = xr.apply_ufunc(
        interpolate_cdf,  # type:ignore[misc]
        sim,
        input_core_dims=[["realization"]],  # type:ignore[misc]
        output_core_dims=[["threshold"]],  # type:ignore[misc]
        vectorize=True,
        dask="parallelized",
        output_sizes={"threshold": len(thresholds)},  # type:ignore[misc]
    )

    result = result.assign_coords(threshold=("threshold", thresholds))
    result.attrs.update(  # type:ignore[misc]
        {"data_type": DataType.simulated_forecast_probabilistic},  # type:ignore[misc]
    )
    result.name = sim.name

    return result


def parse_forecast_period_netcdf_files(
    paths: Generator[Path, None, None],
) -> xr.Dataset:
    """Parse NetCDF responses from get timeseries with leadTimes parameter."""

    def preprocess(dataset: xr.Dataset) -> xr.Dataset:
        """
        Preprocess individual files, set forecast_period based on filename.

        When requesting data for a specific forecast period via the FEWS-Webservice,
        the actual forecast period used in the request is not available in the
        response. As a workaround, we prefix the filename with the forecast period
        in milliseconds, access it via the dataset encoding and set it as a dim/coord
        on the dataset.
        """
        filename = Path(dataset.encoding["source"]).name  # type:ignore[misc]

        forecast_period_millis = filename.split("_")[0]
        if not forecast_period_millis.isalnum():
            msg = "Filename prefix is expected to be a numeric representing the forecast period"
            "(lead time) in milliseconds. The provided prefix '{forecast_period_millis}' is is not"
            "numeric and cannot be converted to a valid forecast period."
            raise ValueError(msg)

        forecast_period = np.timedelta64(int(forecast_period_millis), "ms").astype(
            "timedelta64[ns]",
        )
        forecast_reference_times = dataset[StandardDim.time] - forecast_period  # type:ignore[misc]

        # Set the station_name as coord instead of variable
        if FewsNetcdfCoord.station_names in dataset:
            dataset = dataset.set_coords(FewsNetcdfCoord.station_names)

        # Set forecast_reference_time dim / coord
        dataset = dataset.rename({StandardDim.time: StandardDim.forecast_reference_time})  # type:ignore[misc]
        dataset = dataset.assign_coords(
            {
                StandardDim.forecast_reference_time: (  # type:ignore[misc]
                    StandardDim.forecast_reference_time,
                    forecast_reference_times.to_numpy(),  # type:ignore[misc]
                ),
            },
        )

        # Set coord (expected for alignment of individual arrays)
        dataset = dataset.assign_coords(
            {FewsNetcdfDims.stations: dataset[FewsNetcdfCoord.station_id].to_numpy()},  # type:ignore[misc]
        )
        # Assign forecast_period as a dim/coord
        dataset = dataset.expand_dims(StandardDim.forecast_period)
        return dataset.assign_coords(
            {
                StandardCoord.forecast_period.name: (
                    StandardDim.forecast_period,
                    [forecast_period],  # type:ignore[misc]
                ),
            },
        )

    # Create one object
    dataset_list = [preprocess(xr.open_dataset(path)) for path in paths]
    dataset = xr.merge(dataset_list)

    # Sort forecast_period index
    dataset = dataset.sortby(StandardDim.forecast_period)

    # Decode byte-string coords
    dataset = Preprocessor.convert_byte_string_coord_to_utf8(
        dataset,
        coords=[FewsNetcdfCoord.station_id],
    )

    # Rename dims/coords to internal definitions
    dataset = Preprocessor.rename_to_internal(
        dataset,
    )

    # On resulting object, assign forecast_reference_time as coordinate
    return dataset.assign_coords(
        {  # type:ignore[misc]
            StandardDim.time: (  # type:ignore[misc]
                (StandardDim.forecast_reference_time, StandardDim.forecast_period),
                (
                    dataset[StandardDim.forecast_reference_time]
                    + dataset[StandardDim.forecast_period]
                ).to_numpy(),  # type:ignore[misc]
            ),
        },
    )


[docs] class FewsNetCDF(BaseDatasource): """For reading data from a NetCDF file produced by Delft-FEWS.""" kind = "fewsnetcdf" config_class = FewsNetCDFConfig supported_data_types: ClassVar[set[DataType]] = { DataType.observed_historical, DataType.simulated_forecast_ensemble, } def __init__(self, config: FewsNetCDFConfig) -> None: self.config: FewsNetCDFConfig = config
[docs] @staticmethod def convert_dataset_to_dataarray( dataset: xr.Dataset, source: str, data_type: DataType, ) -> xr.DataArray: """Transform dataset to internal datamodel.""" # Extract the variable units from data variables def _get_unit(da: xr.DataArray) -> str: if "units" not in da.attrs: # type:ignore[misc] return "unknown" return da.attrs["units"] # type:ignore[no-any-return, misc] units = [_get_unit(dataset[da]) for da in dataset] # Stack the variables along dimension variable da = dataset.to_dataarray(dim=StandardDim.variable, name=source) # Set the configured data type as attribute da.attrs["data_type"] = data_type # type:ignore[misc] # Set the station_id as index on station dim # to ensure automatic alignment based on this coord later on. da = da.assign_coords( { StandardDim.station: da[StandardCoord.station.name].to_numpy(), # type:ignore[misc] }, ) # Set the units as auxillary coordinate on new dimension variable da = da.assign_coords( {StandardCoord.units.name: (StandardDim.variable, units)}, # type:ignore[misc] ) if data_type in FORECAST_DATA_TYPES: return da.transpose( StandardDim.variable, StandardDim.station, StandardDim.forecast_reference_time, StandardDim.forecast_period, ..., ) # Historical simulations or observations return da.transpose(StandardDim.variable, StandardDim.station, StandardDim.time, ...)
[docs] def fetch_data(self) -> Self: """Retrieve fewsnetcdf content as an xarray DataArray.""" # Configure pre-processing preprocessor = Preprocessor( fews_netcdf_kind=self.config.netcdf_kind, filter_stations=self.config.station_ids, filter_forecast_periods=self.config.forecast_periods.timedelta64, ) # Observations if self.config.data_type == DataType.observed_historical: dataset = xr.open_mfdataset( self.config.paths, # type:ignore[arg-type] # generator is acceptable argument preprocess=preprocessor, ) # Simulations - per forecast reference time if self.config.netcdf_kind == FewsNetCDFKind.simulated_forecast_per_forecast_reference_time: dataset = xr.open_mfdataset( self.config.paths, # type:ignore[arg-type] # generator is acceptable argument combine="by_coords", preprocess=preprocessor, coords="minimal", compat="override", ) # Simulations - per forecast period if self.config.netcdf_kind == FewsNetCDFKind.simulated_forecast_per_forecast_period: dataset = parse_forecast_period_netcdf_files( self.config.paths, ) if self.config.data_type in FORECAST_DATA_TYPES: # After loading data into xr.Dataset, apply a filter on forecast reference time, based # on the configured verification period dataset = dataset.sel( { StandardDim.forecast_reference_time: slice( # type:ignore[misc] self.config.verification_period_on_frt.start, self.config.verification_period_on_frt.end, ), }, ) else: # Filter historical data on time dim dataset = dataset.sel( { StandardDim.time: slice( # type:ignore[misc] self.config.verification_period_on_time.start, self.config.verification_period_on_time.end, ), }, ) # Load into memory, in the future support dask dataset.load() # Convert datasets to data_array data_array = self.convert_dataset_to_dataarray( dataset, self.config.source, self.config.data_type, ) # For probabilistic data types, transform the data array so that # all cdf's share the same threshold dim if self.config.data_type == DataType.simulated_forecast_probabilistic: if len(data_array[StandardDim.variable]) > 1: msg = "Multiple variables for simulated_forecast_probabilistic not yet supported" raise NotImplementedError(msg) data_array = quantiles_to_cdf_data_array(data_array) # Assign to self self.data_array = data_array return self