Source code for hydromt.data_catalog.drivers.raster.raster_xarray_driver

"""RasterDatasetDriver for zarr data."""

from copy import copy
from functools import partial
from logging import Logger, getLogger
from os.path import splitext
from typing import Callable, List, Optional

import xarray as xr

from hydromt._typing import (
    Geom,
    SourceMetadata,
    StrPath,
    TimeRange,
    Variables,
    Zoom,
)
from hydromt._typing.error import NoDataStrategy, exec_nodata_strat
from hydromt._utils.unused_kwargs import _warn_on_unused_kwargs
from hydromt.data_catalog.drivers._preprocessing import PREPROCESSORS
from hydromt.data_catalog.drivers.raster.raster_dataset_driver import (
    RasterDatasetDriver,
)

logger: Logger = getLogger(__name__)


[docs] class RasterDatasetXarrayDriver(RasterDatasetDriver): """RasterDatasetXarrayDriver.""" name = "raster_xarray" supports_writing = True
[docs] def read( self, uris: List[str], *, mask: Optional[Geom] = None, variables: Optional[Variables] = None, time_range: Optional[TimeRange] = None, zoom: Optional[Zoom] = None, metadata: Optional[SourceMetadata] = None, handle_nodata: NoDataStrategy = NoDataStrategy.RAISE, # TODO: https://github.com/Deltares/hydromt/issues/802 ) -> xr.Dataset: """ Read zarr data to an xarray DataSet. Args: """ _warn_on_unused_kwargs( self.__class__.__name__, { "mask": mask, "time_range": time_range, "variables": variables, "zoom": zoom, "metadata": metadata, }, ) options = copy(self.options) preprocessor: Optional[Callable] = None preprocessor_name: Optional[str] = options.pop("preprocess", None) if preprocessor_name: preprocessor = PREPROCESSORS.get(preprocessor_name) if not preprocessor: raise ValueError(f"unknown preprocessor: '{preprocessor_name}'") ext_override: Optional[str] = options.pop("ext_override", None) if ext_override is not None: first_ext: str = ext_override else: first_ext: str = splitext(uris[0])[-1] if first_ext == ".zarr": opn: Callable = partial(xr.open_zarr, **options) datasets = [] for _uri in uris: ext = splitext(_uri)[-1] if ext != first_ext and not ext_override: logger.warning(f"Reading zarr and {_uri} was not, skipping...") continue else: datasets.append( preprocessor(opn(_uri)) if preprocessor else opn(_uri) ) ds: xr.Dataset = xr.merge(datasets) elif first_ext in [".nc", ".netcdf"]: filtered_uris = [] for _uri in uris: ext = splitext(_uri)[-1] if ext != first_ext: logger.warning(f"Reading netcdf and {_uri} was not, skipping...") continue else: filtered_uris.append(_uri) ds: xr.Dataset = xr.open_mfdataset( filtered_uris, decode_coords="all", preprocess=preprocessor, **options ) else: raise ValueError( f"Unknown extension for RasterDatasetXarrayDriver: {first_ext} " ) for variable in ds.data_vars: if ds[variable].size == 0: exec_nodata_strat( f"No data from driver: '{self.name}' for variable: '{variable}'", strategy=handle_nodata, ) return ds
[docs] def write(self, path: StrPath, ds: xr.Dataset, **kwargs) -> str: """ Write the RasterDataset to a local file using zarr. args: returns: str with written uri """ no_ext, ext = splitext(path) # set filepath if incompat if ext not in {".zarr", ".nc", ".netcdf"}: logger.warning( f"Unknown extension for RasterDatasetXarrayDriver: {ext}," "switching to zarr" ) path = no_ext + ".zarr" ext = ".zarr" if ext == ".zarr": ds.to_zarr(path, mode="w", **kwargs) else: ds.to_netcdf(path, **kwargs) return str(path)