Source code for hydromt.models.model_mesh

from typing import Union, Optional, List, Tuple
import logging
import os
from os.path import join, isdir, dirname, isfile
from pathlib import Path
import numpy as np
import xarray as xr
import xugrid as xu
import geopandas as gpd
from shapely.geometry import box, Polygon

from ..raster import GEO_MAP_COORD
from .model_api import Model
from .. import workflows

__all__ = ["MeshModel"]
logger = logging.getLogger(__name__)


class MeshMixin(object):
    # placeholders
    # We cannot initialize an empty xu.UgridDataArray
    _API = {
        "mesh": Union[xu.UgridDataArray, xu.UgridDataset],
    }

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._mesh = None

    ## general setup methods
    def setup_mesh_from_raster(
        self,
        raster_fn: str,
        variables: Optional[list] = None,
        fill_method: Optional[str] = None,
        resampling_method: Optional[str] = "mean",
        all_touched: Optional[bool] = True,
    ) -> None:
        """
        This component adds data variable(s) from ``raster_fn`` to mesh object.

        Raster data is interpolated to the mesh grid using the ``resampling_method``.
        If raster is a dataset, all variables will be added unless ``variables`` list is specified.

        Adds model layers:

        * **raster.name** mesh: data from raster_fn

        Parameters
        ----------
        raster_fn: str
            Source name of raster data in data_catalog.
        variables: list, optional
            List of variables to add to mesh from raster_fn. By default all.
        fill_method : str, optional
            If specified, fills no data values using fill_nodata method. AVailable methods
            are {'linear', 'nearest', 'cubic', 'rio_idw'}.
        resampling_method: str, optional
            Method to sample from raster data to mesh. By default mean. Options include
            {'count', 'min', 'max', 'sum', 'mean', 'std', 'median', 'q##'}.
        all_touched : bool, optional
            If True, all pixels touched by geometries will used to define the sample.
            If False, only pixels whose center is within the geometry or that are
            selected by Bresenham's line algorithm will be used. By default True.
        """
        self.logger.info(f"Preparing mesh data from raster source {raster_fn}")
        # Read raster data and select variables
        ds = self.data_catalog.get_rasterdataset(
            raster_fn, geom=self.region, buffer=2, variables=variables
        )
        if isinstance(ds, xr.DataArray):
            ds = ds.to_dataset()

        if fill_method is not None:
            ds = ds.raster.interpolate_na(method=fill_method)

        # Convert mesh grid as geodataframe for sampling
        # Reprojection happens to gdf inside of zonal_stats method
        ds_sample = ds.raster.zonal_stats(
            gdf=self.mesh_gdf, stats=resampling_method, all_touched=all_touched
        )
        # Rename variables
        rm_dict = {f"{var}_{resampling_method}": var for var in ds.data_vars}
        ds_sample = ds_sample.rename(rm_dict)
        # Convert to UgridDataset
        uds_sample = xu.UgridDataset(ds_sample, grids=self.mesh.ugrid.grid)

        self.set_mesh(uds_sample)

    def setup_mesh_from_raster_reclass(
        self,
        raster_fn: str,
        reclass_table_fn: str,
        reclass_variables: list,
        variable: Optional[str] = None,
        fill_nodata: Optional[str] = None,
        resampling_method: Optional[Union[str, list]] = "mean",
        all_touched: Optional[bool] = True,
        **kwargs,
    ) -> None:
        """
        This component adds data variable(s) to mesh object by reclassifying the data in ``raster_fn`` based on ``reclass_table_fn``.

        The reclassified raster data are subsequently interpolated to the mesh using ``resampling_method``.

        Adds model layers:

        * **reclass_variables** mesh: reclassified raster data interpolated to the model mesh

        Parameters
        ----------
        raster_fn: str
            Source name of raster data in data_catalog. Should be a DataArray. Else use **kwargs to select variables/time_tuple in
            :py:meth:`hydromt.data_catalog.get_rasterdataset` method
        reclass_table_fn: str
            Source name of reclassification table of raster_fn in data_catalog.
        reclass_variables: list
            List of reclass_variables from reclass_table_fn table to add to mesh. Index column should match values in raster_fn.
        variable: str, optional
            Name of raster dataset variable to use. This is only required when reading datasets with multiple variables.
            By default None.
        fill_method : str, optional
            If specified, fills nodata values in `raster_fn` using fill_nodata method before reclassifying.
            Available methods are {'linear', 'nearest', 'cubic', 'rio_idw'}.
        resampling_method: str, list, optional
            Method to sample from raster data to mesh. Can be a list per variable in ``reclass_variables`` or a
            single method for all. By default mean for all reclass_variables.
            Options include {'count', 'min', 'max', 'sum', 'mean', 'std', 'median', 'q##'}.
        all_touched : bool, optional
            If True, all pixels touched by geometries will used to define the sample.
            If False, only pixels whose center is within the geometry or that are
            selected by Bresenham's line algorithm will be used. By default True.
        """
        self.logger.info(
            f"Preparing mesh data by reclassifying the data in {raster_fn} based on {reclass_table_fn}."
        )
        # Read raster data and mapping table
        da = self.data_catalog.get_rasterdataset(
            raster_fn, geom=self.region, buffer=2, variables=variable, **kwargs
        )
        if not isinstance(da, xr.DataArray):
            raise ValueError(
                f"raster_fn {raster_fn} should be a single variable raster. "
                "Please select one using the 'variable' argument"
            )
        df_vars = self.data_catalog.get_dataframe(
            reclass_table_fn, variables=reclass_variables
        )

        if fill_nodata is not None:
            ds = ds.raster.interpolate_na(method=fill_nodata)

        # Mapping function
        ds_vars = da.raster.reclassify(reclass_table=df_vars, method="exact")

        # Convert mesh grid as geodataframe for sampling
        # Reprojection happens to gdf inside of zonal_stats method
        ds_sample = ds_vars.raster.zonal_stats(
            gdf=self.mesh_gdf,
            stats=np.unique(np.atleast_1d(resampling_method)),
            all_touched=all_touched,
        )
        # Rename variables
        if isinstance(resampling_method, str):
            resampling_method = np.repeat(resampling_method, len(reclass_variables))
        rm_dict = {
            f"{var}_{mtd}": var
            for var, mtd in zip(reclass_variables, resampling_method)
        }
        ds_sample = ds_sample.rename(rm_dict)
        ds_sample = ds_sample[reclass_variables]
        # Convert to UgridDataset
        uds_sample = xu.UgridDataset(ds_sample, grids=self.mesh.ugrid.grid)

        self.set_mesh(uds_sample)

    @property
    def mesh(self) -> Union[xu.UgridDataArray, xu.UgridDataset]:
        """Model static mesh data. Returns a xarray.Dataset."""
        # XU grid data type Xarray dataset with xu sampling.
        if self._mesh is None and self._read:
            self.read_mesh()
        return self._mesh

    def set_mesh(
        self,
        data: Union[xu.UgridDataArray, xu.UgridDataset],
        name: Optional[str] = None,
    ) -> None:
        """Add data to mesh.

        All layers of mesh have identical spatial coordinates in Ugrid conventions.

        Parameters
        ----------
        data: xugrid.UgridDataArray or xugrid.UgridDataset
            new layer to add to mesh
        name: str, optional
            Name of new object layer, this is used to overwrite the name of a UgridDataArray.
        """
        if not isinstance(data, (xu.UgridDataArray, xu.UgridDataset)):
            raise ValueError(
                "New mesh data in set_mesh should be of type xu.UgridDataArray or xu.UgridDataset"
            )
        if isinstance(data, xu.UgridDataArray):
            if name is not None:
                data = data.rename(name)
            elif data.name is None:
                raise ValueError(
                    f"Cannot set mesh from {str(type(data).__name__)} without a name."
                )
            data = data.to_dataset()
        if self._mesh is None:  # NOTE: mesh is initialized with None
            self._mesh = data
        else:
            for dvar in data.data_vars:
                if dvar in self._mesh:
                    self.logger.warning(f"Replacing mesh parameter: {dvar}")
                self._mesh[dvar] = data[dvar]

    def read_mesh(self, fn: str = "mesh/mesh.nc", **kwargs) -> None:
        """Read model mesh data at <root>/<fn> and add to mesh property

        key-word arguments are passed to :py:func:`xr.open_dataset`

        Parameters
        ----------
        fn : str, optional
            filename relative to model root, by default 'mesh/mesh.nc'
        """
        self._assert_read_mode
        for ds in self._read_nc(fn, **kwargs).values():
            uds = xu.UgridDataset(ds)
            if ds.rio.crs is not None:  # parse crs
                uds.ugrid.grid.set_crs(ds.raster.crs)
                uds = uds.drop_vars(GEO_MAP_COORD, errors="ignore")
            self.set_mesh(uds)

    def write_mesh(self, fn: str = "mesh/mesh.nc", **kwargs) -> None:
        """Write model grid data to netcdf file at <root>/<fn>

        key-word arguments are passed to :py:meth:`xarray.Dataset.ugrid.to_netcdf`

        Parameters
        ----------
        fn : str, optional
            filename relative to model root, by default 'grid/grid.nc'
        """
        if self._mesh is None:
            self.logger.debug("No mesh data found, skip writing.")
            return
        self._assert_write_mode
        # filename
        _fn = join(self.root, fn)
        if not isdir(dirname(_fn)):
            os.makedirs(dirname(_fn))
        self.logger.debug(f"Writing file {fn}")
        # ds_new = xu.UgridDataset(grid=ds_out.ugrid.grid) # bug in xugrid?
        ds_out = self.mesh.ugrid.to_dataset()
        if self.mesh.ugrid.grid.crs is not None:
            # save crs to spatial_ref coordinate
            ds_out = ds_out.rio.write_crs(self.mesh.ugrid.grid.crs)
        ds_out.to_netcdf(_fn, **kwargs)


