Source code for hydromt.models.model_mesh

"""Implementations for model mesh workloads."""
import logging
import os
from os.path import dirname, isdir, isfile, join
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
import xugrid as xu
from shapely.geometry import box

from .. import workflows
from ..raster import GEO_MAP_COORD
from .model_api import Model

__all__ = ["MeshModel"]
logger = logging.getLogger(__name__)


class MeshMixin(object):
    # placeholders
    # We cannot initialize an empty xu.UgridDataArray
    _API = {
        "mesh": Union[xu.UgridDataArray, xu.UgridDataset],
    }

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._mesh = None

    ## general setup methods
    def setup_mesh_from_rasterdataset(
        self,
        raster_fn: Union[str, Path, xr.DataArray, xr.Dataset],
        variables: Optional[list] = None,
        fill_method: Optional[str] = None,
        resampling_method: Optional[str] = "mean",
        all_touched: Optional[bool] = True,
        rename: Optional[Dict] = dict(),
    ) -> List[str]:
        """HYDROMT CORE METHOD: Add data variable(s) from ``raster_fn`` to mesh object.

        Raster data is interpolated to the mesh grid using the ``resampling_method``.
        If raster is a dataset, all variables will be added unless ``variables`` list
        is specified.

        Adds model layers:

        * **raster.name** mesh: data from raster_fn

        Parameters
        ----------
        raster_fn: str, Path, xr.DataArray, xr.Dataset
            Data catalog key, path to raster file or raster xarray data object.
        variables: list, optional
            List of variables to add to mesh from raster_fn. By default all.
        fill_method : str, optional
            If specified, fills no data values using fill_nodata method.
            Available methods are {'linear', 'nearest', 'cubic', 'rio_idw'}.
        resampling_method: str, optional
            Method to sample from raster data to mesh. By default mean. Options include
            {'count', 'min', 'max', 'sum', 'mean', 'std', 'median', 'q##'}.
        all_touched : bool, optional
            If True, all pixels touched by geometries will used to define the sample.
            If False, only pixels whose center is within the geometry or that are
            selected by Bresenham's line algorithm will be used. By default True.
        rename: dict, optional
            Dictionary to rename variable names in raster_fn before adding to mesh
            {'name_in_raster_fn': 'name_in_mesh'}. By default empty.

        Returns
        -------
        list
            List of variables added to mesh.
        """
        self.logger.info(f"Preparing mesh data from raster source {raster_fn}")
        # Read raster data and select variables
        ds = self.data_catalog.get_rasterdataset(
            raster_fn, geom=self.region, buffer=2, variables=variables
        )
        if isinstance(ds, xr.DataArray):
            ds = ds.to_dataset()

        if fill_method is not None:
            ds = ds.raster.interpolate_na(method=fill_method)

        # Convert mesh grid as geodataframe for sampling
        # Reprojection happens to gdf inside of zonal_stats method
        ds_sample = ds.raster.zonal_stats(
            gdf=self.mesh_gdf, stats=resampling_method, all_touched=all_touched
        )
        # Rename variables
        rm_dict = {f"{var}_{resampling_method}": var for var in ds.data_vars}
        ds_sample = ds_sample.rename(rm_dict).rename(rename)
        # Convert to UgridDataset
        uds_sample = xu.UgridDataset(ds_sample, grids=self.mesh.ugrid.grid)

        self.set_mesh(uds_sample)

        return list(ds_sample.data_vars.keys())

    def setup_mesh_from_raster_reclass(
        self,
        raster_fn: Union[str, Path, xr.DataArray],
        reclass_table_fn: Union[str, Path, pd.DataFrame],
        reclass_variables: list,
        variable: Optional[str] = None,
        fill_nodata: Optional[str] = None,
        resampling_method: Optional[Union[str, list]] = "mean",
        all_touched: Optional[bool] = True,
        rename: Optional[Dict] = dict(),
        **kwargs,
    ) -> List[str]:
        """HYDROMT CORE METHOD: Add data variable(s) to mesh object by reclassifying the data in ``raster_fn`` based on ``reclass_table_fn``.

        The reclassified raster data
        are subsequently interpolated to the mesh using `resampling_method`.

        Adds model layers:

        * **reclass_variables** mesh: reclassified raster data interpolated to the
            model mesh

        Parameters
        ----------
        raster_fn : str, Path, xr.DataArray
            Data catalog key, path to the raster file, or raster xarray data object.
            Should be a DataArray. If not, use the `variable` argument for selection.
        reclass_table_fn : str, Path, pd.DataFrame
            Data catalog key, path to the tabular data file, or tabular pandas dataframe
            object for the reclassification table of `raster_fn`.
        reclass_variables : list
            List of reclass_variables from the reclass_table_fn table to add to the
            mesh. The index column should match values in raster_fn.
        variable : str, optional
            Name of the raster dataset variable to use. This is only required when
            reading datasets with multiple variables. By default, None.
        fill_nodata : str, optional
            If specified, fills nodata values in `raster_fn` using the `fill_nodata`
            method before reclassifying. Available methods are
            {'linear', 'nearest', 'cubic', 'rio_idw'}.
        resampling_method : str or list, optional
            Method to sample from raster data to the mesh. Can be a list per variable
            in `reclass_variables` or a single method for all. By default, 'mean' is
            used for all `reclass_variables`. Options include {'count', 'min', 'max',
            'sum', 'mean', 'std', 'median', 'q##'}.
        all_touched : bool, optional
            If True, all pixels touched by geometries will be used to define the sample.
            If False, only pixels whose center is within the geometry or that are
            selected by Bresenham's line algorithm will be used. By default, True.
        rename : dict, optional
            Dictionary to rename variable names in `reclass_variables` before adding
            them to the mesh. The dictionary should have the form
            {'name_in_reclass_table': 'name_in_mesh'}. By default, an empty dictionary.
        **kwargs : dict
            Additional keyword arguments to be passed to the raster dataset
            retrieval method.

        Returns
        -------
        variable_names : List[str]
            List of added variable names in the mesh.

        Raises
        ------
        ValueError
            If `raster_fn` is not a single variable raster.
        """  # noqa: E501
        self.logger.info(
            f"Preparing mesh data by reclassifying the data in {raster_fn} "
            f"based on {reclass_table_fn}."
        )
        # Read raster data and mapping table
        da = self.data_catalog.get_rasterdataset(
            raster_fn, geom=self.region, buffer=2, variables=variable, **kwargs
        )
        if not isinstance(da, xr.DataArray):
            raise ValueError(
                f"raster_fn {raster_fn} should be a single variable raster. "
                "Please select one using the 'variable' argument"
            )
        df_vars = self.data_catalog.get_dataframe(
            reclass_table_fn, variables=reclass_variables
        )

        if fill_nodata is not None:
            da = da.raster.interpolate_na(method=fill_nodata)

        # Mapping function
        ds_vars = da.raster.reclassify(reclass_table=df_vars, method="exact")

        # Convert mesh grid as geodataframe for sampling
        # Reprojection happens to gdf inside of zonal_stats method
        ds_sample = ds_vars.raster.zonal_stats(
            gdf=self.mesh_gdf,
            stats=np.unique(np.atleast_1d(resampling_method)),
            all_touched=all_touched,
        )
        # Rename variables
        if isinstance(resampling_method, str):
            resampling_method = np.repeat(resampling_method, len(reclass_variables))
        rm_dict = {
            f"{var}_{mtd}": var
            for var, mtd in zip(reclass_variables, resampling_method)
        }
        ds_sample = ds_sample.rename(rm_dict).rename(rename)
        ds_sample = ds_sample[reclass_variables]
        # Convert to UgridDataset
        uds_sample = xu.UgridDataset(ds_sample, grids=self.mesh.ugrid.grid)

        self.set_mesh(uds_sample)

        return list(ds_sample.data_vars.keys())

    @property
    def mesh(self) -> Union[xu.UgridDataArray, xu.UgridDataset]:
        """Model static mesh data. Returns a xarray.Dataset."""
        # XU grid data type Xarray dataset with xu sampling.
        if self._mesh is None and self._read:
            self.read_mesh()
        return self._mesh

    def set_mesh(
        self,
        data: Union[xu.UgridDataArray, xu.UgridDataset],
        name: Optional[str] = None,
    ) -> None:
        """Add data to mesh.

        All layers of mesh have identical spatial coordinates in Ugrid conventions.

        Parameters
        ----------
        data: xugrid.UgridDataArray or xugrid.UgridDataset
            new layer to add to mesh
        name: str, optional
            Name of new object layer, this is used to overwrite the name of
            a UgridDataArray.
        """
        if not isinstance(data, (xu.UgridDataArray, xu.UgridDataset)):
            raise ValueError(
                "New mesh data in set_mesh should be of type xu.UgridDataArray"
                " or xu.UgridDataset"
            )
        if isinstance(data, xu.UgridDataArray):
            if name is not None:
                data = data.rename(name)
            elif data.name is None:
                raise ValueError(
                    f"Cannot set mesh from {str(type(data).__name__)} without a name."
                )
            data = data.to_dataset()
        if self._mesh is None:  # NOTE: mesh is initialized with None
            self._mesh = data
        else:
            for dvar in data.data_vars:
                if dvar in self._mesh:
                    self.logger.warning(f"Replacing mesh parameter: {dvar}")
                self._mesh[dvar] = data[dvar]

    def read_mesh(self, fn: str = "mesh/mesh.nc", **kwargs) -> None:
        """Read model mesh data at <root>/<fn> and add to mesh property.

        key-word arguments are passed to :py:func:`xr.open_dataset`

        Parameters
        ----------
        fn : str, optional
            filename relative to model root, by default 'mesh/mesh.nc'
        **kwargs : dict
            Additional keyword arguments to be passed to the `_read_nc` method.
        """
        self._assert_read_mode
        for ds in self._read_nc(fn, **kwargs).values():
            uds = xu.UgridDataset(ds)
            if ds.rio.crs is not None:  # parse crs
                uds.ugrid.grid.set_crs(ds.raster.crs)
                uds = uds.drop_vars(GEO_MAP_COORD, errors="ignore")
            self.set_mesh(uds)

    def write_mesh(self, fn: str = "mesh/mesh.nc", **kwargs) -> None:
        """Write model grid data to a netCDF file at <root>/<fn>.

        Keyword arguments are passed to :py:meth:`xarray.Dataset.ugrid.to_netcdf`.

        Parameters
        ----------
        fn : str, optional
            Filename relative to the model root directory, by default 'grid/grid.nc'.
        **kwargs : dict
            Additional keyword arguments to be passed to the
            `xarray.Dataset.ugrid.to_netcdf` method.
        """
        if self._mesh is None:
            self.logger.debug("No mesh data found, skip writing.")
            return
        self._assert_write_mode
        # filename
        _fn = join(self.root, fn)
        if not isdir(dirname(_fn)):
            os.makedirs(dirname(_fn))
        self.logger.debug(f"Writing file {fn}")
        ds_out = self.mesh.ugrid.to_dataset()
        if self.mesh.ugrid.grid.crs is not None:
            # save crs to spatial_ref coordinate
            ds_out = ds_out.rio.write_crs(self.mesh.ugrid.grid.crs)
        ds_out.to_netcdf(_fn, **kwargs)


