Source code for hydromt.models.model_mesh

"""Implementations for model mesh workloads."""
import logging
import os
from os.path import dirname, isdir, join
from pathlib import Path
from typing import Dict, List, Optional, Union

import geopandas as gpd
import pandas as pd
import xarray as xr
import xugrid as xu
from pyproj import CRS
from shapely.geometry import box

from .. import workflows
from ..raster import GEO_MAP_COORD
from .model_api import Model

__all__ = ["MeshModel"]
logger = logging.getLogger(__name__)


class MeshMixin(object):
    # placeholders
    # We cannot initialize an empty xu.UgridDataArray
    _API = {
        "mesh": Union[xu.UgridDataArray, xu.UgridDataset],
    }

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._mesh = None

    ## general setup methods
    def setup_mesh2d_from_rasterdataset(
        self,
        raster_fn: Union[str, Path, xr.DataArray, xr.Dataset],
        grid_name: Optional[str] = "mesh2d",
        variables: Optional[list] = None,
        fill_method: Optional[str] = None,
        resampling_method: Optional[Union[str, List]] = "centroid",
        rename: Optional[Dict] = None,
    ) -> List[str]:
        """HYDROMT CORE METHOD: Add data variable(s) from ``raster_fn`` to 2D ``grid_name`` in mesh object.

        Raster data is interpolated to the mesh ``grid_name`` using the ``resampling_method``.
        If raster is a dataset, all variables will be added unless ``variables`` list
        is specified.

        Adds model layers:

        * **raster.name** mesh: data from raster_fn

        Parameters
        ----------
        raster_fn: str, Path, xr.DataArray, xr.Dataset
            Data catalog key, path to raster file or raster xarray data object.
        grid_name: str, optional
            Name of the mesh grid to add the data to. By default 'mesh2d'.
        variables: list, optional
            List of variables to add to mesh from raster_fn. By default all.
        fill_method : str, optional
            If specified, fills no data values using fill_nodata method.
            Available methods are {'linear', 'nearest', 'cubic', 'rio_idw'}.
        resampling_method: str, list, optional
            Method to sample from raster data to mesh. By default mean. Options include
            {"centroid", "barycentric", "mean", "harmonic_mean", "geometric_mean", "sum",
            "minimum", "maximum", "mode", "median", "max_overlap"}. If centroid, will use
            :py:meth:`xugrid.CentroidLocatorRegridder` method. If barycentric, will use
            :py:meth:`xugrid.BarycentricInterpolator` method. If any other, will use
            :py:meth:`xugrid.OverlapRegridder` method.
            Can provide a list corresponding to ``variables``.
        rename: dict, optional
            Dictionary to rename variable names in raster_fn before adding to mesh
            {'name_in_raster_fn': 'name_in_mesh'}. By default empty.

        Returns
        -------
        list
            List of variables added to mesh.
        """  # noqa: E501
        self.logger.info(f"Preparing mesh data from raster source {raster_fn}")
        # Check if grid name in self.mesh
        if grid_name not in self.mesh_names:
            raise ValueError(f"Grid name {grid_name} not in mesh ({self.mesh_names}).")
        # Read raster data and select variables
        bounds = self.mesh_gdf[grid_name].to_crs(4326).total_bounds
        ds = self.data_catalog.get_rasterdataset(
            raster_fn,
            bbox=bounds,
            buffer=2,
            variables=variables,
            single_var_as_array=False,
        )

        uds_sample = workflows.mesh2d_from_rasterdataset(
            ds=ds,
            mesh2d=self.mesh_grids[grid_name],
            variables=variables,
            fill_method=fill_method,
            resampling_method=resampling_method,
            rename=rename,
            logger=self.logger,
        )

        self.set_mesh(uds_sample, grid_name=grid_name, overwrite_grid=False)

        return list(uds_sample.data_vars.keys())

    def setup_mesh2d_from_raster_reclass(
        self,
        raster_fn: Union[str, Path, xr.DataArray],
        reclass_table_fn: Union[str, Path, pd.DataFrame],
        reclass_variables: list,
        grid_name: Optional[str] = "mesh2d",
        variable: Optional[str] = None,
        fill_method: Optional[str] = None,
        resampling_method: Optional[Union[str, list]] = "centroid",
        rename: Optional[Dict] = None,
        **kwargs,
    ) -> List[str]:
        """HYDROMT CORE METHOD: Add data variable(s) to 2D ``grid_name`` in mesh object by reclassifying the data in ``raster_fn`` based on ``reclass_table_fn``.

        The reclassified raster data
        are subsequently interpolated to the mesh using `resampling_method`.

        Adds model layers:

        * **reclass_variables** mesh: reclassified raster data interpolated to the
            model mesh

        Parameters
        ----------
        raster_fn : str, Path, xr.DataArray
            Data catalog key, path to the raster file, or raster xarray data object.
            Should be a DataArray. If not, use the `variable` argument for selection.
        reclass_table_fn : str, Path, pd.DataFrame
            Data catalog key, path to the tabular data file, or tabular pandas dataframe
            object for the reclassification table of `raster_fn`.
        reclass_variables : list
            List of reclass_variables from the reclass_table_fn table to add to the
            mesh. The index column should match values in raster_fn.
        grid_name : str, optional
            Name of the mesh grid to add the data to. By default 'mesh2d'.
        variable : str, optional
            Name of the raster dataset variable to use. This is only required when
            reading datasets with multiple variables. By default, None.
        fill_method : str, optional
            If specified, fills nodata values in `raster_fn` using the `fill_method`
            method before reclassifying. Available methods are
            {'linear', 'nearest', 'cubic', 'rio_idw'}.
        resampling_method : str or list, optional
            Method to sample from raster data to mesh. By default mean. Options include
            {"centroid", "barycentric", "mean", "harmonic_mean", "geometric_mean", "sum",
            "minimum", "maximum", "mode", "median", "max_overlap"}. If centroid, will use
            :py:meth:`xugrid.CentroidLocatorRegridder` method. If barycentric, will use
            :py:meth:`xugrid.BarycentricInterpolator` method. If any other, will use
            :py:meth:`xugrid.OverlapRegridder` method.
            Can provide a list corresponding to ``reclass_variables``.
        rename : dict, optional
            Dictionary to rename variable names in `reclass_variables` before adding
            them to the mesh. The dictionary should have the form
            {'name_in_reclass_table': 'name_in_mesh'}. By default, an empty dictionary.
        **kwargs : dict
            Additional keyword arguments to be passed to the raster dataset
            retrieval method.

        Returns
        -------
        variable_names : List[str]
            List of added variable names in the mesh.

        Raises
        ------
        ValueError
            If `raster_fn` is not a single variable raster.
        """  # noqa: E501
        self.logger.info(
            f"Preparing mesh data by reclassifying the data in {raster_fn} "
            f"based on {reclass_table_fn}."
        )
        # Check if grid name in self.mesh
        if grid_name not in self.mesh_names:
            raise ValueError(f"Grid name {grid_name} not in mesh ({self.mesh_names}).")
        # Read raster data and mapping table
        bounds = self.mesh_gdf[grid_name].to_crs(4326).total_bounds
        da = self.data_catalog.get_rasterdataset(
            raster_fn,
            bbox=bounds,
            buffer=2,
            variables=variable,
            **kwargs,
        )
        if not isinstance(da, xr.DataArray):
            raise ValueError(
                f"raster_fn {raster_fn} should be a single variable raster. "
                "Please select one using the 'variable' argument"
            )
        df_vars = self.data_catalog.get_dataframe(
            reclass_table_fn, variables=reclass_variables
        )

        uds_sample = workflows.mesh2d_from_raster_reclass(
            da=da,
            df_vars=df_vars,
            mesh2d=self.mesh_grids[grid_name],
            reclass_variables=reclass_variables,
            fill_method=fill_method,
            resampling_method=resampling_method,
            rename=rename,
            logger=self.logger,
        )

        self.set_mesh(uds_sample, grid_name=grid_name, overwrite_grid=False)

        return list(uds_sample.data_vars.keys())

    @property
    def mesh(self) -> Union[xu.UgridDataArray, xu.UgridDataset]:
        """
        Model static mesh data. It returns a xugrid.UgridDataset.

        Mesh can contain several grids (1D, 2D, 3D) defined according
        to UGRID conventions. To extract a specific grid, use get_mesh
        method.
        """
        # XU grid data type Xarray dataset with xu sampling.
        if self._mesh is None and self._read:
            self.read_mesh()
        return self._mesh

    def set_mesh(
        self,
        data: Union[xu.UgridDataArray, xu.UgridDataset],
        name: Optional[str] = None,
        grid_name: Optional[str] = None,
        overwrite_grid: Optional[bool] = False,
    ) -> None:
        """Add data to mesh.

        All layers of mesh have identical spatial coordinates in Ugrid conventions.
        Also updates self.region if grid_name is new or overwrite_grid is True.

        Parameters
        ----------
        data: xugrid.UgridDataArray or xugrid.UgridDataset
            new layer to add to mesh, should contain only one grid topology.
        name: str, optional
            Name of new object layer, this is used to overwrite the name of
            a UgridDataArray.
        grid_name: str, optional
            Name of the mesh grid to add data to. If None, inferred from data.
            Can be used for renaming the grid.
        overwrite_grid: bool, optional
            If True, overwrite the grid with the same name as the grid in self.mesh.
        """
        # Checks on data
        if not isinstance(data, (xu.UgridDataArray, xu.UgridDataset)):
            raise ValueError(
                "New mesh data in set_mesh should be of type xu.UgridDataArray"
                " or xu.UgridDataset"
            )
        if isinstance(data, xu.UgridDataArray):
            if name is not None:
                data = data.rename(name)
            elif data.name is None:
                raise ValueError(
                    f"Cannot set mesh from {str(type(data).__name__)} without a name."
                )
            data = data.to_dataset()

        # Checks on grid topology
        # TODO: check if we support setting multiple grids at once. For now just one
        if len(data.ugrid.grids) > 1:
            raise ValueError(
                "set_mesh methods only supports adding data to one grid at a time."
            )
        if grid_name is None:
            grid_name = data.ugrid.grid.name
        elif grid_name != data.ugrid.grid.name:
            data = workflows.rename_mesh(data, name=grid_name)

        # Check if new grid_name
        if grid_name not in self.mesh_names:
            new_grid = True
        else:
            new_grid = False

        # Adding to mesh
        if self.mesh is None:  # NOTE: mesh is initialized with None
            # Check on crs
            if not data.ugrid.grid.crs:
                raise ValueError("Data should have CRS.")
            self._mesh = data
        else:
            # Check on crs
            if not data.ugrid.grid.crs == self.crs:
                raise ValueError("Data and self.mesh should have the same CRS.")
            # Save crs as it will be lost when converting to xarray
            crs = self.crs
            # Check on new grid topology
            if grid_name in self.mesh_names:
                # check if the two grids are the same
                if (
                    not self.mesh_grids[grid_name]
                    .to_dataset()
                    .equals(data.ugrid.grid.to_dataset())
                ):
                    if not overwrite_grid:
                        raise ValueError(
                            f"Grid {grid_name} already exists in mesh"
                            " and has a different topology. "
                            "Use overwrite_grid=True to overwrite the grid"
                            " topology and related data."
                        )
                    else:
                        # Remove grid and all corresponding data variables from mesh
                        self.logger.warning(
                            f"Overwriting grid {grid_name} and the corresponding"
                            " data variables in mesh."
                        )
                        grids = [
                            self.mesh_datasets[g].ugrid.to_dataset(
                                optional_attributes=True
                            )
                            for g in self.mesh_names
                            if g != grid_name
                        ]
                        # Re-define _mesh
                        grids = xr.merge(grids)
                        self._mesh = xu.UgridDataset(grids)
            # Check again mesh_names, could have changed if overwrite_grid=True
            if grid_name in self.mesh_names:
                for dvar in data.data_vars:
                    if dvar in self._mesh:
                        self.logger.warning(f"Replacing mesh parameter: {dvar}")
                    self._mesh[dvar] = data[dvar]
            else:
                # We are potentially adding a new grid without any data variables
                self._mesh = xu.UgridDataset(
                    xr.merge(
                        [
                            self.mesh.ugrid.to_dataset(optional_attributes=True),
                            data.ugrid.to_dataset(optional_attributes=True),
                        ]
                    )
                )
            # Restore crs
            for grid in self._mesh.ugrid.grids:
                grid.set_crs(crs)

        # update related geoms if necessary: region
        if overwrite_grid or new_grid:
            # add / updates region
            if "region" in self.geoms:
                self._geoms.pop("region", None)
            _ = self.region

    def get_mesh(
        self, grid_name: str, include_data: bool = False
    ) -> Union[xu.UgridDataArray, xu.UgridDataset]:
        """
        Return a specific grid topology from mesh based on grid_name.

        If include_data is True, the data variables for that specific
        grid are also included.

        Parameters
        ----------
        grid_name : str
            Name of the grid to return.
        include_data : bool, optional
            If True, also include data variables, by default False.

        Returns
        -------
        uds: Union[xu.UgridDataArray, xu.UgridDataset]
            Grid topology with or without data variables.
        """
        if self.mesh is None:
            raise ValueError("Mesh is not set, please use set_mesh first.")
        if grid_name not in self.mesh_names:
            raise ValueError(f"Grid {grid_name} not found in mesh.")
        if include_data:
            # Look for data_vars that are defined on grid_name
            variables = []
            for var in self.mesh.data_vars:
                if hasattr(self.mesh[var], "ugrid"):
                    if self.mesh[var].ugrid.grid.name != grid_name:
                        variables.append(var)
                # additional topology properties
                elif not var.startswith(grid_name):
                    variables.append(var)
                # else is global property (not grid specific)

            if variables and len(variables) < len(self.mesh.data_vars):
                uds = self.mesh.drop_vars(variables)
                # Drop coords as well
                drop_coords = [c for c in uds.coords if not c.startswith(grid_name)]
                uds = uds.drop_vars(drop_coords)
            elif variables and len(variables) == len(self.mesh.data_vars):
                grid = self.mesh_grids[grid_name]
                uds = xu.UgridDataset(grid.to_dataset(optional_attributes=True))
                uds.ugrid.grid.set_crs(grid.crs)
            else:
                uds = self.mesh.copy()

            return uds

        else:
            return self.mesh_grids[grid_name]

    def read_mesh(
        self, fn: str = "mesh/mesh.nc", crs: Union[CRS, int] = None, **kwargs
    ) -> None:
        """Read model mesh data at <root>/<fn> and add to mesh property.

        key-word arguments are passed to :py:meth:`~hydromt.models.Model.read_nc`

        Parameters
        ----------
        fn : str, optional
            filename relative to model root, by default 'mesh/mesh.nc'
        crs : CRS or int, optional
            Coordinate Reference System (CRS) object or EPSG code representing the
            spatial reference system of the mesh file. Only used if the CRS is not
            found when reading the mesh file.
        **kwargs : dict
            Additional keyword arguments to be passed to the `read_nc` method.
        """
        self._assert_read_mode()
        ds = xr.merge(self.read_nc(fn, **kwargs).values())
        uds = xu.UgridDataset(ds)
        if ds.rio.crs is not None:  # parse crs
            uds.ugrid.set_crs(ds.raster.crs)
            uds = uds.drop_vars(GEO_MAP_COORD, errors="ignore")
        else:
            if not crs:
                raise ValueError(
                    "no crs is found in the file nor passed to the reader."
                )
            else:
                uds.ugrid.set_crs(crs)
                self.logger.info(
                    "no crs is found in the file, assigning from user input."
                )
        self._mesh = uds

    def write_mesh(
        self,
        fn: str = "mesh/mesh.nc",
        write_optional_ugrid_attributes: bool = True,
        **kwargs,
    ) -> None:
        """Write model grid data to a netCDF file at <root>/<fn>.

        Keyword arguments are passed to :py:meth:`xarray.Dataset.to_netcdf`.

        Parameters
        ----------
        fn : str, optional
            Filename relative to the model root directory, by default 'grid/grid.nc'.
        write_optional_ugrid_attributes : bool, optional
            If True, write optional ugrid attributes to the netCDF file, by default
            True.
        **kwargs : dict
            Additional keyword arguments to be passed to the
            `xarray.Dataset.to_netcdf` method.
        """
        if self.mesh is None:
            self.logger.debug("No mesh data found, skip writing.")
            return
        self._assert_write_mode()
        # filename
        _fn = join(self.root, fn)
        if not isdir(dirname(_fn)):
            os.makedirs(dirname(_fn))
        self.logger.debug(f"Writing file {fn}")
        ds_out = self.mesh.ugrid.to_dataset(
            optional_attributes=write_optional_ugrid_attributes,
        )
        if self.crs is not None:
            # save crs to spatial_ref coordinate
            ds_out = ds_out.rio.write_crs(self.crs)
        ds_out.to_netcdf(_fn, **kwargs)

    # Other mesh properties
    @property
    def mesh_grids(self) -> Dict[str, Union[xu.Ugrid1d, xu.Ugrid2d]]:
        """Dictionnary of grid names and Ugrid topologies in mesh."""
        grids = dict()
        if self.mesh is not None:
            for grid in self.mesh.ugrid.grids:
                grids[grid.name] = grid

        return grids

    @property
    def mesh_datasets(self) -> Dict[str, xu.UgridDataset]:
        """Dictionnary of grid names and corresponding UgridDataset topology and data variables in mesh."""  # noqa: E501
        datasets = dict()
        if self.mesh is not None:
            for grid in self.mesh.ugrid.grids:
                datasets[grid.name] = self.get_mesh(
                    grid_name=grid.name, include_data=True
                )

        return datasets

    @property
    def mesh_names(self) -> List[str]:
        """List of grid names in mesh."""
        if self.mesh is not None:
            return [grid.name for grid in self.mesh.ugrid.grids]
        else:
            return []

    @property
    def mesh_gdf(self) -> Dict[str, gpd.GeoDataFrame]:
        """Returns dict of geometry of grids in mesh as a gpd.GeoDataFrame."""
        mesh_gdf = dict()
        if self.mesh is not None:
            for k, grid in self.mesh_grids.items():
                if grid.topology_dimension == 1:
                    dim = grid.edge_dimension
                elif grid.topology_dimension == 2:
                    dim = grid.face_dimension
                gdf = gpd.GeoDataFrame(
                    index=grid.to_dataset()[dim].values.astype(str),
                    geometry=grid.to_shapely(dim),
                )
                mesh_gdf[k] = gdf.set_crs(grid.crs)

        return mesh_gdf


