Source code for hydromt.models.model_grid

# -*- coding: utf-8 -*-
"""HydroMT GridModel class definition"""

from typing import Dict, List, Tuple, Union, Optional
import logging
from os.path import join, isfile
import xarray as xr
import numpy as np
import geopandas as gpd
from shapely.geometry import box
from pyproj import CRS

from .model_api import Model
from .. import workflows

__all__ = ["GridModel"]
logger = logging.getLogger(__name__)


class GridMixin(object):
    # placeholders
    # xr.Dataset representation of all static parameter maps at the same resolution and bounds - renamed from staticmaps
    _API = {"grid": xr.Dataset}

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._grid = xr.Dataset()

    @property
    def grid(self):
        """Model static gridded data. Returns xarray.Dataset.
        Previously called staticmaps."""
        if len(self._grid) == 0 and self._read:
            self.read_grid()
        return self._grid

    def set_grid(
        self,
        data: Union[xr.DataArray, xr.Dataset, np.ndarray],
        name: Optional[str] = None,
    ):
        """Add data to grid.

        All layers of grid must have identical spatial coordinates.

        Parameters
        ----------
        data: xarray.DataArray or xarray.Dataset
            new map layer to add to grid
        name: str, optional
            Name of new map layer, this is used to overwrite the name of a DataArray
            and ignored if data is a Dataset
        """
        # NOTE: variables in a dataset are not longer renamed as used to be the case in set_staticmaps
        name_required = isinstance(data, np.ndarray) or (
            isinstance(data, xr.DataArray) and data.name is None
        )
        if name is None and name_required:
            raise ValueError(f"Unable to set {type(data).__name__} data without a name")
        if isinstance(data, np.ndarray):
            if data.shape != self.grid.raster.shape:
                raise ValueError("Shape of data and grid maps do not match")
            data = xr.DataArray(dims=self.grid.raster.dims, data=data, name=name)
        if isinstance(data, xr.DataArray):
            if name is not None:  # rename
                data.name = name
            data = data.to_dataset()
        elif not isinstance(data, xr.Dataset):
            raise ValueError(f"cannot set data of type {type(data).__name__}")
        for dvar in data.data_vars:
            if dvar in self._grid:
                if self._read:
                    self.logger.warning(f"Replacing grid map: {dvar}")
            self._grid[dvar] = data[dvar]

    def read_grid(self, fn: str = "grid/grid.nc", **kwargs) -> None:
        """Read model grid data at <root>/<fn> and add to grid property

        key-word arguments are passed to :py:func:`xarray.open_dataset`

        Parameters
        ----------
        fn : str, optional
            filename relative to model root, by default 'grid/grid.nc'
        """
        self._assert_read_mode
        for ds in self._read_nc(fn, **kwargs).values():
            self.set_grid(ds)

    def write_grid(self, fn: str = "grid/grid.nc", **kwargs) -> None:
        """Write model grid data to netcdf file at <root>/<fn>

        key-word arguments are passed to :py:meth:`xarray.Dataset.to_netcdf`

        Parameters
        ----------
        fn : str, optional
            filename relative to model root, by default 'grid/grid.nc'
        """
        if len(self._grid) == 0:
            self.logger.debug("No grid data found, skip writing.")
        else:
            self._assert_write_mode
            # _write_nc requires dict - use dummy 'grid' key
            self._write_nc({"grid": self._grid}, fn, **kwargs)


[docs]class GridModel(GridMixin, Model): """Model class Grid Model for gridded models in HydroMT""" # TODO: add here "res": "setup_region" or "res": "setup_grid" when generic method is available _CLI_ARGS = {"region": "setup_region"} _NAME = "grid_model"
[docs] def __init__( self, root: str = None, mode: str = "w", config_fn: str = None, data_libs: List[str] = None, logger=logger, ): """Initialize a GridModel for distributed models with a regular grid.""" super().__init__( root=root, mode=mode, config_fn=config_fn, data_libs=data_libs, logger=logger, )
def read( self, components: List = [ "config", "grid", "geoms", "forcing", "states", "results", ], ) -> None: """Read the complete model schematization and configuration from model files. Parameters ---------- components : List, optional List of model components to read, each should have an associated read_<component> method. By default ['config', 'maps', 'grid', 'geoms', 'forcing', 'states', 'results'] """ super().read(components=components) def write( self, components: List = ["config", "grid", "geoms", "forcing", "states"], ) -> None: """Write the complete model schematization and configuration to model files. Parameters ---------- components : List, optional List of model components to write, each should have an associated write_<component> method. By default ['config', 'maps', 'grid', 'geoms', 'forcing', 'states'] """ super().write(components=components) # Properties for subclass GridModel @property def res(self) -> Tuple[float, float]: """Returns the resolution of the model grid.""" if len(self._grid) > 0: return self.grid.raster.res @property def transform(self): """Returns spatial transform of the model grid.""" if len(self._grid) > 0: return self.grid.raster.transform @property def crs(self) -> Union[CRS, None]: """Returns coordinate reference system embedded in the model grid.""" if len(self._grid) > 0: return CRS(self._grid.raster.crs) @property def bounds(self) -> List[float]: """Returns the bounding box of the model grid.""" if len(self._grid) > 0: return self.grid.raster.bounds @property def region(self) -> gpd.GeoDataFrame: """Returns the geometry of the model area of interest.""" region = gpd.GeoDataFrame() if "region" in self.geoms: region = self.geoms["region"] elif len(self.grid) > 0: crs = self.grid.raster.crs if crs is None and hasattr(crs, "to_epsg"): crs = crs.to_epsg() # not all CRS have an EPSG code region = gpd.GeoDataFrame(geometry=[box(*self.bounds)], crs=crs) return region