Source code for xugrid.core.wrap

"""
Wrap in advance instead of overloading __getattr__.

This allows for tab completion and documentation.
"""

from __future__ import annotations

import types
import warnings
from collections import ChainMap
from functools import wraps
from itertools import chain
from typing import List, Sequence, Union

import xarray as xr
from numpy.typing import ArrayLike
from pandas import RangeIndex

import xugrid
from xugrid.conversion import grid_from_dataset, grid_from_geodataframe
from xugrid.core.utils import unique_grids
from xugrid.ugrid.ugrid2d import Ugrid2d
from xugrid.ugrid.ugridbase import AbstractUgrid, UgridType, align

# Import entire module here for circular import of UgridDatasetAccessor and
# UgridDataArrayAccessor. Note: can only be used in functions (since that code
# is run at runtime).


def maybe_xugrid(obj, topology, old_indexes=None):
    if not isinstance(obj, (xr.DataArray, xr.Dataset)):
        return obj

    # Topology can either be a sequence of grids or a grid.
    if isinstance(topology, (list, set, tuple)):
        grids = {dim: grid for grid in topology for dim in grid.dims}
    else:
        grids = dict.fromkeys(topology.dims, topology)

    item_grids = unique_grids([grids[dim] for dim in obj.dims if dim in grids])

    if len(item_grids) == 0:
        return obj
    else:
        result, aligned = align(obj, item_grids, old_indexes)

        if isinstance(result, xr.DataArray):
            if len(aligned) > 1:
                raise RuntimeError("This shouldn't happen. Please open an issue.")
            return UgridDataArray(result, aligned[0])

        elif isinstance(result, xr.Dataset):
            return UgridDataset(result, aligned)


def maybe_xarray(arg):
    if isinstance(arg, (UgridDataArray, UgridDataset)):
        return arg.obj
    else:
        return arg


def wraps_xarray(method):
    @wraps(method)
    def wrapped(*args, **kwargs):
        self = args[0]
        args = [maybe_xarray(arg) for arg in args]
        kwargs = {k: maybe_xarray(v) for k, v in kwargs.items()}
        result = method(*args, **kwargs)

        # Sidestep staticmethods, classmethods: in that case self will not be a
        # xugrid type.
        if isinstance(self, (UgridDataArray, UgridDataset)):
            return maybe_xugrid(result, self.grids, self.obj.indexes)
        else:
            return result

    return wrapped


def wrap_accessor(accessor):
    # TODO: This will not add dynamic accessors, those most be included at
    # runtime instead?

    def wrapped(*args, **kwargs):
        args = [maybe_xarray(arg) for arg in args]
        kwargs = {k: maybe_xarray(v) for k, v in kwargs.items()}
        result = accessor(*args, **kwargs)
        return result

    return wrapped


def wrap(
    target_class_dict,
    source_class,
):
    FuncType = (types.FunctionType, types.MethodType)

    # Set every method, property from the xarray object to the UgridDataArray,
    # UgridDataset. Don't set everything, as this will break the objects.
    #
    # class Empty:
    #     pass
    #
    # keep = {
    #     "__eq__",
    #     "__ge__",
    #     "__gt__",
    #     "__le__",
    #     "__lt__",
    #     "__ne__",
    #     "__repr__",
    #     "__str__",
    # }
    #
    # remove = set(dir(Empty)) - keep

    remove = {
        # These members are shared by all objects:
        "__class__",
        "__delattr__",
        "__dict__",
        "__dir__",
        "__doc__",
        "__format__",
        "__getattribute__",
        "__hash__",
        "__init__",
        "__init_subclass__",
        "__module__",
        "__new__",
        "__reduce__",
        "__reduce_ex__",
        "__setattr__",
        "__sizeof__",
        "__subclasshook__",
        "__weakref__"
        # These are additionally included in xarray:
        "__getatrr__",
        "__slots__",
        "__annotations__",
    }

    attr_names = set(dir(source_class)) - remove
    all_attrs = {k: getattr(source_class, k) for k in attr_names}

    methods = {k: v for k, v in all_attrs.items() if isinstance(v, FuncType)}
    for name, method in methods.items():
        wrapped = wraps_xarray(method)
        setattr(wrapped, "__doc__", method.__doc__)
        target_class_dict[name] = wrapped

    properties = {k: v for k, v in all_attrs.items() if isinstance(v, property)}
    for name, prop in properties.items():
        wrapped = property(
            fget=wraps_xarray(prop.__get__),
            fset=wraps_xarray(prop.__set__),
            doc=prop.__doc__,
        )
        target_class_dict[name] = wrapped

    accessors = {k: v for k, v in all_attrs.items() if isinstance(v, type)}
    for name, accessor in accessors.items():
        wrapped = property(wrap_accessor(accessor))
        setattr(wrapped, "__doc__", accessor.__doc__)
        target_class_dict[name] = wrapped

    return


