Source code for hydromt.data_catalog.drivers.raster.rasterio_driver

"""Driver using rasterio for RasterDataset."""

from logging import Logger, getLogger
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import rasterio
import rasterio.errors
import xarray as xr
from pyproj import CRS

from hydromt._io.readers import _open_mfraster
from hydromt._typing import (
    Geom,
    SourceMetadata,
    StrPath,
    TimeRange,
    Variables,
    Zoom,
)
from hydromt._typing.error import NoDataStrategy, exec_nodata_strat
from hydromt._utils.caching import _cache_vrt_tiles
from hydromt._utils.temp_env import temp_env
from hydromt._utils.unused_kwargs import _warn_on_unused_kwargs
from hydromt._utils.uris import _strip_scheme
from hydromt.config import SETTINGS
from hydromt.data_catalog.drivers import RasterDatasetDriver
from hydromt.gis._gis_utils import _zoom_to_overview_level

logger: Logger = getLogger(__name__)


[docs] class RasterioDriver(RasterDatasetDriver): """Driver using rasterio for RasterDataset.""" name = "rasterio"
[docs] def read( self, uris: List[str], *, mask: Optional[Geom] = None, time_range: Optional[TimeRange] = None, variables: Optional[Variables] = None, zoom: Optional[Zoom] = None, metadata: Optional[SourceMetadata] = None, handle_nodata: NoDataStrategy = NoDataStrategy.RAISE, ) -> xr.Dataset: """Read data using rasterio.""" if metadata is None: metadata = SourceMetadata() # build up kwargs for open_raster _warn_on_unused_kwargs( self.__class__.__name__, {"time_range": time_range}, ) kwargs: Dict[str, Any] = {} mosaic_kwargs: Dict[str, Any] = self.options.get("mosaic_kwargs", {}) mosaic: bool = self.options.get("mosaic", False) and len(uris) > 1 # get source-specific options cache_root: str = str( self.options.get("cache_root", SETTINGS.cache_root), ) if cache_root is not None and all([uri.endswith(".vrt") for uri in uris]): cache_dir = Path(cache_root) / self.options.get( "cache_dir", Path( _strip_scheme(uris[0])[1] ).stem, # default to first uri without extension ) uris_cached = [] for uri in uris: cached_uri: str = _cache_vrt_tiles( uri, geom=mask, fs=self.filesystem, cache_dir=cache_dir ) uris_cached.append(cached_uri) uris = uris_cached if mask is not None: mosaic_kwargs.update({"mask": mask}) # get mosaic kwargs if mosaic_kwargs: kwargs.update({"mosaic_kwargs": mosaic_kwargs}) if np.issubdtype(type(metadata.nodata), np.number): kwargs.update(nodata=metadata.nodata) # Fix overview level if zoom: try: zls_dict: Dict[int, float] = metadata.zls_dict crs: Optional[CRS] = metadata.crs except AttributeError: # pydantic extra=allow on SourceMetadata zls_dict, crs = self._get_zoom_levels_and_crs(uris[0]) overview_level: Optional[int] = _zoom_to_overview_level( zoom, mask, zls_dict, crs ) if overview_level: # NOTE: overview levels start at zoom_level 1, see _get_zoom_levels_and_crs kwargs.update(overview_level=overview_level - 1) # If the metadata resolver has already resolved the overview level, # trying to open zoom levels here will result in an error. # Better would be to seperate uriresolver and driver: https://github.com/Deltares/hydromt/issues/1023 # Then we can implement looking for a overview level in the driver. def _open() -> Union[xr.DataArray, xr.Dataset]: try: return _open_mfraster(uris, mosaic=mosaic, **kwargs) except rasterio.errors.RasterioIOError as e: if "Cannot open overview level" in str(e): kwargs.pop("overview_level") return _open_mfraster(uris, mosaic=mosaic, **kwargs) else: raise # rasterio uses specific environment variable for s3 access. try: anon: str = self.filesystem.anon except AttributeError: anon: str = "" if anon: with temp_env(**{"AWS_NO_SIGN_REQUEST": "true"}): ds = _open() else: ds = _open() # rename ds with single band if single variable is requested if variables is not None and len(variables) == 1 and len(ds.data_vars) == 1: ds = ds.rename({list(ds.data_vars.keys())[0]: list(variables)[0]}) for variable in ds.data_vars: if ds[variable].size == 0: exec_nodata_strat( f"No data from driver: '{self.name}' for variable: '{variable}'", strategy=handle_nodata, ) return ds
[docs] def write(self, path: StrPath, ds: xr.Dataset, **kwargs) -> str: """Write out a RasterDataset using rasterio.""" raise NotImplementedError()
@staticmethod def _get_zoom_levels_and_crs(uri: str) -> Tuple[Dict[int, float], int]: """Get zoom levels and crs from adapter or detect from tif file if missing.""" zoom_levels = {} crs = None try: with rasterio.open(uri) as src: res = abs(src.res[0]) crs = src.crs overviews = [src.overviews(i) for i in src.indexes] if len(overviews[0]) > 0: # check overviews for band 0 # check if identical if not all([o == overviews[0] for o in overviews]): raise ValueError("Overviews are not identical across bands") # dict with overview level and corresponding resolution zls = [1] + overviews[0] zoom_levels = {i: res * zl for i, zl in enumerate(zls)} except rasterio.RasterioIOError as e: logger.warning(f"IO error while detecting zoom levels: {e}") return zoom_levels, crs