Source code for xugrid.core.wrap

"""
Wrap in advance instead of overloading __getattr__.

This allows for tab completion and documentation.
"""
from __future__ import annotations

import types
from collections import ChainMap
from functools import wraps
from itertools import chain
from typing import List, Sequence, Union

import xarray as xr
from pandas import RangeIndex

import xugrid
from xugrid.conversion import grid_from_dataset, grid_from_geodataframe
from xugrid.core.utils import unique_grids
from xugrid.ugrid.ugrid2d import Ugrid2d
from xugrid.ugrid.ugridbase import AbstractUgrid, UgridType, align

# Import entire module here for circular import of UgridDatasetAccessor and
# UgridDataArrayAccessor. Note: can only be used in functions (since that code
# is run at runtime).


def maybe_xugrid(obj, topology, old_indexes=None):
    if not isinstance(obj, (xr.DataArray, xr.Dataset)):
        return obj

    # Topology can either be a sequence of grids or a grid.
    if isinstance(topology, (list, set, tuple)):
        grids = {dim: grid for grid in topology for dim in grid.dimensions}
    else:
        grids = {dim: topology for dim in topology.dimensions}

    item_grids = unique_grids([grids[dim] for dim in obj.dims if dim in grids])

    if len(item_grids) == 0:
        return obj
    else:
        result, aligned = align(obj, item_grids, old_indexes)

        if isinstance(result, xr.DataArray):
            if len(aligned) > 1:
                raise RuntimeError("This shouldn't happen. Please open an issue.")
            return UgridDataArray(result, aligned[0])

        elif isinstance(result, xr.Dataset):
            return UgridDataset(result, aligned)


def maybe_xarray(arg):
    if isinstance(arg, (UgridDataArray, UgridDataset)):
        return arg.obj
    else:
        return arg


def wraps_xarray(method):
    @wraps(method)
    def wrapped(*args, **kwargs):
        self = args[0]
        args = [maybe_xarray(arg) for arg in args]
        kwargs = {k: maybe_xarray(v) for k, v in kwargs.items()}
        result = method(*args, **kwargs)

        # Sidestep staticmethods, classmethods: in that case self will not be a
        # xugrid type.
        if isinstance(self, (UgridDataArray, UgridDataset)):
            return maybe_xugrid(result, self.grids, self.obj.indexes)
        else:
            return result

    return wrapped


def wrap_accessor(accessor):
    # TODO: This will not add dynamic accessors, those most be included at
    # runtime instead?

    def wrapped(*args, **kwargs):
        args = [maybe_xarray(arg) for arg in args]
        kwargs = {k: maybe_xarray(v) for k, v in kwargs.items()}
        result = accessor(*args, **kwargs)
        return result

    return wrapped


def wrap(
    target_class_dict,
    source_class,
):
    FuncType = (types.FunctionType, types.MethodType)

    # Set every method, property from the xarray object to the UgridDataArray,
    # UgridDataset. Don't set everything, as this will break the objects.
    #
    # class Empty:
    #     pass
    #
    # keep = {
    #     "__eq__",
    #     "__ge__",
    #     "__gt__",
    #     "__le__",
    #     "__lt__",
    #     "__ne__",
    #     "__repr__",
    #     "__str__",
    # }
    #
    # remove = set(dir(Empty)) - keep

    remove = {
        # These members are shared by all objects:
        "__class__",
        "__delattr__",
        "__dict__",
        "__dir__",
        "__doc__",
        "__format__",
        "__getattribute__",
        "__hash__",
        "__init__",
        "__init_subclass__",
        "__module__",
        "__new__",
        "__reduce__",
        "__reduce_ex__",
        "__setattr__",
        "__sizeof__",
        "__subclasshook__",
        "__weakref__"
        # These are additionally included in xarray:
        "__getatrr__",
        "__slots__",
        "__annotations__",
    }

    attr_names = set(dir(source_class)) - remove
    all_attrs = {k: getattr(source_class, k) for k in attr_names}

    methods = {k: v for k, v in all_attrs.items() if isinstance(v, FuncType)}
    for name, method in methods.items():
        wrapped = wraps_xarray(method)
        setattr(wrapped, "__doc__", method.__doc__)
        target_class_dict[name] = wrapped

    properties = {k: v for k, v in all_attrs.items() if isinstance(v, property)}
    for name, prop in properties.items():
        wrapped = property(
            fget=wraps_xarray(prop.__get__),
            fset=wraps_xarray(prop.__set__),
            doc=prop.__doc__,
        )
        target_class_dict[name] = wrapped

    accessors = {k: v for k, v in all_attrs.items() if isinstance(v, type)}
    for name, accessor in accessors.items():
        wrapped = property(wrap_accessor(accessor))
        setattr(wrapped, "__doc__", accessor.__doc__)
        target_class_dict[name] = wrapped

    return


class DataArrayForwardMixin:
    wrap(
        target_class_dict=vars(),
        source_class=xr.DataArray,
    )


class DatasetForwardMixin:
    wrap(
        target_class_dict=vars(),
        source_class=xr.Dataset,
    )


def assign_ugrid_coords(obj, grids):
    grid_dims = ChainMap(*(grid.dimensions for grid in grids))
    ugrid_dims = set(grid_dims.keys()).intersection(obj.dims)
    ugrid_coords = {dim: RangeIndex(0, grid_dims[dim]) for dim in ugrid_dims}
    obj = obj.assign_coords(ugrid_coords)
    return obj


