Source code for hydromt_sfincs.workflows.discharge
"""Workflows for discharge boundary conditions."""
import logging
import geopandas as gpd
import numpy as np
import xarray as xr
logger = logging.getLogger(__name__)
__all__ = [
"snap_discharge",
]
[docs]
def snap_discharge(
ds: xr.Dataset,
gdf: gpd.GeoDataFrame,
wdw: int = 1,
rel_error: float = 0.05,
abs_error: float = 50,
uparea_name: str = "uparea",
discharge_name: str = "discharge",
logger=logger,
) -> xr.DataArray:
"""
Snaps point locations to grid cell with smallest difference in upstream area
within `wdw` around the original location if the local cell does not meet the
error criteria. Both the upstream area variable named ``uparea_name`` in
``ds`` and ``gdf`` as well as ``abs_error`` should have the same unit (typically km2).
Parameters
----------
ds: xarray.Dataset
Dataset with discharge and optional uparea variable.
gdf: geopandas.GeoDataFrame[Points]
Dataframe with Point geometries of locations of interest.
wdw: int, optional
Window size in number of cells around discharge boundary locations
to snap to, only used if ``uparea_fn`` is provided. By default 1.
rel_error, abs_error: float, optional
Maximum relative error (default 0.05) and absolute error (default 50 km2)
between the discharge boundary location upstream area and the upstream area of
the best fit grid cell, only used if "discharge" staticgeoms has a "uparea" column.
Returns
-------
ds: xarray.Dataset
snapped dataset
"""
ds_wdw = ds.raster.sample(gdf, wdw=wdw)
# check if valid discharge
valid = ds_wdw[discharge_name].notnull().any("time")
if uparea_name in ds and uparea_name in gdf.columns:
logger.debug(
f"Snapping {discharge_name} points to best matching uparea cell within wdw (size={wdw})."
)
# find cells in window with smallest difference in uparea
upa0 = xr.DataArray(gdf[uparea_name], dims=("index"))
upa_dff = np.abs(
ds_wdw[uparea_name].where(ds_wdw[uparea_name] > 0).load() - upa0
)
i_wdw = upa_dff.fillna(np.inf).argmin("wdw")
# find valid cells based on error criteria
upa_check = np.logical_or((upa_dff / upa0) <= rel_error, upa_dff <= abs_error)
valid = np.logical_and(valid, upa_check)
else:
logger.debug(
f"No {uparea_name} variable found in ds or gdf; "
f"sampling {discharge_name} points from nearest grid cell."
)
# calculate distance to center cell (measured in cells)
ar_wdw = np.abs(np.arange(-wdw, wdw + 1))
dist = np.hypot(*np.meshgrid(ar_wdw, ar_wdw)).ravel()
ds_wdw["dist"] = xr.Variable(
("index", "wdw"), np.tile(dist, (ds_wdw["index"].size, 1))
)
# find nearest valid cell in window
i_wdw = ds_wdw["dist"].where(valid, np.inf).argmin("wdw").load()
# filter valid cells
idx_valid = np.where(valid.isel(wdw=i_wdw).values)[0]
if idx_valid.size < gdf.index.size:
logger.warning(
f"{idx_valid.size}/{gdf.index.size} {discharge_name} points successfully snapped."
)
i_wdw = i_wdw.isel(index=idx_valid)
# return discharge at valid cells
ds_out = ds_wdw.isel(wdw=i_wdw.load(), index=idx_valid)
return ds_out