Source code for hydromt.data_catalog.drivers.raster.rasterio_driver

"""Driver using rasterio for RasterDataset."""

import copy
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import rasterio
import rasterio.errors
import xarray as xr
from pyproj import CRS

from hydromt._io.readers import _open_mfraster
from hydromt._typing import (
    Geom,
    SourceMetadata,
    StrPath,
    TimeRange,
    Variables,
    Zoom,
)
from hydromt._typing.error import NoDataStrategy, exec_nodata_strat
from hydromt._utils.caching import _cache_vrt_tiles
from hydromt._utils.temp_env import temp_env
from hydromt._utils.uris import _strip_scheme
from hydromt.config import SETTINGS
from hydromt.data_catalog.drivers import RasterDatasetDriver
from hydromt.gis._gis_utils import _zoom_to_overview_level

logger: Logger = getLogger(__name__)


[docs] class RasterioDriver(RasterDatasetDriver): """Driver using rasterio for RasterDataset.""" name = "rasterio"
[docs] def read( self, uris: List[str], *, mask: Optional[Geom] = None, time_range: Optional[TimeRange] = None, variables: Optional[Variables] = None, zoom: Optional[Zoom] = None, chunks: Optional[dict] = None, metadata: Optional[SourceMetadata] = None, handle_nodata: NoDataStrategy = NoDataStrategy.RAISE, ) -> xr.Dataset: """Read data using rasterio.""" if metadata is None: metadata = SourceMetadata() # build up kwargs for open_raster options = copy.deepcopy(self.options) mosaic_kwargs: Dict[str, Any] = self.options.get("mosaic_kwargs", {}) mosaic: bool = options.pop("mosaic", False) and len(uris) > 1 # get source-specific options cache_root: str = str( options.pop("cache_root", SETTINGS.cache_root), ) # Check for caching, default to false cache_flag = options.pop("cache", False) # Caching portion, only when the flag is True and the file format is vrt if all([uri.endswith(".vrt") for uri in uris]) and cache_flag: cache_dir = Path(cache_root) / options.pop( "cache_dir", Path( _strip_scheme(uris[0])[1] ).stem, # default to first uri without extension ) uris_cached = [] for uri in uris: cached_uri: str = _cache_vrt_tiles( uri, geom=mask, fs=self.filesystem, cache_dir=cache_dir ) uris_cached.append(cached_uri) uris = uris_cached if mask is not None: mosaic_kwargs.update({"mask": mask}) # get mosaic kwargs if mosaic_kwargs: options.update({"mosaic_kwargs": mosaic_kwargs}) if np.issubdtype(type(metadata.nodata), np.number): options.update(nodata=metadata.nodata) # Fix overview level if zoom: try: zls_dict: Dict[int, float] = metadata.zls_dict crs: Optional[CRS] = metadata.crs except AttributeError: # pydantic extra=allow on SourceMetadata zls_dict, crs = self._get_zoom_levels_and_crs(uris[0]) overview_level: Optional[int] = _zoom_to_overview_level( zoom, mask, zls_dict, crs ) if overview_level: # NOTE: overview levels start at zoom_level 1, see _get_zoom_levels_and_crs options.update(overview_level=overview_level - 1) if chunks is not None: options.update({"chunks": chunks}) # If the metadata resolver has already resolved the overview level, # trying to open zoom levels here will result in an error. # Better would be to seperate uriresolver and driver: https://github.com/Deltares/hydromt/issues/1023 # Then we can implement looking for a overview level in the driver. def _open() -> Union[xr.DataArray, xr.Dataset]: try: return _open_mfraster(uris, mosaic=mosaic, **options) except rasterio.errors.RasterioIOError as e: if "Cannot open overview level" in str(e): options.pop("overview_level") return _open_mfraster(uris, mosaic=mosaic, **options) else: raise # rasterio uses specific environment variable for s3 access. try: anon: str = self.filesystem.anon except AttributeError: anon: str = "" if anon: with temp_env(**{"AWS_NO_SIGN_REQUEST": "true"}): ds = _open() else: ds = _open() # Mosiac's can mess up the chunking, which can error during writing # Or maybe setting chunks = options.get("chunks") if chunks is not None: ds = ds.chunk(chunks=chunks) # rename ds with single band if single variable is requested if variables is not None and len(variables) == 1 and len(ds.data_vars) == 1: ds = ds.rename({list(ds.data_vars.keys())[0]: list(variables)[0]}) for variable in ds.data_vars: if ds[variable].size == 0: exec_nodata_strat( f"No data from driver: '{self.name}' for variable: '{variable}'", strategy=handle_nodata, ) return ds
[docs] def write(self, path: StrPath, ds: xr.Dataset, **kwargs) -> str: """Write out a RasterDataset using rasterio.""" raise NotImplementedError()
@staticmethod def _get_zoom_levels_and_crs(uri: str) -> Tuple[Dict[int, float], int]: """Get zoom levels and crs from adapter or detect from tif file if missing.""" zoom_levels = {} crs = None try: with rasterio.open(uri) as src: res = abs(src.res[0]) crs = src.crs overviews = [src.overviews(i) for i in src.indexes] if len(overviews[0]) > 0: # check overviews for band 0 # check if identical if not all([o == overviews[0] for o in overviews]): raise ValueError("Overviews are not identical across bands") # dict with overview level and corresponding resolution zls = [1] + overviews[0] zoom_levels = {i: res * zl for i, zl in enumerate(zls)} except rasterio.RasterioIOError as e: logger.warning(f"IO error while detecting zoom levels: {e}") return zoom_levels, crs