[docs] class UgridDataArray(DataArrayForwardMixin):
[docs] def __init__(self, obj: xr.DataArray, grid: UgridType): if not isinstance(obj, xr.DataArray): raise TypeError( "obj must be xarray.DataArray. Received instead: " f"{type(obj).__name__}" ) if not isinstance(grid, AbstractUgrid): raise TypeError( "grid must be Ugrid1d or Ugrid2d. Received instead: " f"{type(grid).__name__}" ) self._grid = grid self._obj = assign_ugrid_coords(obj, [grid])
def __getattr__(self, attr): result = getattr(self.obj, attr) return maybe_xugrid(result, [self.grid]) @property def obj(self): return self._obj @property def grid(self): return self._grid @property def grids(self) -> List[UgridType]: return [self._grid] @property def ugrid(self): """ UGRID Accessor. This "accessor" makes operations using the UGRID topology available. """ return xugrid.core.dataarray_accessor.UgridDataArrayAccessor( self.obj, self.grid )
[docs] @staticmethod def from_structured( da: xr.DataArray, x: str | None = None, y: str | None = None, ) -> "UgridDataArray": """ Create a UgridDataArray from a (structured) xarray DataArray. The spatial dimensions are flattened into a single UGRID face dimension. By default, this method looks for the "x" and "y" coordinates and assumes they are one-dimensional. To convert rotated or curvilinear coordinates, provide the names of the x and y coordinates. Parameters ---------- da: xr.DataArray Last two dimensions must be ``("y", "x")``. x: str, default: None Which coordinate to use as the UGRID x-coordinate. y: str, default: None Which coordinate to use as the UGRID y-coordinate. Returns ------- unstructured: UgridDataArray """ if da.dims[-2:] != ("y", "x"): raise ValueError('Last two dimensions of da must be ("y", "x")') if (x is None) ^ (y is None): raise ValueError("Provide both x and y, or neither.") if x is None: grid = Ugrid2d.from_structured(da) else: # Find out if it's multi-dimensional xdim = da[x].ndim if xdim == 1: grid = Ugrid2d.from_structured(da, x=x, y=y) elif xdim == 2: grid = Ugrid2d.from_structured_multicoord(da, x=x, y=y) else: raise ValueError(f"x and y must be 1D or 2D. Found: {xdim}") dims = da.dims[:-2] coords = {k: da.coords[k] for k in dims} face_da = xr.DataArray( da.data.reshape(*da.shape[:-2], -1), coords=coords, dims=[*dims, grid.face_dimension], name=da.name, ) return UgridDataArray(face_da, grid)
[docs] class UgridDataset(DatasetForwardMixin):
[docs] def __init__( self, obj: xr.Dataset = None, grids: Union[UgridType, Sequence[UgridType]] = None, ): if obj is None and grids is None: raise ValueError("At least either obj or grids is required") if obj is None: ds = xr.Dataset() else: if not isinstance(obj, xr.Dataset): raise TypeError( "obj must be xarray.Dataset. Received instead: " f"{type(obj).__name__}" ) connectivity_vars = [ name for v in obj.ugrid_roles.connectivity.values() for name in v.values() ] ds = obj.drop_vars(obj.ugrid_roles.topology + connectivity_vars) if grids is None: topologies = obj.ugrid_roles.topology grids = [grid_from_dataset(obj, topology) for topology in topologies] else: # Make sure it's a new list if isinstance(grids, (list, tuple, set)): grids = list(grids) else: # not iterable grids = [grids] # Now typecheck for grid in grids: if not isinstance(grid, AbstractUgrid): raise TypeError( "grid must be Ugrid1d or Ugrid2d. " f"Received instead: {type(grid).__name__}" ) self._grids = grids self._obj = assign_ugrid_coords(ds, grids)
@property def obj(self): return self._obj @property def grid(self) -> UgridType: # We need to do some checking. Don't duplicate that logic. return self.ugrid.grid @property def grids(self) -> List[UgridType]: return self._grids @property def ugrid(self): """ UGRID Accessor. This "accessor" makes operations using the UGRID topology available. """ return xugrid.core.dataset_accessor.UgridDatasetAccessor(self.obj, self.grids) def __getattr__(self, attr): result = getattr(self.obj, attr) return maybe_xugrid(result, self.grids) def __setitem__(self, key, value): # TODO: check with topology if isinstance(value, UgridDataArray): append = True # Check if the dimensions occur in self. # if they don't, the grid should be added. if self.grids is not None: alldims = set( chain.from_iterable([grid.dimensions for grid in self.grids]) ) matching_dims = set(value.grid.dimensions).intersection(alldims) if matching_dims: append = False # If they do match: the grids should match. grids = { dim: grid for grid in self.grids for dim in grid.dimensions } firstdim = next(iter(matching_dims)) grid_to_check = grids[firstdim] if not grid_to_check.equals(value.grid): raise ValueError( "Grids share dimension names but do not are not identical. " f"Matching dimensions: {matching_dims}" ) self.obj[key] = value.obj if append: self._grids.append(value.grid) else: self.obj[key] = value
[docs] @staticmethod def from_geodataframe(geodataframe: "geopandas.GeoDataFrame"): # type: ignore # noqa """ Convert a geodataframe into the appropriate Ugrid topology and dataset. Parameters ---------- geodataframe: gpd.GeoDataFrame Returns ------- dataset: UGridDataset """ grid = grid_from_geodataframe(geodataframe) ds = xr.Dataset.from_dataframe(geodataframe.drop("geometry", axis=1)) return UgridDataset(ds, [grid])