class DataArrayForwardMixin:
    wrap(
        target_class_dict=vars(),
        source_class=xr.DataArray,
    )


class DatasetForwardMixin:
    wrap(
        target_class_dict=vars(),
        source_class=xr.Dataset,
    )


def assign_ugrid_coords(obj, grids):
    grid_dims = ChainMap(*(grid.sizes for grid in grids))
    ugrid_dims = set(grid_dims.keys()).intersection(obj.dims)
    ugrid_coords = {dim: RangeIndex(0, grid_dims[dim]) for dim in ugrid_dims}
    obj = obj.assign_coords(ugrid_coords)
    return obj


[docs] class UgridDataArray(DataArrayForwardMixin):
[docs] def __init__(self, obj: xr.DataArray, grid: UgridType): if not isinstance(obj, xr.DataArray): raise TypeError( f"obj must be xarray.DataArray. Received instead: {type(obj).__name__}" ) if not isinstance(grid, AbstractUgrid): raise TypeError( "grid must be Ugrid1d or Ugrid2d. Received instead: " f"{type(grid).__name__}" ) self._grid = grid self._obj = assign_ugrid_coords(obj, [grid]) self._obj.set_close(obj._close)
def __getattr__(self, attr): result = getattr(self.obj, attr) return maybe_xugrid(result, [self.grid]) @property def obj(self): return self._obj @property def grid(self): return self._grid @property def grids(self) -> List[UgridType]: return [self._grid] @property def ugrid(self): """ UGRID Accessor. This "accessor" makes operations using the UGRID topology available. """ return xugrid.core.dataarray_accessor.UgridDataArrayAccessor( self.obj, self.grid )
[docs] @staticmethod def from_structured2d( da: xr.DataArray, x: str | None = None, y: str | None = None, x_bounds: xr.DataArray = None, y_bounds: xr.DataArray = None, ) -> "UgridDataArray": """ Create a UgridDataArray from a (structured) xarray DataArray. The spatial dimensions are flattened into a single UGRID face dimension. By default, this method looks for: 1. "x" and "y" dimensions 2. "longitude" and "latitude" dimensions 3. "axis" attributes of "X" or "Y" on coordinates 4. "standard_name" attributes of "longitude", "latitude", "projection_x_coordinate", or "projection_y_coordinate" on coordinate variables Parameters ---------- da : xr.DataArray The structured data array to convert. The last two dimensions must be the y and x dimensions (in that order). x : str, optional Name of the UGRID x-coordinate, or x-dimension if bounds are provided. Defaults to None. y : str, optional Name of the UGRID y-coordinate, or y-dimension if bounds are provided. Defaults to None. x_bounds : xr.DataArray, optional Bounds for x-coordinates. Required for non-monotonic coordinates. Defaults to None. y_bounds : xr.DataArray, optional Bounds for y-coordinates. Required for non-monotonic coordinates. Defaults to None. Returns ------- UgridDataArray The unstructured grid data array. Notes ----- When using bounds, they should have one of these shapes: * x bounds: (M, 2) or (N, M, 4) * y bounds: (N, 2) or (N, M, 4) where N is the number of rows (along y) and M is columns (along x). Cells with NaN bounds coordinates are omitted. Examples -------- Basic usage with default coordinate detection: >>> uda = xugrid.UgridDataArray.from_structured2d(data_array) Specifying explicit coordinate names: >>> uda = xugrid.UgridDataArray.from_structured2d( ... data_array, ... x="longitude", ... y="latitude" ... ) Using bounds for curvilinear grids: >>> uda = xugrid.UgridDataArray.from_structured2d( ... data_array, ... x="x_dim", ... y="y_dim", ... x_bounds=x_bounds_array, ... y_bounds=y_bounds_array ... ) """ if da.ndim < 2: raise ValueError( "DataArray must have at least two spatial dimensions. " f"Found: {da.dims}." ) if x_bounds is not None and y_bounds is not None: if x is None or y is None: raise ValueError("x and y must be provided for bounds") yx = (y, x) grid, index = Ugrid2d.from_structured_bounds( x_bounds=x_bounds.transpose(y, x, ...).to_numpy(), y_bounds=y_bounds.transpose(y, x, ...).to_numpy(), return_index=True, ) else: # Possibly rely on inference of x and y dims. grid, yx = Ugrid2d.from_structured(da, x, y, return_dims=True) index = slice(None, None) face_da = ( da.stack( # noqa: PD013 {grid.face_dimension: (yx)}, create_index=False ) .isel({grid.face_dimension: index}) .drop_vars(yx, errors="ignore") ) return UgridDataArray(face_da, grid)
@staticmethod def from_structured( da: xr.DataArray, x: str | None = None, y: str | None = None, x_bounds: xr.DataArray = None, y_bounds: xr.DataArray = None, ) -> "UgridDataArray": warnings.warn( "UgridDataArray.from_structured is deprecated and will be removed. " "Use UgridDataArray.from_structured2d instead.", FutureWarning, stacklevel=2, ) return UgridDataArray.from_structured2d(da, x, y, x_bounds, y_bounds)
[docs] @staticmethod def from_data(data: ArrayLike, grid: UgridType, facet: str) -> UgridDataArray: """ Create a UgridDataArray from a grid and a 1D array of values. Parameters ---------- data: array like Values for this array. Must be a ``numpy.ndarray`` or castable to it. grid: Ugrid1d, Ugrid2d facet: str With which facet to associate the data. Options for Ugrid1d are, ``"node"`` or ``"edge"``. Options for Ugrid2d are ``"node"``, ``"edge"``, or ``"face"``. Returns ------- uda: UgridDataArray """ return grid.create_data_array(data=data, facet=facet)
[docs] class UgridDataset(DatasetForwardMixin):
[docs] def __init__( self, obj: xr.Dataset = None, grids: Union[UgridType, Sequence[UgridType]] = None, ): if obj is None and grids is None: raise ValueError("At least either obj or grids is required") if obj is None: original = ds = xr.Dataset() else: if not isinstance(obj, xr.Dataset): raise TypeError( "obj must be xarray.Dataset. Received instead: " f"{type(obj).__name__}" ) connectivity_vars = [ name for v in obj.ugrid_roles.connectivity.values() for name in v.values() ] ds = obj.drop_vars(obj.ugrid_roles.topology + connectivity_vars) original = obj if grids is None: topologies = obj.ugrid_roles.topology grids = [grid_from_dataset(obj, topology) for topology in topologies] else: # Make sure it's a new list if isinstance(grids, (list, tuple, set)): grids = list(grids) else: # not iterable grids = [grids] # Now typecheck for grid in grids: if not isinstance(grid, AbstractUgrid): raise TypeError( "grid must be Ugrid1d or Ugrid2d. " f"Received instead: {type(grid).__name__}" ) self._grids = grids self._obj = assign_ugrid_coords(ds, grids) # We've created a new object; the file handle will be associated with the original. # set_close makes sure that when close is called on the UgridDataset, that the # file will actually be closed (via the original). self._obj.set_close(original._close)
@property def obj(self): return self._obj @property def grid(self) -> UgridType: # We need to do some checking. Don't duplicate that logic. return self.ugrid.grid @property def grids(self) -> List[UgridType]: return self._grids @property def ugrid(self): """ UGRID Accessor. This "accessor" makes operations using the UGRID topology available. """ return xugrid.core.dataset_accessor.UgridDatasetAccessor(self.obj, self.grids) def __getattr__(self, attr): result = getattr(self.obj, attr) return maybe_xugrid(result, self.grids) def __setitem__(self, key, value): # TODO: check with topology if isinstance(value, UgridDataArray): append = True # Check if the dimensions occur in self. # if they don't, the grid should be added. if self.grids is not None: alldims = set(chain.from_iterable([grid.dims for grid in self.grids])) matching_dims = set(value.grid.dims).intersection(alldims) if matching_dims: append = False # If they do match: the grids should match. grids = {dim: grid for grid in self.grids for dim in grid.dims} firstdim = next(iter(matching_dims)) grid_to_check = grids[firstdim] if not grid_to_check.equals(value.grid): raise ValueError( "Grids share dimension names but do not are not identical. " f"Matching dimensions: {matching_dims}" ) self.obj[key] = value.obj if append: self._grids.append(value.grid) else: self.obj[key] = value
[docs] @staticmethod def from_geodataframe(geodataframe: "geopandas.GeoDataFrame"): # type: ignore # noqa """ Convert a geodataframe into the appropriate Ugrid topology and dataset. Parameters ---------- geodataframe: gpd.GeoDataFrame Returns ------- dataset: UGridDataset """ grid = grid_from_geodataframe(geodataframe) ds = xr.Dataset.from_dataframe(geodataframe.drop("geometry", axis=1)) return UgridDataset(ds, [grid])
[docs] @staticmethod def from_structured2d( dataset: xr.Dataset, topology: dict | None = None ) -> "UgridDataset": """ Create a UgridDataset from a (structured) xarray Dataset. The spatial dimensions are flattened into a single UGRID face dimension. By default, this method looks for: 1. "x" and "y" dimensions 2. "longitude" and "latitude" dimensions 3. "axis" attributes of "X" or "Y" on coordinates 4. "standard_name" attributes of "longitude", "latitude", "projection_x_coordinate", or "projection_y_coordinate" on coordinate variables Parameters ---------- dataset : xr.Dataset The structured dataset to convert. topology : dict, optional Either: * A mapping of topology name to (x, y) coordinate names * A mapping of topology name to a dict containing: - "x": x-dimension name - "y": y-dimension name - "bounds_x": x-bounds variable name - "bounds_y": y-bounds variable name Defaults to {"mesh2d": (None, None)}. Returns ------- UgridDataset The unstructured grid dataset. Notes ----- When using bounds, they should have one of these shapes: * x bounds: (M, 2) or (N, M, 4) * y bounds: (N, 2) or (N, M, 4) where N is the number of rows (along y) and M is columns (along x). Cells with NaN bounds coordinates are omitted. Examples -------- Basic usage with default coordinate names: >>> uds = xugrid.UgridDataset.from_structured2d(dataset) Specifying custom coordinate names: >>> uds = xugrid.UgridDataset.from_structured2d( ... dataset, ... topology={"my_mesh2d": {"x": "xc", "y": "yc"}} ... ) Multiple grid topologies in a single dataset: >>> uds = xugrid.UgridDataset.from_structured2d( ... dataset, ... topology={ ... "mesh2d_xy": {"x": "x", "y": "y"}, ... "mesh2d_lonlat": {"x": "lon", "y": "lat"} ... } ... ) Using bounds for non-monotonic coordinates (e.g., curvilinear grids): >>> uds = xugrid.UgridDataset.from_structured2d( ... dataset, ... topology={ ... "my_mesh2d": { ... "x": "M", ... "y": "N", ... "bounds_x": "grid_x", ... "bounds_y": "grid_y" ... } ... } ... ) """ if topology is None: # By default, set None. This communicates to # Ugrid2d.from_structured to infer x and y dims. topology = {"mesh2d": (None, None)} grids = [] dss = [] xy_vars = set() # store x, y, x_bounds, y_bounds to drop. for name, args in topology.items(): x_bounds = None y_bounds = None if isinstance(args, dict): x = args.get("x") y = args.get("y") if "x_bounds" in args and "y_bounds" in args: if x is None or y is None: raise ValueError("x and y must be provided for bounds") x_bounds = dataset[args["x_bounds"]] y_bounds = dataset[args["y_bounds"]] xy_vars.update((args["x_bounds"], args["y_bounds"])) elif isinstance(args, tuple): x, y = args else: raise TypeError( "Expected dict or tuple in topology, received: " f"{type(args).__name__}" ) if x_bounds is not None and y_bounds is not None: stackdims = (y, x) grid, index = Ugrid2d.from_structured_bounds( x_bounds.transpose(*stackdims, ...).to_numpy(), y_bounds.transpose(*stackdims, ...).to_numpy(), name=name, return_index=True, ) else: grid, stackdims = Ugrid2d.from_structured( dataset, x=x, y=y, name=name, return_dims=True ) index = slice(None, None) # Use subset to check that ALL dims of stackdims are present in the # variable. checkdims = set(stackdims) xy_vars.update(checkdims) ugrid_vars = [ name for name, var in dataset.data_vars.items() if checkdims.issubset(var.dims) and name not in xy_vars ] dss.append( dataset[ugrid_vars] # noqa: PD013 .stack({grid.face_dimension: stackdims}) .isel({grid.face_dimension: index}) .drop_vars(stackdims + (grid.face_dimension,)) ) grids.append(grid) # Add the original dataset to include all non-UGRID variables. dss.append(dataset.drop_vars(xy_vars, errors="ignore")) # Then merge with compat="override". This'll pick the first available # variable: i.e. it will prioritize the UGRID form. merged = xr.merge(dss, compat="override") return UgridDataset(merged, grids)
@staticmethod def from_structured( dataset: xr.Dataset, topology: dict | None = None ) -> "UgridDataset": warnings.warn( "UgridDataset.from_structured is deprecated and will be removed. " "Use UgridDataset.from_structured2d instead.", FutureWarning, stacklevel=2, ) return UgridDataset.from_structured2d(dataset, topology)