"""Driver using rasterio for RasterDataset."""
import copy
from logging import Logger, getLogger
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import rasterio
import rasterio.errors
import xarray as xr
from pyproj import CRS
from hydromt._io.readers import _open_mfraster
from hydromt._typing import (
    Geom,
    SourceMetadata,
    StrPath,
    TimeRange,
    Variables,
    Zoom,
)
from hydromt._typing.error import NoDataStrategy, exec_nodata_strat
from hydromt._utils.caching import _cache_vrt_tiles
from hydromt._utils.temp_env import temp_env
from hydromt._utils.uris import _strip_scheme
from hydromt.config import SETTINGS
from hydromt.data_catalog.drivers import RasterDatasetDriver
from hydromt.gis._gis_utils import _zoom_to_overview_level
logger: Logger = getLogger(__name__)
[docs]
class RasterioDriver(RasterDatasetDriver):
    """Driver using rasterio for RasterDataset."""
    name = "rasterio"
[docs]
    def read(
        self,
        uris: List[str],
        *,
        mask: Optional[Geom] = None,
        time_range: Optional[TimeRange] = None,
        variables: Optional[Variables] = None,
        zoom: Optional[Zoom] = None,
        chunks: Optional[dict] = None,
        metadata: Optional[SourceMetadata] = None,
        handle_nodata: NoDataStrategy = NoDataStrategy.RAISE,
    ) -> xr.Dataset:
        """Read data using rasterio."""
        if metadata is None:
            metadata = SourceMetadata()
        # build up kwargs for open_raster
        options = copy.deepcopy(self.options)
        mosaic_kwargs: Dict[str, Any] = self.options.get("mosaic_kwargs", {})
        mosaic: bool = options.pop("mosaic", False) and len(uris) > 1
        # get source-specific options
        cache_root: str = str(
            options.pop("cache_root", SETTINGS.cache_root),
        )
        # Check for caching, default to false
        cache_flag = options.pop("cache", False)
        # Caching portion, only when the flag is True and the file format is vrt
        if all([uri.endswith(".vrt") for uri in uris]) and cache_flag:
            cache_dir = Path(cache_root) / options.pop(
                "cache_dir",
                Path(
                    _strip_scheme(uris[0])[1]
                ).stem,  # default to first uri without extension
            )
            uris_cached = []
            for uri in uris:
                cached_uri: str = _cache_vrt_tiles(
                    uri, geom=mask, fs=self.filesystem, cache_dir=cache_dir
                )
                uris_cached.append(cached_uri)
            uris = uris_cached
        if mask is not None:
            mosaic_kwargs.update({"mask": mask})
        # get mosaic kwargs
        if mosaic_kwargs:
            options.update({"mosaic_kwargs": mosaic_kwargs})
        if np.issubdtype(type(metadata.nodata), np.number):
            options.update(nodata=metadata.nodata)
        # Fix overview level
        if zoom:
            try:
                zls_dict: Dict[int, float] = metadata.zls_dict
                crs: Optional[CRS] = metadata.crs
            except AttributeError:  # pydantic extra=allow on SourceMetadata
                zls_dict, crs = self._get_zoom_levels_and_crs(uris[0])
            overview_level: Optional[int] = _zoom_to_overview_level(
                zoom, mask, zls_dict, crs
            )
            if overview_level:
                # NOTE: overview levels start at zoom_level 1, see _get_zoom_levels_and_crs
                options.update(overview_level=overview_level - 1)
        if chunks is not None:
            options.update({"chunks": chunks})
        # If the metadata resolver has already resolved the overview level,
        # trying to open zoom levels here will result in an error.
        # Better would be to seperate uriresolver and driver: https://github.com/Deltares/hydromt/issues/1023
        # Then we can implement looking for a overview level in the driver.
        def _open() -> Union[xr.DataArray, xr.Dataset]:
            try:
                return _open_mfraster(uris, mosaic=mosaic, **options)
            except rasterio.errors.RasterioIOError as e:
                if "Cannot open overview level" in str(e):
                    options.pop("overview_level")
                    return _open_mfraster(uris, mosaic=mosaic, **options)
                else:
                    raise
        # rasterio uses specific environment variable for s3 access.
        try:
            anon: str = self.filesystem.anon
        except AttributeError:
            anon: str = ""
        if anon:
            with temp_env(**{"AWS_NO_SIGN_REQUEST": "true"}):
                ds = _open()
        else:
            ds = _open()
        # Mosiac's can mess up the chunking, which can error during writing
        # Or maybe setting
        chunks = options.get("chunks")
        if chunks is not None:
            ds = ds.chunk(chunks=chunks)
        # rename ds with single band if single variable is requested
        if variables is not None and len(variables) == 1 and len(ds.data_vars) == 1:
            ds = ds.rename({list(ds.data_vars.keys())[0]: list(variables)[0]})
        for variable in ds.data_vars:
            if ds[variable].size == 0:
                exec_nodata_strat(
                    f"No data from driver: '{self.name}' for variable: '{variable}'",
                    strategy=handle_nodata,
                )
        return ds 
[docs]
    def write(self, path: StrPath, ds: xr.Dataset, **kwargs) -> str:
        """Write out a RasterDataset using rasterio."""
        raise NotImplementedError() 
    @staticmethod
    def _get_zoom_levels_and_crs(uri: str) -> Tuple[Dict[int, float], int]:
        """Get zoom levels and crs from adapter or detect from tif file if missing."""
        zoom_levels = {}
        crs = None
        try:
            with rasterio.open(uri) as src:
                res = abs(src.res[0])
                crs = src.crs
                overviews = [src.overviews(i) for i in src.indexes]
                if len(overviews[0]) > 0:  # check overviews for band 0
                    # check if identical
                    if not all([o == overviews[0] for o in overviews]):
                        raise ValueError("Overviews are not identical across bands")
                    # dict with overview level and corresponding resolution
                    zls = [1] + overviews[0]
                    zoom_levels = {i: res * zl for i, zl in enumerate(zls)}
        except rasterio.RasterioIOError as e:
            logger.warning(f"IO error while detecting zoom levels: {e}")
        return zoom_levels, crs