Source code for hydromt_sfincs.workflows.discharge
"""Workflows for discharge boundary conditions."""
import logging
import geopandas as gpd
import numpy as np
import xarray as xr
logger = logging.getLogger(__name__)
__all__ = [
"snap_discharge",
]
[docs]def snap_discharge(
ds: xr.Dataset,
gdf: gpd.GeoDataFrame,
wdw: int = 1,
rel_error: float = 0.05,
abs_error: float = 50,
uparea_name: str = "uparea",
discharge_name: str = "discharge",
logger=logger,
) -> xr.DataArray:
"""
Snaps point locations to grid cell with smallest difference in upstream area
within `wdw` around the original location if the local cell does not meet the
error criteria. Both the upstream area variable named ``uparea_name`` in
``ds`` and ``gdf`` as well as ``abs_error`` should have the same unit (typically km2).
Parameters
----------
ds: xarray.Dataset
Dataset with discharge and optional uparea variable.
gdf: geopandas.GeoDataFrame[Points]
Dataframe with Point geometries of locations of interest.
wdw: int, optional
Window size in number of cells around discharge boundary locations
to snap to, only used if ``uparea_fn`` is provided. By default 1.
rel_error, abs_error: float, optional
Maximum relative error (default 0.05) and absolute error (default 50 km2)
between the discharge boundary location upstream area and the upstream area of
the best fit grid cell, only used if "discharge" staticgeoms has a "uparea" column.
Returns
-------
ds: xarray.Dataset
snapped dataset
"""
ds_wdw = ds.raster.sample(gdf, wdw=wdw)
# check if valid discharge
valid = ds_wdw[discharge_name].notnull().any("time")
if uparea_name in ds and uparea_name in gdf.columns:
logger.debug(
f"Snapping {discharge_name} points to best matching uparea cell within wdw (size={wdw})."
)
upa0 = xr.DataArray(gdf[uparea_name], dims=("index"))
upa_dff = np.abs(
ds_wdw[uparea_name].where(ds_wdw[uparea_name] > 0).load() - upa0
)
upa_check = np.logical_or((upa_dff / upa0) <= rel_error, upa_dff <= abs_error)
valid = np.logical_and(valid, upa_check)
# combine valid local cells with best matching windows cells if local cell invalid
# i_loc = int((1 + 2 * wdw) ** 2 / 2) # center cell
# i_wdw = upa_dff.argmin("wdw").where(~valid.isel(wdw=i_loc), i_loc).load()
# find best matching uparea cell in window
i_wdw = upa_dff.argmin("wdw").load()
else:
logger.debug(
f"No {uparea_name} variable found in ds or gdf; "
f"sampling {discharge_name} points from nearest grid cell."
)
# add distance (measured in cells)
ar_wdw = np.abs(np.arange(-wdw, wdw + 1))
dist = np.hypot(**np.meshgrid(ar_wdw, ar_wdw)).ravel()
ds_wdw["dist"] = xr.Variable(
("index", "wdw"), np.tile(dist, (ds_wdw["index"].size, 1))
)
i_wdw = ds_wdw["dist"].where(valid, np.inf).argmin("wdw").load()
idx_valid = np.where(valid.isel(wdw=i_wdw).values)[0]
if idx_valid.size < gdf.index.size:
logger.warning(
f"{idx_valid.size}/{gdf.index.size} {discharge_name} points successfully snapped."
)
i_wdw = i_wdw.isel(index=idx_valid)
ds_out = ds_wdw.isel(wdw=i_wdw.load(), index=idx_valid)
return ds_out # .reset_coords()[discharge_name]