[docs] class MeshModel(MeshMixin, Model): """Model class Mesh Model for mesh models in HydroMT. Uses xugrid for working with unstructured grids, for data and topology stored according to UGRID conventions. See also: xugrid """ _CLI_ARGS = {"region": "setup_mesh2d", "res": "setup_mesh2d"} _NAME = "mesh_model"
[docs] def __init__( self, root: str = None, mode: str = "w", config_fn: str = None, data_libs: List[str] = None, logger=logger, ): """Initialize a MeshModel for models with an unstructured grid.""" super().__init__( root=root, mode=mode, config_fn=config_fn, data_libs=data_libs, logger=logger, )
## general setup methods
[docs] def setup_mesh2d( self, region: dict, res: Optional[float] = None, crs: int = None, grid_name: str = "mesh2d", ) -> xu.UgridDataset: """HYDROMT CORE METHOD: Create an 2D unstructured mesh or reads an existing 2D mesh according UGRID conventions. Grids are read according to UGRID conventions. An 2D unstructured mesh will be created as 2D rectangular grid from a geometry (geom_fn) or bbox. If an existing 2D mesh is given, then no new mesh will be generated but an extent can be extracted using the `bounds` argument of region. Note Only existing meshed with only 2D grid can be read. Adds/Updates model layers: * **grid_name** mesh topology: add grid_name 2D topology to mesh object Parameters ---------- region : dict Dictionary describing region of interest, bounds can be provided for type 'mesh'. In case of 'mesh', if the file includes several grids, the specific 2D grid can be selected using the 'grid_name' argument. CRS for 'bbox' and 'bounds' should be 4326; e.g.: * {'bbox': [xmin, ymin, xmax, ymax]} * {'geom': 'path/to/polygon_geometry'} * {'mesh': 'path/to/2dmesh_file'} * {'mesh': 'path/to/mesh_file', 'grid_name': 'mesh2d', 'bounds': [xmin, ymin, xmax, ymax]} res: float Resolution used to generate 2D mesh [unit of the CRS], required if region is not based on 'mesh'. crs : EPSG code, int, optional Optional EPSG code of the model or "utm" to let hydromt find the closest projected CRS. If None using the one from region, and else 4326. grid_name : str, optional Name of the 2D grid in mesh, by default "mesh2d". Returns ------- mesh2d : xu.UgridDataset Generated mesh2d. """ # noqa: E501 self.logger.info("Preparing 2D mesh.") # Create mesh2d mesh2d = workflows.create_mesh2d( region=region, res=res, crs=crs, logger=self.logger, ) # Add mesh2d to self.mesh self.set_mesh(mesh2d, grid_name=grid_name) # This setup method returns mesh2d so that it can be wrapped for models # which require more information return mesh2d
## I/O def read( self, components: List = None, ) -> None: """Read the complete model schematization and configuration from model files. Parameters ---------- components : List, optional List of model components to read, each should have an associated read_<component> method. By default ['config', 'maps', 'mesh', 'geoms', 'forcing', 'states', 'results'] """ components = components or [ "config", "mesh", "geoms", "tables", "forcing", "states", "results", ] super().read(components=components) def write( self, components: List = None, ) -> None: """Write the complete model schematization and configuration to model files. Parameters ---------- components : List, optional List of model components to write, each should have an associated write_<component> method. By default ['config', 'maps', 'mesh', 'geoms', 'tables', 'forcing', 'states'] """ components = components or [ "config", "mesh", "geoms", "tables", "forcing", "states", ] super().write(components=components) # MeshModel specific methods # MeshModel properties @property def bounds(self) -> Dict: """Returns model mesh bounds.""" if self.mesh is not None: return self.mesh.ugrid.bounds @property def crs(self) -> CRS: """Returns model mesh crs.""" if self.mesh is not None: grid_crs = self.mesh.ugrid.crs # Check if all the same crs = None for _k, v in grid_crs.items(): if crs is None: crs = v if v == crs: continue else: raise ValueError( f"Mesh crs is not uniform, please check {grid_crs}" ) return crs else: return None @property def region(self) -> gpd.GeoDataFrame: """Returns geometry of region of the model area of interest based on mesh total bounds.""" # noqa: E501 region = gpd.GeoDataFrame() if "region" in self.geoms: region = self.geoms["region"] elif self.mesh is not None: region = gpd.GeoDataFrame( geometry=[box(*self.mesh.ugrid.total_bounds)], crs=self.crs ) self.set_geoms(region, name="region") return region