[docs]class MeshModel(MeshMixin, Model): """Model class Mesh Model for mesh models in HydroMT.""" _CLI_ARGS = {"region": "setup_mesh", "res": "setup_mesh"} _NAME = "mesh_model"
[docs] def __init__( self, root: str = None, mode: str = "w", config_fn: str = None, data_libs: List[str] = None, logger=logger, ): """Initialize a MeshModel for distributed models with an unstructured grid.""" super().__init__( root=root, mode=mode, config_fn=config_fn, data_libs=data_libs, logger=logger, )
## general setup methods def setup_mesh( self, region: dict, res: Optional[float] = None, crs: int = None, ) -> xu.UgridDataset: """HYDROMT CORE METHOD: Create an 2D unstructured mesh or reads an existing 2D mesh according UGRID conventions. Grids are read according to UGRID conventions. An 2D unstructured mesh will be created as 2D rectangular grid from a geometry (geom_fn) or bbox. If an existing 2D mesh is given, then no new mesh will be generated Note Only existing meshed with only 2D grid can be read. #FIXME: read existing 1D2D network file and extract 2D part. Adds/Updates model layers: * **mesh** mesh topology: add mesh topology to mesh object Parameters ---------- region : dict Dictionary describing region of interest, e.g.: * {'bbox': [xmin, ymin, xmax, ymax]} * {'geom': 'path/to/polygon_geometry'} * {'mesh': 'path/to/2dmesh_file'} res: float Resolution used to generate 2D mesh [unit of the CRS], required if region is not based on 'mesh'. crs : EPSG code, int, optional Optional EPSG code of the model. If None using the one from region, and else 4326. Returns ------- mesh2d : xu.UgridDataset Generated mesh2d. """ # noqa: E501 self.logger.info("Preparing 2D mesh.") if "mesh" not in region: if not isinstance(res, (int, float)): raise ValueError("res argument required") kind, region = workflows.parse_region(region, logger=self.logger) if kind == "bbox": bbox = region["bbox"] geom = gpd.GeoDataFrame(geometry=[box(*bbox)], crs=4326) elif kind == "geom": geom = region["geom"] if geom.crs is None: raise ValueError('Model region "geom" has no CRS') else: raise ValueError( f"Region for mesh must of kind [bbox, geom, mesh], kind {kind} " "not understood." ) if crs is not None: geom = geom.to_crs(crs) # Generate grid based on res for region bbox xmin, ymin, xmax, ymax = geom.total_bounds # note we flood the number of faces within bounds ncol = int((xmax - xmin) // res) nrow = int((ymax - ymin) // res) dx, dy = res, -res faces = [] for i in range(nrow): top = ymax + i * dy bottom = ymax + (i + 1) * dy for j in range(ncol): left = xmin + j * dx right = xmin + (j + 1) * dx faces.append(box(left, bottom, right, top)) grid = gpd.GeoDataFrame(geometry=faces, crs=geom.crs) # If needed clip to geom if kind != "bbox": # TODO: grid.intersects(geom) does not seem to work ? grid = grid.loc[ gpd.sjoin( grid, geom, how="left", predicate="intersects" ).index_right.notna() ].reset_index() # Create mesh from grid grid.index.name = "mesh2d_nFaces" mesh2d = xu.UgridDataset.from_geodataframe(grid) mesh2d.ugrid.grid.set_crs(grid.crs) else: mesh2d_fn = region["mesh"] if isinstance(mesh2d_fn, (str, Path)) and isfile(mesh2d_fn): self.logger.info("An existing 2D grid is used to prepare 2D mesh.") ds = xr.open_dataset(mesh2d_fn, mask_and_scale=False) elif isinstance(mesh2d_fn, xr.Dataset): ds = mesh2d_fn else: raise ValueError( f"Region 'mesh' file {mesh2d_fn} not found, please check" ) topologies = [ k for k in ds.data_vars if ds[k].attrs.get("cf_role") == "mesh_topology" ] for topology in topologies: topodim = ds[topology].attrs["topology_dimension"] if topodim != 2: # chek if 2d mesh file else throw error raise NotImplementedError( f"{mesh2d_fn} cannot be opened. Please check if the existing" " grid is an 2D mesh and not 1D2D mesh. " " This option is not yet available for 1D2D meshes." ) # Continues with a 2D grid mesh2d = xu.UgridDataset(ds) # Check crs and reproject to model crs if crs is None: crs = 4326 if ds.rio.crs is not None: # parse crs mesh2d.ugrid.grid.set_crs(ds.raster.crs) else: # Assume model crs self.logger.warning( f"Mesh data from {mesh2d_fn} doesn't have a CRS." f" Assuming crs option {crs}" ) mesh2d.ugrid.grid.set_crs(crs) mesh2d = mesh2d.drop_vars(GEO_MAP_COORD, errors="ignore") # Reproject to user crs option if needed if mesh2d.ugrid.grid.crs != crs and crs is not None: self.logger.info(f"Reprojecting mesh to crs {crs}") mesh2d.ugrid.grid.to_crs(self.crs) self.set_mesh(mesh2d) # This setup method returns region so that it can be wrapped for models # which require more information return mesh2d ## I/O def read( self, components: List = [ "config", "mesh", "geoms", "forcing", "states", "results", ], ) -> None: """Read the complete model schematization and configuration from model files. Parameters ---------- components : List, optional List of model components to read, each should have an associated read_<component> method. By default ['config', 'maps', 'mesh', 'geoms', 'forcing', 'states', 'results'] """ super().read(components=components) def write( self, components: List = ["config", "mesh", "geoms", "forcing", "states"], ) -> None: """Write the complete model schematization and configuration to model files. Parameters ---------- components : List, optional List of model components to write, each should have an associated write_<component> method. By default ['config', 'maps', 'mesh', 'geoms', 'forcing', 'states'] """ super().write(components=components) # MeshModel specific methods # MeshModel properties @property def bounds(self) -> Tuple: """Returns model mesh bounds.""" if self._mesh is not None: return self._mesh.ugrid.grid.bounds @property def region(self) -> gpd.GeoDataFrame: """Returns geometry of region of the model area of interest.""" region = gpd.GeoDataFrame() if "region" in self.geoms: region = self.geoms["region"] elif self.mesh is not None: crs = self.mesh.ugrid.grid.crs if crs is None and hasattr(crs, "to_epsg"): crs = crs.to_epsg() # not all CRS have an EPSG code region = gpd.GeoDataFrame(geometry=[box(*self.bounds)], crs=crs) return region @property def mesh_gdf(self) -> gpd.GeoDataFrame: """Returns geometry of mesh as a gpd.GeoDataFrame.""" if self._mesh is not None: name = [n for n in self.mesh.data_vars][0] # works better on a DataArray return self._mesh[name].ugrid.to_geodataframe()