Source code for hydromt.data_catalog.drivers.raster.rasterio_driver

"""Driver using rasterio for RasterDataset."""

import logging
from pathlib import Path
from typing import Any, ClassVar

import numpy as np
import rasterio
import rasterio.errors
import xarray as xr
from pydantic import Field, field_serializer, model_validator
from pyproj import CRS

from hydromt._utils import _strip_scheme, cache_vrt_tiles, temp_env
from hydromt.config import SETTINGS
from hydromt.data_catalog.drivers.base_driver import DriverOptions
from hydromt.data_catalog.drivers.raster import RasterDatasetDriver
from hydromt.error import NoDataException, NoDataStrategy, exec_nodata_strat
from hydromt.gis._gdal_drivers import GDAL_DRIVER_CODE_MAP
from hydromt.gis.gis_utils import zoom_to_overview_level
from hydromt.readers import open_mfraster
from hydromt.typing import (
    Geom,
    SourceMetadata,
    Variables,
    Zoom,
)

logger = logging.getLogger(__name__)

_TIFF_EXT = ".tif"


[docs] class RasterioOptions(DriverOptions): """Options for RasterioDriver.""" KWARGS_FOR_OPEN: ClassVar[set[str]] = {"mosaic_kwargs"} mosaic: bool = Field( default=False, description="If True and multiple uris are given, will mosaic the datasets together using `rasterio.merge.merge`. Default is False.", ) mosaic_kwargs: dict[str, Any] = Field( default={}, description="Additional keyword arguments to pass to `rasterio.merge.merge`.", ) cache: bool = Field( default=False, description="If True and reading from VRT files, will cache the tiles locally to speed up reading. Default is False.", ) cache_root: str = Field( default=str(SETTINGS.cache_root), description="Root directory for caching. Default is taken from `hydromt.config.SETTINGS.cache_root`.", ) cache_dir: str | None = Field( default=None, description="Subdirectory for caching. Default is the stem of the first uri without extension.", )
[docs] def get_cache_path(self, uris: list[str]) -> Path: """Get the cache path based on the options and uris.""" if self.cache_dir is not None: cache_dir = Path(self.cache_root) / self.cache_dir else: # default to first uri without extension cache_dir = Path(self.cache_root) / Path(_strip_scheme(uris[0])[1]).stem return cache_dir
@model_validator(mode="after") def _convert_path_to_str(self): """Convert Path to str for pydantic compatibility.""" if isinstance(self.cache_root, Path): self.cache_root = self.cache_root.as_posix() if isinstance(self.cache_dir, Path): self.cache_dir = self.cache_dir.as_posix() return self
[docs] @field_serializer("cache_root", "cache_dir") def serialize_paths(self, value: Path) -> str | None: """Serialize Path to str for pydantic compatibility.""" if value is None: return None return Path(value).as_posix()
[docs] class RasterioDriver(RasterDatasetDriver): """ Driver for RasterDataset using the rasterio library: ``rasterio``. Supports reading and writing raster files using rasterio. """ name = "rasterio" supports_writing = True SUPPORTED_EXTENSIONS: ClassVar[set[str]] = { "." + extension for extension in GDAL_DRIVER_CODE_MAP.keys() } options: RasterioOptions = Field(default_factory=RasterioOptions)
[docs] def read( self, uris: list[str], *, handle_nodata: NoDataStrategy = NoDataStrategy.RAISE, mask: Geom | None = None, variables: Variables | None = None, zoom: Zoom | None = None, chunks: dict[str, Any] | None = None, metadata: SourceMetadata | None = None, ) -> xr.Dataset: """ Read raster data using the rasterio library. Supports reading single or multiple raster files (optionally mosaicked), applying spatial masks, caching VRT tiles, and reading overviews at different zoom levels. Returns an xarray Dataset constructed from raster bands. Parameters ---------- uris : list[str] List of raster file URIs to read. handle_nodata : NoDataStrategy, optional Strategy for handling missing or empty data. Default is NoDataStrategy.RAISE. mask : Geom | None, optional Geometry used to mask or clip the raster data. Default is None. variables : Variables | None, optional List of variables or band names to read. Default is None. zoom : Zoom | None, optional Requested zoom level or resolution. Used to determine the appropriate overview level. Default is None. chunks : dict[str, Any] | None, optional Dask chunking configuration for lazy loading. Default is None. metadata : SourceMetadata | None, optional Optional metadata describing CRS, nodata, and overview levels. Default is None. Returns ------- xr.Dataset The loaded raster dataset as an xarray Dataset. Raises ------ ValueError If the file extension is unsupported or invalid. rasterio.errors.RasterioIOError If an I/O error occurs during reading. """ if len(uris) == 0: return None # handle_nodata == ignore if metadata is None: metadata = SourceMetadata() # Caching portion, only when the flag is True and the file format is vrt if all(uri.endswith(".vrt") for uri in uris) and self.options.cache: cache_dir: Path = self.options.get_cache_path(uris) uris_cached = [] for uri in uris: cached_uri = cache_vrt_tiles( uri, geom=mask, fs=self.filesystem.get_fs(), cache_dir=cache_dir ) uris_cached.append(cached_uri) uris = uris_cached if mask is not None: self.options.mosaic_kwargs.update({"mask": mask}) open_kwargs = self.options.get_kwargs() if np.issubdtype(type(metadata.nodata), np.number): open_kwargs.update({"nodata": metadata.nodata}) # Fix overview level if zoom: try: zls_dict: dict[int, float] = metadata.zls_dict crs: CRS | None = metadata.crs except AttributeError: # pydantic extra=allow on SourceMetadata zls_dict, crs = self._get_zoom_levels_and_crs(uris[0]) overview_level: int | None = zoom_to_overview_level( zoom, mask, zls_dict, crs ) if overview_level: # NOTE: overview levels start at zoom_level 1, see _get_zoom_levels_and_crs open_kwargs.update(overview_level=overview_level - 1) if chunks is not None: open_kwargs.update({"chunks": chunks}) mosaic: bool = self.options.mosaic and len(uris) > 1 mosaic_kwargs = open_kwargs.pop("mosaic_kwargs", {}) if mosaic_kwargs and not mosaic: logger.warning( "mosaic_kwargs provided but mosaic is False. Ignoring mosaic_kwargs. To use mosaic_kwargs, set mosaic=True in driver options." ) # If the metadata resolver has already resolved the overview level, # trying to open zoom levels here will result in an error. # Better would be to separate uriresolver and driver: https://github.com/Deltares/hydromt/issues/1023 # Then we can implement looking for a overview level in the driver. def _open() -> xr.Dataset: try: return open_mfraster( uris, mosaic=mosaic, mosaic_kwargs=mosaic_kwargs, **open_kwargs ) except rasterio.errors.RasterioIOError as e: if "Cannot open overview level" in str(e): open_kwargs.pop("overview_level", None) return open_mfraster( uris, mosaic=mosaic, mosaic_kwargs=mosaic_kwargs, **open_kwargs ) else: raise # rasterio uses specific environment variable for s3 access. try: anon: str = self.filesystem.get_fs().anon except AttributeError: anon = "" if anon: with temp_env(**{"AWS_NO_SIGN_REQUEST": "true"}): ds = _open() else: ds = _open() # Mosaic's can mess up the chunking, which can error during writing # Or maybe setting chunks = open_kwargs.get("chunks", None) if chunks is not None: ds = ds.chunk(chunks=chunks) # rename ds with single band if single variable is requested if variables is not None and len(variables) == 1 and len(ds.data_vars) == 1: ds = ds.rename({list(ds.data_vars.keys())[0]: list(variables)[0]}) for variable in ds.data_vars: if ds[variable].size == 0: exec_nodata_strat( f"No data from driver: '{self.name}' for variable: '{variable}'", strategy=handle_nodata, ) return None # handle_nodata == ignore return ds
[docs] def write( self, path: Path | str, data: xr.Dataset | xr.DataArray, *, write_kwargs: dict[str, Any] | None = None, ) -> Path: """ Write a RasterDataset to disk using the rasterio library. This method is not implemented in this driver. Concrete implementations must provide a way to write raster datasets to supported formats. Parameters ---------- path : Path | str Destination path for the raster dataset. data : xr.DataArray | xr.Dataset The xarray DataArray or Dataset to write. write_kwargs : dict[str, Any] | None, optional Additional keyword arguments for writing. Default is None. Returns ------- Path The path to the written raster dataset. """ path = Path(path) write_kwargs = write_kwargs or {} if path.suffix not in self.SUPPORTED_EXTENSIONS: raise ValueError(f"Unknown extension for RasterioDriver: {path.suffix}") if path.suffix == ".vrt": logger.warning( "Writing to VRT format is not supported by RasterioDriver, will attempt to write as GeoTIFF instead." ) path = path.with_suffix(_TIFF_EXT) gdal_driver = GDAL_DRIVER_CODE_MAP.get(path.suffix.lstrip(".").lower()) if "*" in str(path) and isinstance(data, xr.DataArray): if len(data.dims) < 3: raise ValueError( "Writing multiple files with wildcard requires at least 3 dimensions in data array" ) path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) if path.name.count("*") != 1: raise ValueError( "There must be exactly one wildcard `*` in the filename when multiple outputs required" ) dim0 = data.dims[0] for label in data[dim0]: ds_sel = data.sel({dim0: label}) file_name = path.name.replace("*", f"{label.values}") self._write_raster( ds_sel, gdal_driver, path.with_name(file_name), **write_kwargs ) return path if isinstance(data, xr.Dataset): if len(data.data_vars) == 1: data = data[list(data.data_vars.keys())[0]] else: path.parent.mkdir(parents=True, exist_ok=True) for var in data.data_vars: if "*" in path.name: file_name = path.name.replace("*", var) file_path = path.with_name(file_name) else: file_path = path.parent / f"{var}{path.suffix}" data_raster = data[var] self._write_raster( data_raster, gdal_driver, file_path, **write_kwargs ) return path if "*" in path.name else path.parent / f"*{path.suffix}" self._write_raster(data, gdal_driver, path, **write_kwargs) return path
@staticmethod def _get_zoom_levels_and_crs(uri: str) -> tuple[dict[int, float], int]: """Get zoom levels and crs from adapter or detect from tif file if missing.""" zoom_levels = {} crs = None try: with rasterio.open(uri) as src: res = abs(src.res[0]) crs = src.crs overviews = [src.overviews(i) for i in src.indexes] if len(overviews[0]) > 0: # check overviews for band 0 # check if identical if not all([o == overviews[0] for o in overviews]): raise ValueError("Overviews are not identical across bands") # dict with overview level and corresponding resolution zls = [1] + overviews[0] zoom_levels = {i: res * zl for i, zl in enumerate(zls)} except rasterio.RasterioIOError as e: logger.warning(f"IO error while detecting zoom levels: {e}") return zoom_levels, crs def _write_raster( self, data: xr.DataArray, driver: str, path: Path, **write_kwargs: Any ) -> None: """Write raster data to file using rasterio.""" y_coords = data[data.raster.y_dim] x_coords = data[data.raster.x_dim] if ( y_coords.size < 2 or (y_coords.ndim == 2 and y_coords.shape[0] < 2) or x_coords.size < 2 or (x_coords.ndim == 2 and x_coords.shape[1] < 2) ): raise NoDataException( f"Cannot write raster data with insufficient spatial dimensions: {data.raster.y_dim} size {y_coords.size}, {data.raster.x_dim} size {x_coords.size}", ) data.raster.to_raster(path, driver=driver, **write_kwargs)