[docs]class MeshModel(MeshMixin, Model): """Model class Mesh Model for mesh models in HydroMT""" _CLI_ARGS = {"region": "setup_mesh", "res": "setup_mesh"} _NAME = "mesh_model"
[docs] def __init__( self, root: str = None, mode: str = "w", config_fn: str = None, data_libs: List[str] = None, logger=logger, ): """Initialize a MeshModel for distributed models with an unstructured grid.""" super().__init__( root=root, mode=mode, config_fn=config_fn, data_libs=data_libs, logger=logger, )
## general setup methods def setup_mesh( self, region: dict, res: Optional[float] = None, crs: int = None, ) -> xu.UgridDataset: """Creates an 2D unstructured mesh or reads an existing 2D mesh according UGRID conventions. An 2D unstructured mesh will be created as 2D rectangular grid from a geometry (geom_fn) or bbox. If an existing 2D mesh is given, then no new mesh will be generated Note Only existing meshed with only 2D grid can be read. #FIXME: read existing 1D2D network file and extract 2D part. Adds/Updates model layers: * **mesh** mesh topology: add mesh topology to mesh object Parameters ---------- region : dict Dictionary describing region of interest, e.g.: * {'bbox': [xmin, ymin, xmax, ymax]} * {'geom': 'path/to/polygon_geometry'} * {'mesh': 'path/to/2dmesh_file'} res: float Resolution used to generate 2D mesh [unit of the CRS], required if region is not based on 'mesh'. crs : EPSG code, int, optional Optional EPSG code of the model. If None using the one from region, and else 4326. Returns ------- mesh2d : xu.UgridDataset Generated mesh2d. """ self.logger.info(f"Preparing 2D mesh.") if "mesh" not in region: if not isinstance(res, (int, float)): raise ValueError("res argument required") kind, region = workflows.parse_region(region, logger=self.logger) if kind == "bbox": bbox = region["bbox"] geom = gpd.GeoDataFrame(geometry=[box(*bbox)], crs=4326) elif kind == "geom": geom = region["geom"] if geom.crs is None: raise ValueError('Model region "geom" has no CRS') else: raise ValueError( f"Region for mesh must of kind [bbox, geom, mesh], kind {kind} not understood." ) if crs is not None: geom = geom.to_crs(crs) # Generate grid based on res for region bbox xmin, ymin, xmax, ymax = geom.total_bounds # note we flood the number of faces within bounds ncol = int((xmax - xmin) // res) nrow = int((ymax - ymin) // res) dx, dy = res, -res faces = [] for i in range(nrow): top = ymax + i * dy bottom = ymax + (i + 1) * dy for j in range(ncol): left = xmin + j * dx right = xmin + (j + 1) * dx faces.append(box(left, bottom, right, top)) grid = gpd.GeoDataFrame(geometry=faces, crs=geom.crs) # If needed clip to geom if kind != "bbox": # TODO: grid.intersects(geom) does not seem to work ? grid = grid.loc[ gpd.sjoin( grid, geom, how="left", op="intersects" ).index_right.notna() ].reset_index() # Create mesh from grid grid.index.name = "mesh2d_nFaces" mesh2d = xu.UgridDataset.from_geodataframe(grid) mesh2d.ugrid.grid.set_crs(grid.crs) else: mesh2d_fn = region["mesh"] if isinstance(mesh2d_fn, (str, Path)) and isfile(mesh2d_fn): self.logger.info(f"An existing 2D grid is used to prepare 2D mesh.") ds = xr.open_dataset(mesh2d_fn, mask_and_scale=False) elif isinstance(mesh2d_fn, xr.Dataset): ds = mesh2d_fn else: raise ValueError( f"Region 'mesh' file {mesh2d_fn} not found, please check" ) topologies = [ k for k in ds.data_vars if ds[k].attrs.get("cf_role") == "mesh_topology" ] for topology in topologies: topodim = ds[topology].attrs["topology_dimension"] if topodim != 2: # chek if 2d mesh file else throw error raise NotImplementedError( f"{mesh2d_fn} cannot be opened. Please check if the existing grid is " f"an 2D mesh and not 1D2D mesh. This option is not yet available for 1D2D meshes." ) # Continues with a 2D grid mesh2d = xu.UgridDataset(ds) # Check crs and reproject to model crs if crs is None: crs = 4326 if ds.rio.crs is not None: # parse crs mesh2d.ugrid.grid.set_crs(ds.raster.crs) else: # Assume model crs self.logger.warning( f"Mesh data from {mesh2d_fn} doesn't have a CRS. Assuming crs option {crs}" ) mesh2d.ugrid.grid.set_crs(crs) mesh2d = mesh2d.drop_vars(GEO_MAP_COORD, errors="ignore") # Reproject to user crs option if needed if mesh2d.ugrid.grid.crs != crs and crs is not None: self.logger.info(f"Reprojecting mesh to crs {crs}") mesh2d.ugrid.grid.to_crs(self.crs) self.set_mesh(mesh2d) # This setup method returns region so that it can be wrapped for models which require # more information return mesh2d ## I/O def read( self, components: List = [ "config", "mesh", "geoms", "forcing", "states", "results", ], ) -> None: """Read the complete model schematization and configuration from model files. Parameters ---------- components : List, optional List of model components to read, each should have an associated read_<component> method. By default ['config', 'maps', 'mesh', 'geoms', 'forcing', 'states', 'results'] """ super().read(components=components) def write( self, components: List = ["config", "mesh", "geoms", "forcing", "states"], ) -> None: """Write the complete model schematization and configuration to model files. Parameters ---------- components : List, optional List of model components to write, each should have an associated write_<component> method. By default ['config', 'maps', 'mesh', 'geoms', 'forcing', 'states'] """ super().write(components=components) # MeshModel specific methods # MeshModel properties @property def bounds(self) -> Tuple: """Returns model mesh bounds.""" if self._mesh is not None: return self._mesh.ugrid.grid.bounds @property def region(self) -> gpd.GeoDataFrame: """Returns geometry of region of the model area of interest.""" region = gpd.GeoDataFrame() if "region" in self.geoms: region = self.geoms["region"] elif self.mesh is not None: crs = self.mesh.ugrid.grid.crs if crs is None and hasattr(crs, "to_epsg"): crs = crs.to_epsg() # not all CRS have an EPSG code region = gpd.GeoDataFrame(geometry=[box(*self.bounds)], crs=crs) return region @property def mesh_gdf(self) -> gpd.GeoDataFrame: """Returns geometry of mesh as a gpd.GeoDataFrame""" if self._mesh is not None: name = [n for n in self.mesh.data_vars][0] # works better on a DataArray return self._mesh[name].ugrid.to_geodataframe()