Source code for hydromt.data_adapter.dataset

"""Implementation for the dataset DataAdapter."""
from datetime import datetime
from logging import Logger, getLogger
from os.path import basename, splitext
from typing import List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import xarray as xr
from pystac import Asset as StacAsset
from pystac import Catalog as StacCatalog
from pystac import Item as StacItem
from pystac import MediaType

from hydromt.exceptions import NoDataException
from hydromt.nodata import NoDataStrategy, _exec_nodata_strat
from hydromt.typing import Data, ErrorHandleMethod, StrPath, TimeRange, Variables
from hydromt.utils import has_no_data

from .data_adapter import DataAdapter
from .utils import netcdf_writer, shift_dataset_time, zarr_writer

logger = getLogger(__name__)


[docs] class DatasetAdapter(DataAdapter): """DatasetAdapter for non-spatial n-dimensional data.""" _DEFAULT_DRIVER = "" _DRIVERS = { "nc": "netcdf", }
[docs] def __init__( self, path: StrPath, driver: Optional[str] = None, filesystem: Optional[str] = None, nodata: Optional[Union[dict, float, int]] = None, rename: Optional[dict] = None, unit_mult: Optional[dict] = None, unit_add: Optional[dict] = None, meta: Optional[dict] = None, attrs: Optional[dict] = None, driver_kwargs: Optional[dict] = None, storage_options: Optional[dict] = None, name: str = "", catalog_name: str = "", provider: Optional[str] = None, version: Optional[str] = None, ): """Initiate data adapter for n-dimensional timeseries data. This object contains all properties required to read supported files(netcdf, zarr) into a single unified Dataset, i.e. :py:class:`xarray.Dataset`. In addition it keeps meta data to be able to reproduce which data is used. Parameters ---------- path: str, Path Path to data source. If the dataset consists of multiple files, the path may contain {variable}, {year}, {month} placeholders as well as path search pattern using a ``*`` wildcard. driver: {'netcdf', 'zarr'}, optional Driver to read files with, for 'netcdf' :py:func:`xarray.open_mfdataset`. By default the driver is inferred from the file extension. filesystem: str, optional Filesystem where the data is stored (local, cloud, http etc.). If None (default) the filesystem is inferred from the path. See :py:func:`fsspec.registry.known_implementations` for all options. nodata: float, int, optional Missing value number. Only used if the data has no native missing value. Nodata values can be differentiated between variables using a dictionary. rename: dict, optional Mapping of native data source variable to output source variable name as required by hydroMT. unit_mult, unit_add: dict, optional Scaling multiplication and addition to change to map from the native data unit to the output data unit as required by hydroMT. meta: dict, optional Metadata information of dataset, prefably containing the following keys: - 'source_version' - 'source_url' - 'source_license' - 'paper_ref' - 'paper_doi' - 'category' placeholders: dict, optional Placeholders to expand yaml entry to multiple entries (name and path) based on placeholder values attrs: dict, optional Additional attributes relating to data variables. For instance unit or long name of the variable. driver_kwargs, dict, optional Additional key-word arguments passed to the driver. storage_options: dict, optional Additional key-word arguments passed to the fsspec FileSystem object. name, catalog_name: str, optional Name of the dataset and catalog, optional. provider: str, optional A name to identifiy the specific provider of the dataset requested. if None is provided, the last added source will be used. version: str, optional A name to identifiy the specific version of the dataset requested. if None is provided, the last added source will be used. """ super().__init__( path=path, driver=driver, filesystem=filesystem, nodata=nodata, rename=rename, unit_mult=unit_mult, unit_add=unit_add, meta=meta, attrs=attrs, driver_kwargs=driver_kwargs, storage_options=storage_options, name=name, catalog_name=catalog_name, provider=provider, version=version, )
[docs] def to_file( self, data_root: StrPath, data_name: str, time_tuple: Optional[TimeRange] = None, variables: Optional[List[str]] = None, driver: Optional[str] = None, handle_nodata: NoDataStrategy = NoDataStrategy.RAISE, **kwargs, ) -> Optional[Tuple[str, str]]: """Save a dataset slice to file. By default the data is saved as a NetCDF file. Parameters ---------- data_root : str, Path Path to output folder data_name : str Name of output file without extension. variables : list of str, optional Names of Dataset variables to return. By default all dataset variables are returned. time_tuple : tuple of str, datetime, optional Start and end date of period of interest. By default the entire time period of the dataset is returned. driver : str, optional Driver to write file, e.g.: 'netcdf', 'zarr', by default None **kwargs Additional keyword arguments that are passed to the `to_zarr` function. Returns ------- fn_out: str Absolute path to output file driver: str Name of driver to read data with, see :py:func:`~hydromt.data_catalog.DataCatalog.get_dataset` """ obj = self.get_data( time_tuple=time_tuple, variables=variables, single_var_as_array=variables is None, handle_nodata=handle_nodata, logger=logger, ) if obj is None: return None if driver is None or driver == "netcdf": fn_out = netcdf_writer( obj=obj, data_root=data_root, data_name=data_name, variables=variables, ) elif driver == "zarr": fn_out = zarr_writer( obj=obj, data_root=data_root, data_name=data_name, **kwargs ) else: raise ValueError(f"Dataset: Driver {driver} unknown.") return fn_out, driver
[docs] def get_data( self, variables: Optional[Variables] = None, time_tuple: Optional[TimeRange] = None, single_var_as_array: Optional[bool] = True, handle_nodata: NoDataStrategy = NoDataStrategy.RAISE, logger: Logger = logger, ): """Return a clipped, sliced and unified Dataset. For a detailed description see: :py:func:`~hydromt.data_catalog.DataCatalog.get_dataset` """ try: # load data fns = self._resolve_paths(variables, time_tuple) self.mark_as_used() ds = self._read_data(fns, logger=logger) if ds is None: raise NoDataException() # rename variables and parse data and attrs ds = self._rename_vars(ds) ds = self._set_nodata(ds) ds = self._shift_time(ds, logger=logger) # slice ds = DatasetAdapter._slice_data(ds, variables, time_tuple, logger=logger) if ds is None: raise NoDataException() # uniformize ds = self._apply_unit_conversion(ds, logger=logger) ds = self._set_metadata(ds) # return array if single var and single_var_as_array return self._single_var_as_array(ds, single_var_as_array, variables) except NoDataException: _exec_nodata_strat( "No data to export", strategy=handle_nodata, logger=logger ) return None
def _read_data( self, fns: List[StrPath], logger: Logger = logger, ) -> Optional[Data]: kwargs = self.driver_kwargs.copy() if len(fns) > 1 and self.driver in ["zarr"]: raise ValueError( f"Dataset: Reading multiple {self.driver} files is not supported." ) logger.info(f"Reading {self.name} {self.driver} data from {self.path}") if self.driver in ["netcdf"]: ds = xr.open_mfdataset(fns, **kwargs) elif self.driver == "zarr": ds = xr.open_zarr(fns[0], **kwargs) elif self.driver == "mfcsv": raise NotImplementedError else: raise ValueError(f"Dataset: Driver {self.driver} unknown") if has_no_data(ds): return None else: return ds def _rename_vars(self, ds: Data) -> Data: rm = {k: v for k, v in self.rename.items() if k in ds} ds = ds.rename(rm) return ds def _set_metadata(self, ds: Data) -> Data: if self.attrs: if isinstance(ds, xr.DataArray): ds.attrs.update(self.attrs[ds.name]) else: for k in self.attrs: ds[k].attrs.update(self.attrs[k]) ds.attrs.update(self.meta) return ds def _set_nodata(self, ds: Data) -> Data: if self.nodata is not None: if not isinstance(self.nodata, dict): nodata = {k: self.nodata for k in ds.data_vars.keys()} else: nodata = self.nodata for k in ds.data_vars: mv = nodata.get(k, None) if mv is not None and ds[k].attrs.get("_FillValue", None) is None: ds[k].attrs["_FillValue"] = mv return ds def _apply_unit_conversion(self, ds: Data, logger: Logger = logger) -> Data: unit_names = list(self.unit_mult.keys()) + list(self.unit_add.keys()) unit_names = [k for k in unit_names if k in ds.data_vars] if len(unit_names) > 0: logger.debug(f"Convert units for {len(unit_names)} variables.") for name in list(set(unit_names)): # unique m = self.unit_mult.get(name, 1) a = self.unit_add.get(name, 0) da = ds[name] attrs = da.attrs.copy() nodata_isnan = da.attrs.get("_FillValue", None) is None or np.isnan( da.attrs.get("_FillValue", None) ) # nodata value is explicitly set to NaN in case no nodata value is provided nodata = np.nan if nodata_isnan else da.attrs["_FillValue"] data_bool = ~np.isnan(da) if nodata_isnan else da != nodata ds[name] = xr.where(data_bool, da * m + a, nodata) ds[name].attrs.update(attrs) # set original attributes return ds def _shift_time(self, ds: Data, logger: Logger = logger) -> Data: dt = self.unit_add.get("time", 0) return shift_dataset_time(dt=dt, ds=ds, logger=logger) @staticmethod def _slice_data( ds: Data, variables: Optional[Variables] = None, time_tuple: Optional[TimeRange] = None, logger: Logger = logger, ) -> Optional[Data]: """Slice the dataset in space and time. Arguments --------- ds : xarray.Dataset or xarray.DataArray The Dataset to slice. variables : str or list of str, optional. Names of variables to return. time_tuple : tuple of str, datetime, optional Start and end date of period of interest. By default the entire time period of the dataset is returned. Returns ------- ds : xarray.Dataset The sliced Dataset. """ if isinstance(ds, xr.DataArray): if ds.name is None: ds.name = "data" ds = ds.to_dataset() elif variables is not None: variables = np.atleast_1d(variables).tolist() if len(variables) > 1 or len(ds.data_vars) > 1: mvars = [var not in ds.data_vars for var in variables] if any(mvars): raise ValueError(f"Dataset: variables not found {mvars}") ds = ds[variables] if time_tuple is not None: ds = DatasetAdapter._slice_temporal_dimension(ds, time_tuple, logger=logger) if has_no_data(ds): return None else: return ds @staticmethod def _slice_temporal_dimension( ds: Data, time_tuple: TimeRange, logger: Logger = logger, ) -> Optional[Data]: if ( "time" in ds.dims and ds["time"].size > 1 and np.issubdtype(ds["time"].dtype, np.datetime64) ): logger.debug(f"Slicing time dim {time_tuple}") ds = ds.sel(time=slice(*time_tuple)) if ds.time.size == 0: raise IndexError("Dataset: Time slice out of range.") if has_no_data(ds): return None else: return ds def get_time_range( self, ds: Optional[Data] = None, ) -> TimeRange: """Get the temporal range of a dataset. Parameters ---------- ds : Optional[xr.DataArray | xr.Dataset] The dataset to detect the time range of. It must have a time dimentsion set. If none is provided, :py:meth:`hydromt.DatasetAdapter.get_data` will be used to fetch the it before detecting. Returns ------- range: Tuple[np.datetime64, np.datetime64] A tuple containing the start and end of the time dimension. Range is inclusive on both sides. """ if ds is None: ds = self.get_data() try: return (ds.time[0].values, ds.time[-1].values) except AttributeError: raise AttributeError("Dataset has no dimension called 'time'")
[docs] def to_stac_catalog(self, on_error: ErrorHandleMethod = ErrorHandleMethod.COERCE): """ Convert a dataset into a STAC Catalog representation. The collection will contain an asset for each of the associated files. Parameters ---------- - on_error (str, optional): The error handling strategy. Options are: "raise" to raise an error on failure, "skip" to skip the dataset on failure, and "coerce" (default) to set default values on failure. Returns ------- - Optional[StacCatalog]: The STAC Catalog representation of the dataset, or None if the dataset was skipped. """ try: start_dt, end_dt = self.get_time_range() start_dt = pd.to_datetime(start_dt) end_dt = pd.to_datetime(end_dt) props = {**self.meta} ext = splitext(self.path)[-1] bbox = [0.0, 0.0, 0.0, 0.0] if ext == ".nc": media_type = MediaType.HDF5 elif ext == ".zarr": raise RuntimeError("STAC does not support zarr datasets") else: raise RuntimeError( f"Unknown extention: {ext} cannot determine media type" ) except (IndexError, KeyError) as e: if on_error == ErrorHandleMethod.SKIP: logger.warning( "Skipping {name} during stac conversion because" "because detecting temporal extent failed." ) return elif on_error == ErrorHandleMethod.COERCE: bbox = [0.0, 0.0, 0.0, 0.0] props = self.meta start_dt = datetime(1, 1, 1) end_dt = datetime(1, 1, 1) media_type = MediaType.JSON else: raise e stac_catalog = StacCatalog(self.name, description=self.name) stac_item = StacItem( self.name, geometry=None, bbox=bbox, properties=props, datetime=None, start_datetime=start_dt, end_datetime=end_dt, ) stac_asset = StacAsset(str(self.path), media_type=media_type) base_name = basename(self.path) stac_item.add_asset(base_name, stac_asset) stac_catalog.add_item(stac_item) return stac_catalog