from typing import Union, Optional, List, Tuple
import logging
import os
from os.path import join, isdir, dirname, isfile
from pathlib import Path
import numpy as np
import xarray as xr
import xugrid as xu
import geopandas as gpd
from shapely.geometry import box, Polygon
from ..raster import GEO_MAP_COORD
from .model_api import Model
from .. import workflows
__all__ = ["MeshModel"]
logger = logging.getLogger(__name__)
class MeshMixin(object):
# placeholders
# We cannot initialize an empty xu.UgridDataArray
_API = {
"mesh": Union[xu.UgridDataArray, xu.UgridDataset],
}
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._mesh = None
## general setup methods
def setup_mesh_from_raster(
self,
raster_fn: str,
variables: Optional[list] = None,
fill_method: Optional[str] = None,
resampling_method: Optional[str] = "mean",
all_touched: Optional[bool] = True,
) -> None:
"""
This component adds data variable(s) from ``raster_fn`` to mesh object.
Raster data is interpolated to the mesh grid using the ``resampling_method``.
If raster is a dataset, all variables will be added unless ``variables`` list is specified.
Adds model layers:
* **raster.name** mesh: data from raster_fn
Parameters
----------
raster_fn: str
Source name of raster data in data_catalog.
variables: list, optional
List of variables to add to mesh from raster_fn. By default all.
fill_method : str, optional
If specified, fills no data values using fill_nodata method. AVailable methods
are {'linear', 'nearest', 'cubic', 'rio_idw'}.
resampling_method: str, optional
Method to sample from raster data to mesh. By default mean. Options include
{'count', 'min', 'max', 'sum', 'mean', 'std', 'median', 'q##'}.
all_touched : bool, optional
If True, all pixels touched by geometries will used to define the sample.
If False, only pixels whose center is within the geometry or that are
selected by Bresenham's line algorithm will be used. By default True.
"""
self.logger.info(f"Preparing mesh data from raster source {raster_fn}")
# Read raster data and select variables
ds = self.data_catalog.get_rasterdataset(
raster_fn, geom=self.region, buffer=2, variables=variables
)
if isinstance(ds, xr.DataArray):
ds = ds.to_dataset()
if fill_method is not None:
ds = ds.raster.interpolate_na(method=fill_method)
# Convert mesh grid as geodataframe for sampling
# Reprojection happens to gdf inside of zonal_stats method
ds_sample = ds.raster.zonal_stats(
gdf=self.mesh_gdf, stats=resampling_method, all_touched=all_touched
)
# Rename variables
rm_dict = {f"{var}_{resampling_method}": var for var in ds.data_vars}
ds_sample = ds_sample.rename(rm_dict)
# Convert to UgridDataset
uds_sample = xu.UgridDataset(ds_sample, grids=self.mesh.ugrid.grid)
self.set_mesh(uds_sample)
def setup_mesh_from_raster_reclass(
self,
raster_fn: str,
reclass_table_fn: str,
reclass_variables: list,
variable: Optional[str] = None,
fill_nodata: Optional[str] = None,
resampling_method: Optional[Union[str, list]] = "mean",
all_touched: Optional[bool] = True,
**kwargs,
) -> None:
"""
This component adds data variable(s) to mesh object by reclassifying the data in ``raster_fn`` based on ``reclass_table_fn``.
The reclassified raster data are subsequently interpolated to the mesh using ``resampling_method``.
Adds model layers:
* **reclass_variables** mesh: reclassified raster data interpolated to the model mesh
Parameters
----------
raster_fn: str
Source name of raster data in data_catalog. Should be a DataArray. Else use **kwargs to select variables/time_tuple in
:py:meth:`hydromt.data_catalog.get_rasterdataset` method
reclass_table_fn: str
Source name of reclassification table of raster_fn in data_catalog.
reclass_variables: list
List of reclass_variables from reclass_table_fn table to add to mesh. Index column should match values in raster_fn.
variable: str, optional
Name of raster dataset variable to use. This is only required when reading datasets with multiple variables.
By default None.
fill_method : str, optional
If specified, fills nodata values in `raster_fn` using fill_nodata method before reclassifying.
Available methods are {'linear', 'nearest', 'cubic', 'rio_idw'}.
resampling_method: str, list, optional
Method to sample from raster data to mesh. Can be a list per variable in ``reclass_variables`` or a
single method for all. By default mean for all reclass_variables.
Options include {'count', 'min', 'max', 'sum', 'mean', 'std', 'median', 'q##'}.
all_touched : bool, optional
If True, all pixels touched by geometries will used to define the sample.
If False, only pixels whose center is within the geometry or that are
selected by Bresenham's line algorithm will be used. By default True.
"""
self.logger.info(
f"Preparing mesh data by reclassifying the data in {raster_fn} based on {reclass_table_fn}."
)
# Read raster data and mapping table
da = self.data_catalog.get_rasterdataset(
raster_fn, geom=self.region, buffer=2, variables=variable, **kwargs
)
if not isinstance(da, xr.DataArray):
raise ValueError(
f"raster_fn {raster_fn} should be a single variable raster. "
"Please select one using the 'variable' argument"
)
df_vars = self.data_catalog.get_dataframe(
reclass_table_fn, variables=reclass_variables
)
if fill_nodata is not None:
ds = ds.raster.interpolate_na(method=fill_nodata)
# Mapping function
ds_vars = da.raster.reclassify(reclass_table=df_vars, method="exact")
# Convert mesh grid as geodataframe for sampling
# Reprojection happens to gdf inside of zonal_stats method
ds_sample = ds_vars.raster.zonal_stats(
gdf=self.mesh_gdf,
stats=np.unique(np.atleast_1d(resampling_method)),
all_touched=all_touched,
)
# Rename variables
if isinstance(resampling_method, str):
resampling_method = np.repeat(resampling_method, len(reclass_variables))
rm_dict = {
f"{var}_{mtd}": var
for var, mtd in zip(reclass_variables, resampling_method)
}
ds_sample = ds_sample.rename(rm_dict)
ds_sample = ds_sample[reclass_variables]
# Convert to UgridDataset
uds_sample = xu.UgridDataset(ds_sample, grids=self.mesh.ugrid.grid)
self.set_mesh(uds_sample)
@property
def mesh(self) -> Union[xu.UgridDataArray, xu.UgridDataset]:
"""Model static mesh data. Returns a xarray.Dataset."""
# XU grid data type Xarray dataset with xu sampling.
if self._mesh is None and self._read:
self.read_mesh()
return self._mesh
def set_mesh(
self,
data: Union[xu.UgridDataArray, xu.UgridDataset],
name: Optional[str] = None,
) -> None:
"""Add data to mesh.
All layers of mesh have identical spatial coordinates in Ugrid conventions.
Parameters
----------
data: xugrid.UgridDataArray or xugrid.UgridDataset
new layer to add to mesh
name: str, optional
Name of new object layer, this is used to overwrite the name of a UgridDataArray.
"""
if not isinstance(data, (xu.UgridDataArray, xu.UgridDataset)):
raise ValueError(
"New mesh data in set_mesh should be of type xu.UgridDataArray or xu.UgridDataset"
)
if isinstance(data, xu.UgridDataArray):
if name is not None:
data = data.rename(name)
elif data.name is None:
raise ValueError(
f"Cannot set mesh from {str(type(data).__name__)} without a name."
)
data = data.to_dataset()
if self._mesh is None: # NOTE: mesh is initialized with None
self._mesh = data
else:
for dvar in data.data_vars:
if dvar in self._mesh:
self.logger.warning(f"Replacing mesh parameter: {dvar}")
self._mesh[dvar] = data[dvar]
def read_mesh(self, fn: str = "mesh/mesh.nc", **kwargs) -> None:
"""Read model mesh data at <root>/<fn> and add to mesh property
key-word arguments are passed to :py:func:`xr.open_dataset`
Parameters
----------
fn : str, optional
filename relative to model root, by default 'mesh/mesh.nc'
"""
self._assert_read_mode
for ds in self._read_nc(fn, **kwargs).values():
uds = xu.UgridDataset(ds)
if ds.rio.crs is not None: # parse crs
uds.ugrid.grid.set_crs(ds.raster.crs)
uds = uds.drop_vars(GEO_MAP_COORD, errors="ignore")
self.set_mesh(uds)
def write_mesh(self, fn: str = "mesh/mesh.nc", **kwargs) -> None:
"""Write model grid data to netcdf file at <root>/<fn>
key-word arguments are passed to :py:meth:`xarray.Dataset.ugrid.to_netcdf`
Parameters
----------
fn : str, optional
filename relative to model root, by default 'grid/grid.nc'
"""
if self._mesh is None:
self.logger.debug("No mesh data found, skip writing.")
return
self._assert_write_mode
# filename
_fn = join(self.root, fn)
if not isdir(dirname(_fn)):
os.makedirs(dirname(_fn))
self.logger.debug(f"Writing file {fn}")
# ds_new = xu.UgridDataset(grid=ds_out.ugrid.grid) # bug in xugrid?
ds_out = self.mesh.ugrid.to_dataset()
if self.mesh.ugrid.grid.crs is not None:
# save crs to spatial_ref coordinate
ds_out = ds_out.rio.write_crs(self.mesh.ugrid.grid.crs)
ds_out.to_netcdf(_fn, **kwargs)
[docs]class MeshModel(MeshMixin, Model):
"""Model class Mesh Model for mesh models in HydroMT"""
_CLI_ARGS = {"region": "setup_mesh", "res": "setup_mesh"}
_NAME = "mesh_model"
[docs] def __init__(
self,
root: str = None,
mode: str = "w",
config_fn: str = None,
data_libs: List[str] = None,
logger=logger,
):
"""Initialize a MeshModel for distributed models with an unstructured grid."""
super().__init__(
root=root,
mode=mode,
config_fn=config_fn,
data_libs=data_libs,
logger=logger,
)
## general setup methods
def setup_mesh(
self,
region: dict,
res: Optional[float] = None,
crs: int = None,
) -> xu.UgridDataset:
"""Creates an 2D unstructured mesh or reads an existing 2D mesh according UGRID conventions.
An 2D unstructured mesh will be created as 2D rectangular grid from a geometry (geom_fn) or bbox. If an existing
2D mesh is given, then no new mesh will be generated
Note Only existing meshed with only 2D grid can be read.
#FIXME: read existing 1D2D network file and extract 2D part.
Adds/Updates model layers:
* **mesh** mesh topology: add mesh topology to mesh object
Parameters
----------
region : dict
Dictionary describing region of interest, e.g.:
* {'bbox': [xmin, ymin, xmax, ymax]}
* {'geom': 'path/to/polygon_geometry'}
* {'mesh': 'path/to/2dmesh_file'}
res: float
Resolution used to generate 2D mesh [unit of the CRS], required if region is not based on 'mesh'.
crs : EPSG code, int, optional
Optional EPSG code of the model. If None using the one from region, and else 4326.
Returns
-------
mesh2d : xu.UgridDataset
Generated mesh2d.
"""
self.logger.info(f"Preparing 2D mesh.")
if "mesh" not in region:
if not isinstance(res, (int, float)):
raise ValueError("res argument required")
kind, region = workflows.parse_region(region, logger=self.logger)
if kind == "bbox":
bbox = region["bbox"]
geom = gpd.GeoDataFrame(geometry=[box(*bbox)], crs=4326)
elif kind == "geom":
geom = region["geom"]
if geom.crs is None:
raise ValueError('Model region "geom" has no CRS')
else:
raise ValueError(
f"Region for mesh must of kind [bbox, geom, mesh], kind {kind} not understood."
)
if crs is not None:
geom = geom.to_crs(crs)
# Generate grid based on res for region bbox
xmin, ymin, xmax, ymax = geom.total_bounds
# note we flood the number of faces within bounds
ncol = int((xmax - xmin) // res)
nrow = int((ymax - ymin) // res)
dx, dy = res, -res
faces = []
for i in range(nrow):
top = ymax + i * dy
bottom = ymax + (i + 1) * dy
for j in range(ncol):
left = xmin + j * dx
right = xmin + (j + 1) * dx
faces.append(box(left, bottom, right, top))
grid = gpd.GeoDataFrame(geometry=faces, crs=geom.crs)
# If needed clip to geom
if kind != "bbox":
# TODO: grid.intersects(geom) does not seem to work ?
grid = grid.loc[
gpd.sjoin(
grid, geom, how="left", op="intersects"
).index_right.notna()
].reset_index()
# Create mesh from grid
grid.index.name = "mesh2d_nFaces"
mesh2d = xu.UgridDataset.from_geodataframe(grid)
mesh2d.ugrid.grid.set_crs(grid.crs)
else:
mesh2d_fn = region["mesh"]
if isinstance(mesh2d_fn, (str, Path)) and isfile(mesh2d_fn):
self.logger.info(f"An existing 2D grid is used to prepare 2D mesh.")
ds = xr.open_dataset(mesh2d_fn, mask_and_scale=False)
elif isinstance(mesh2d_fn, xr.Dataset):
ds = mesh2d_fn
else:
raise ValueError(
f"Region 'mesh' file {mesh2d_fn} not found, please check"
)
topologies = [
k for k in ds.data_vars if ds[k].attrs.get("cf_role") == "mesh_topology"
]
for topology in topologies:
topodim = ds[topology].attrs["topology_dimension"]
if topodim != 2: # chek if 2d mesh file else throw error
raise NotImplementedError(
f"{mesh2d_fn} cannot be opened. Please check if the existing grid is "
f"an 2D mesh and not 1D2D mesh. This option is not yet available for 1D2D meshes."
)
# Continues with a 2D grid
mesh2d = xu.UgridDataset(ds)
# Check crs and reproject to model crs
if crs is None:
crs = 4326
if ds.rio.crs is not None: # parse crs
mesh2d.ugrid.grid.set_crs(ds.raster.crs)
else:
# Assume model crs
self.logger.warning(
f"Mesh data from {mesh2d_fn} doesn't have a CRS. Assuming crs option {crs}"
)
mesh2d.ugrid.grid.set_crs(crs)
mesh2d = mesh2d.drop_vars(GEO_MAP_COORD, errors="ignore")
# Reproject to user crs option if needed
if mesh2d.ugrid.grid.crs != crs and crs is not None:
self.logger.info(f"Reprojecting mesh to crs {crs}")
mesh2d.ugrid.grid.to_crs(self.crs)
self.set_mesh(mesh2d)
# This setup method returns region so that it can be wrapped for models which require
# more information
return mesh2d
## I/O
def read(
self,
components: List = [
"config",
"mesh",
"geoms",
"forcing",
"states",
"results",
],
) -> None:
"""Read the complete model schematization and configuration from model files.
Parameters
----------
components : List, optional
List of model components to read, each should have an associated read_<component> method.
By default ['config', 'maps', 'mesh', 'geoms', 'forcing', 'states', 'results']
"""
super().read(components=components)
def write(
self,
components: List = ["config", "mesh", "geoms", "forcing", "states"],
) -> None:
"""Write the complete model schematization and configuration to model files.
Parameters
----------
components : List, optional
List of model components to write, each should have an associated write_<component> method.
By default ['config', 'maps', 'mesh', 'geoms', 'forcing', 'states']
"""
super().write(components=components)
# MeshModel specific methods
# MeshModel properties
@property
def bounds(self) -> Tuple:
"""Returns model mesh bounds."""
if self._mesh is not None:
return self._mesh.ugrid.grid.bounds
@property
def region(self) -> gpd.GeoDataFrame:
"""Returns geometry of region of the model area of interest."""
region = gpd.GeoDataFrame()
if "region" in self.geoms:
region = self.geoms["region"]
elif self.mesh is not None:
crs = self.mesh.ugrid.grid.crs
if crs is None and hasattr(crs, "to_epsg"):
crs = crs.to_epsg() # not all CRS have an EPSG code
region = gpd.GeoDataFrame(geometry=[box(*self.bounds)], crs=crs)
return region
@property
def mesh_gdf(self) -> gpd.GeoDataFrame:
"""Returns geometry of mesh as a gpd.GeoDataFrame"""
if self._mesh is not None:
name = [n for n in self.mesh.data_vars][0] # works better on a DataArray
return self._mesh[name].ugrid.to_geodataframe()