Source code for hydromt.data_catalog.drivers.raster.raster_xarray_driver

"""RasterDatasetDriver for zarr data."""

import logging
from functools import partial
from os.path import splitext
from pathlib import Path
from typing import Any, Callable, ClassVar

import xarray as xr
from pydantic import Field

from hydromt._utils.unused_kwargs import _warn_on_unused_kwargs
from hydromt.data_catalog.drivers.base_driver import (
    DRIVER_OPTIONS_DESCRIPTION,
    DriverOptions,
)
from hydromt.data_catalog.drivers.preprocessing import Preprocessor, get_preprocessor
from hydromt.data_catalog.drivers.raster.raster_dataset_driver import (
    RasterDatasetDriver,
)
from hydromt.error import NoDataStrategy, exec_nodata_strat
from hydromt.typing import (
    Geom,
    SourceMetadata,
    Variables,
    Zoom,
)

logger = logging.getLogger(__name__)

_ZARR_EXT = ".zarr"
_NETCDF_EXT = [".nc", ".netcdf"]


[docs] class RasterXarrayOptions(DriverOptions): """Options for RasterXarrayDriver.""" preprocess: str | None = None """Name of preprocessor to apply before merging datasets. Available preprocessors include: round_latlon, to_datetimeindex, remove_duplicates, harmonise_dims. See their docstrings for details.""" ext_override: str | None = None """Override the file extension check and try to read all files as the given extension. Useful when reading zarr files without the .zarr extension."""
[docs] def get_preprocessor(self) -> Preprocessor: """Get the preprocessor function.""" return get_preprocessor(self.preprocess)
[docs] def get_ext_override(self, uris: list[str]) -> str | None: """Get the extension override.""" if not self.ext_override: return splitext(uris[0])[-1] return self.ext_override
[docs] class RasterDatasetXarrayDriver(RasterDatasetDriver): """ Driver for RasterDataset using the xarray library: ``raster_xarray``. Supports reading and writing zarr and netcdf files using xarray. zarr files will be read using `xr.open_zarr` and netcdf files using `xr.open_mfdataset`. """ name = "raster_xarray" supports_writing = True SUPPORTED_EXTENSIONS: ClassVar[set[str]] = {_ZARR_EXT, *_NETCDF_EXT} options: RasterXarrayOptions = Field( default_factory=RasterXarrayOptions, description=DRIVER_OPTIONS_DESCRIPTION )
[docs] def read( self, uris: list[str], *, handle_nodata: NoDataStrategy = NoDataStrategy.RAISE, mask: Geom | None = None, variables: Variables | None = None, zoom: Zoom | None = None, chunks: dict[str, Any] | None = None, metadata: SourceMetadata | None = None, ) -> xr.Dataset: """ Read zarr or netCDF raster data into an xarray Dataset. Supports both zarr archives and NetCDF datasets via `xr.open_zarr` and `xr.open_mfdataset`. Optionally applies a preprocessing function defined in the driver options. Unused parameters (e.g., mask, zoom) are ignored but logged for transparency. Parameters ---------- uris : list[str] List of URIs pointing to zarr or netCDF files. handle_nodata : NoDataStrategy, optional Strategy for handling missing or empty data. Default is NoDataStrategy.RAISE. mask : Geom | None, optional Spatial mask or geometry (currently unused). Default is None. variables : Variables | None, optional List of variables to select from the dataset (currently unused). Default is None. zoom : Zoom | None, optional Zoom level or resolution (currently unused). Default is None. chunks : dict[str, Any] | None, optional Chunking configuration for Dask-based reading (currently unused). Default is None. metadata : SourceMetadata | None, optional Optional metadata about the dataset source (currently unused). Default is None. Returns ------- xr.Dataset The merged xarray Dataset. Raises ------ ValueError If the file extension is unsupported. Warning ------- The `mask`, `variables`, `zoom`, `chunks` and `metadata` parameters are not used directly in this driver, but are included for consistency with the GeoDataFrameDriver interface. """ _warn_on_unused_kwargs( self.__class__.__name__, { "mask": mask, "variables": variables, "zoom": zoom, "chunks": chunks, "metadata": metadata, }, ) if len(uris) == 0: return None # handle_nodata == ignore preprocessor = self.options.get_preprocessor() first_ext = self.options.get_ext_override(uris) if first_ext == _ZARR_EXT: opn: Callable = partial(xr.open_zarr, **self.options.get_kwargs()) datasets = [] for _uri in uris: ext = splitext(_uri)[-1] if ext != first_ext and not self.options.ext_override: logger.warning(f"Reading zarr and {_uri} was not, skipping...") else: datasets.append(preprocessor(opn(_uri))) ds: xr.Dataset = xr.merge(datasets) # Normal netcdf file(s) elif first_ext in _NETCDF_EXT: filtered_uris = [] for _uri in uris: ext = splitext(_uri)[-1] if ext != first_ext: logger.warning(f"Reading netcdf and {_uri} was not, skipping...") else: filtered_uris.append(_uri) ds = xr.open_mfdataset( filtered_uris, decode_coords="all", preprocess=preprocessor, **self.options.get_kwargs(), decode_timedelta=True, ) else: raise ValueError( f"Unknown extension for RasterDatasetXarrayDriver: {first_ext} " ) for variable in ds.data_vars: if ds[variable].size == 0: exec_nodata_strat( f"No data from driver: '{self.name}' for variable: '{variable}'", strategy=handle_nodata, ) return None # handle_nodata == ignore return ds
[docs] def write( self, path: Path | str, data: xr.Dataset, *, write_kwargs: dict[str, Any] | None = None, ) -> Path: """ Write a RasterDataset to disk using Zarr or NetCDF format. Supports writing datasets to `.zarr`, `.nc`, or `.netcdf` formats depending on the file extension. If an unsupported extension is provided, defaults to Zarr. Parameters ---------- path : Path | str Destination path for the dataset. data : xr.Dataset The xarray Dataset to write. write_kwargs : dict[str, Any] | None, optional Additional keyword arguments passed to `to_zarr` or `to_netcdf`. Default is None. Returns ------- Path The path to the written dataset. Raises ------ ValueError If the file extension is not recognized or supported. """ no_ext, ext = splitext(path) write_kwargs = write_kwargs or {} # set filepath if incompat if ext not in self.SUPPORTED_EXTENSIONS: logger.warning( f"Unknown extension for RasterDatasetXarrayDriver: {ext}," "switching to zarr" ) path = no_ext + _ZARR_EXT ext = _ZARR_EXT if ext == _ZARR_EXT: write_kwargs.setdefault("zarr_format", 2) data.to_zarr(path, mode="w", **write_kwargs) else: data.to_netcdf(path, **write_kwargs) return Path(path)