"""Implementation for the geodataset DataAdapter."""
from logging import getLogger
from typing import Dict, List, Optional, Union, cast
import numpy as np
import pyproj
import xarray as xr
from hydromt._typing import (
Geom,
NoDataStrategy,
Predicate,
SourceMetadata,
TimeRange,
Variables,
exec_nodata_strat,
)
from hydromt._typing.type_def import Number
from hydromt._utils import (
_has_no_data,
_rename_vars,
_set_metadata,
_set_vector_nodata,
_shift_dataset_time,
_single_var_as_array,
_slice_temporal_dimension,
)
from hydromt.data_catalog.adapters.data_adapter_base import DataAdapterBase
from hydromt.gis.raster import GEO_MAP_COORD
logger = getLogger(__name__)
__all__ = ["GeoDatasetAdapter"]
[docs]
class GeoDatasetAdapter(DataAdapterBase):
"""DatasetAdapter for GeoDatasets."""
@staticmethod
def _validate_spatial_coords(
ds: Optional[xr.Dataset],
) -> Optional[xr.Dataset]:
if ds is None:
return None
if GEO_MAP_COORD in ds.data_vars:
ds = ds.set_coords(GEO_MAP_COORD)
try:
ds.vector.set_spatial_dims()
idim = ds.vector.index_dim
if idim not in ds: # set coordinates for index dimension if missing
ds[idim] = xr.IndexVariable(idim, np.arange(ds.dims[idim]))
coords = [ds.vector.x_name, ds.vector.y_name, idim]
coords = [item for item in coords if item is not None]
ds = ds.set_coords(coords)
except ValueError:
raise ValueError("GeoDataset: No spatial geometry dimension found")
return ds
@staticmethod
def _set_crs(
ds: Optional[xr.Dataset],
crs: Union[str, int, None] = None,
) -> Optional[xr.Dataset]:
if ds is None:
return None
# set crs
if ds.vector.crs is None and crs is not None:
ds.vector.set_crs(crs)
elif ds.vector.crs is None:
raise ValueError("GeoDataset: CRS not defined in data catalog or data.")
elif crs is not None and ds.vector.crs != pyproj.CRS.from_user_input(crs):
logger.warning(
"GeoDataset: CRS from data catalog does not match CRS of"
" data. The original CRS will be used. Please check your data catalog."
)
return ds
@staticmethod
def _slice_data(
ds: Optional[Union[xr.Dataset, xr.DataArray]],
variables: Optional[Variables] = None,
mask: Optional[Geom] = None,
predicate: Predicate = "intersects",
time_range: Optional[TimeRange] = None,
) -> Optional[xr.Dataset]:
"""Filter the GeoDataset.
Parameters
----------
ds : Optional[Union[xr.Dataset, xr.DataArray]]
input dataset
variables : Optional[List[str]], optional
variable filter, by default None
mask : Optional[gpd.GeoDataFrame], optional
mask to filter by geometry, by default None
predicate : str, optional
predicate to use for the mask filter, by default "intersects"
time_range : Optional[TimeRange], optional
filter start and end times, by default None
Returns
-------
Optional[xr.Dataset]
the filtered GeoDataSet
Raises
------
ValueError
if not all variables are found in the data
NoDataException
if no data in left after slicing and handle_nodata is NoDataStrategy.RAISE
"""
if isinstance(ds, xr.DataArray):
if ds.name is None:
# dummy name, required to create dataset
# renamed to variable in _single_var_as_array
ds.name = "data"
ds = ds.to_dataset()
elif variables is not None:
assert isinstance(ds, xr.Dataset)
variables = cast(List, np.atleast_1d(variables).tolist())
if len(variables) > 1 or len(ds.data_vars) > 1:
mvars = [var not in ds.data_vars for var in variables]
if any(mvars):
raise ValueError(f"GeoDataset: variables not found {mvars}")
ds = ds[variables]
maybe_ds: Optional[xr.Dataset] = ds
if time_range is not None:
maybe_ds = _slice_temporal_dimension(ds, time_range)
if mask is not None:
maybe_ds = GeoDatasetAdapter._slice_spatial_dimension(
maybe_ds,
mask=mask,
predicate=predicate,
)
if _has_no_data(maybe_ds):
return None
else:
return maybe_ds
@staticmethod
def _slice_spatial_dimension(
ds: Optional[xr.Dataset],
mask: Geom,
predicate: Predicate,
) -> Optional[xr.Dataset]:
if ds is None:
return None
else:
bbox_str = ", ".join([f"{c:.3f}" for c in mask.total_bounds])
epsg = mask.crs.to_epsg()
logger.debug(f"Clip {predicate} [{bbox_str}] (EPSG:{epsg})")
ds = ds.vector.clip_geom(mask, predicate=predicate)
if _has_no_data(ds):
return None
else:
return ds
@staticmethod
def _apply_unit_conversion(
ds: Optional[xr.Dataset],
unit_mult: Dict[str, Number],
unit_add: Dict[str, Number],
) -> Optional[xr.Dataset]:
if ds is None:
return None
unit_names = list(unit_mult.keys()) + list(unit_add.keys())
unit_names = [k for k in unit_names if k in ds.data_vars]
if len(unit_names) > 0:
logger.debug(f"Convert units for {len(unit_names)} variables.")
for name in list(set(unit_names)): # unique
m = unit_mult.get(name, 1)
a = unit_add.get(name, 0)
da = ds[name]
attrs = da.attrs.copy()
nodata_isnan = da.vector.nodata is None or np.isnan(da.vector.nodata)
# nodata value is explicitly set to NaN in case no nodata value is provided
nodata = np.nan if nodata_isnan else da.vector.nodata
data_bool = ~np.isnan(da) if nodata_isnan else da != nodata
ds[name] = xr.where(data_bool, da * m + a, nodata)
ds[name].attrs.update(attrs) # set original attributes
return ds