"""Implementations relating to merging of datasets and arrays."""
import numpy as np
import rasterio
from pyproj import CRS
from hydromt.gis.raster import full_from_transform
__all__ = ["merge"]
[docs]
def merge(
    data_arrays,
    dst_crs=None,
    dst_bounds=None,
    dst_res=None,
    align=True,
    mask=None,
    merge_method="first",
    **kwargs,
):
    """Merge multiple tiles to a single DataArray.
    If mismatching grid CRS or resolution, tiles are reprojected
    to match the output DataArray grid. If no destination
    grid is defined it is based on the first DataArray in the list.
    Based on :py:meth:`rasterio.merge.merge`.
    Arguments
    ---------
    data_arrays: list of xarray.DataArray
        Tiles to merge
    dst_crs: pyproj.CRS, int
        CRS (or EPSG code) of destination grid
    dst_bounds: list of float
        Bounding box [xmin, ymin, xmax, ymax] of destination grid
    dst_res: float
        Resolution of destination grid
    align : bool, optional
        If True, align grid with dst_res
    mask: geopands.GeoDataFrame, optional
        Mask of destination area of interest.
        Used to determine dst_crs, dst_bounds and dst_res if missing.
    merge_method: {'first','last','min','max','new'}, callable
        Merge method:
        * first: reverse painting
        * last: paint valid new on top of existing
        * min: pixel-wise min of existing and new
        * max: pixel-wise max of existing and new
        * new: assert no pixel overlap
    **kwargs:
        Key-word arguments passed to
        :py:meth:`~hydromt.raster.RasterDataArray.reproject`
    Returns
    -------
    da_out: xarray.DataArray
        Merged tiles.
    """
    # define merge method
    if merge_method == "first":
        def copyto(old_data, new_data, old_nodata, new_nodata):
            dmask = np.logical_and(old_nodata, ~new_nodata)
            old_data[dmask] = new_data[dmask]
    elif merge_method == "last":
        def copyto(old_data, new_data, old_nodata, new_nodata):
            dmask = ~new_nodata
            old_data[dmask] = new_data[dmask]
    elif merge_method == "min":
        def copyto(old_data, new_data, old_nodata, new_nodata):
            dmask = np.logical_and(~old_nodata, ~new_nodata)
            old_data[dmask] = np.minimum(old_data[dmask], new_data[dmask])
            dmask = np.logical_and(old_nodata, ~new_nodata)
            old_data[dmask] = new_data[dmask]
    elif merge_method == "max":
        def copyto(old_data, new_data, old_nodata, new_nodata):
            dmask = np.logical_and(~old_nodata, ~new_nodata)
            old_data[dmask] = np.maximum(old_data[dmask], new_data[dmask])
            dmask = np.logical_and(old_nodata, ~new_nodata)
            old_data[dmask] = new_data[dmask]
    elif merge_method == "new":
        def copyto(old_data, new_data, old_nodata, new_nodata):
            data_overlap = np.logical_and(~old_nodata, ~new_nodata)
            assert not np.any(data_overlap)
            mask = ~new_nodata
            old_data[mask] = new_data[mask]
    elif callable(merge_method):
        copyto = merge_method
    # resolve arguments
    if data_arrays[0].ndim != 2:
        raise ValueError("Mosaic is only implemented for 2D DataArrays.")
    if dst_crs is None and (dst_bounds is not None or dst_res is not None):
        raise ValueError("dst_bounds and dst_res not understood without dst_crs.")
    idx0 = 0
    if mask is not None and (dst_res is None or dst_crs is None):
        # select array with largest overlap of valid cells if mask
        areas = []
        for da in data_arrays:
            da_clipped = da.raster.clip_geom(geom=mask, mask=True)
            n = da_clipped.raster.mask_nodata().notnull().sum().load().item()
            areas.append(n)
        idx0 = np.argmax(areas)
        # if single layer with overlap of valid cells: return clip
        if np.sum(np.array(areas) > 0) == 1:
            return data_arrays[idx0].raster.clip_geom(mask)
    da0 = data_arrays[idx0]  # used for default dst_crs and dst_res
    # dst CRS
    if dst_crs is None:
        dst_crs = da0.raster.crs
    else:
        dst_crs = CRS.from_user_input(dst_crs)
    # dst res
    if dst_res is None:
        dst_res = da0.raster.res
    if isinstance(dst_res, float):
        x_res, y_res = dst_res, -dst_res
    elif isinstance(dst_res, (tuple, list)) and len(dst_res) == 2:
        x_res, y_res = dst_res
    else:
        raise ValueError("dst_res not understood.")
    assert x_res > 0
    assert y_res < 0
    dst_res = (x_res, -y_res)  # NOTE: y_res is multiplied with -1 in rasterio!
    # dst bounds
    if dst_bounds is None and mask is not None:
        w, s, e, n = mask.to_crs(dst_crs).buffer(2 * abs(x_res)).total_bounds
        # align with da0 grid
        w0, s0, e0, n0 = da0.raster.bounds
        w = w0 + np.round((w - w0) / abs(x_res)) * abs(x_res)
        n = n0 + np.round((n - n0) / abs(y_res)) * abs(y_res)
        e = w + int(round((e - w) / abs(x_res))) * abs(x_res)
        s = n - int(round((n - s) / abs(y_res))) * abs(y_res)
        align = False  # don't align with resolution
    elif dst_bounds is None:
        for i, da in enumerate(data_arrays):
            if i == 0:
                w, s, e, n = da.raster.transform_bounds(dst_crs)
            else:
                w1, s1, e1, n1 = da.raster.transform_bounds(dst_crs)
                w = min(w, w1)
                s = min(s, s1)
                e = max(e, e1)
                n = max(n, n1)
    else:
        w, s, e, n = dst_bounds
    # check N>S orientation
    top, bottom = (n, s) if y_res < 0 else (s, n)
    dst_bounds = w, bottom, e, top
    # dst transform
    width = int(round((e - w) / abs(x_res)))
    height = int(round((n - s) / abs(y_res)))
    transform = rasterio.transform.from_bounds(*dst_bounds, width, height)
    # align with dst_res
    if align:
        transform, width, height = rasterio.warp.aligned_target(
            transform, width, height, dst_res
        )
    # creat output array
    nodata = da0.raster.nodata
    isnan = np.isnan(nodata)
    dtype = da0.dtype
    da_out = full_from_transform(
        transform=transform,
        shape=(height, width),
        nodata=nodata,
        dtype=dtype,
        name=da0.name,
        attrs=da0.attrs,
        crs=dst_crs,
    )
    ys = da_out.raster.ycoords.values
    xs = da_out.raster.xcoords.values
    dest = da_out.values
    kwargs.update(dst_crs=dst_crs, dst_res=dst_res, align=True)
    sf_lst = []
    for _, da in enumerate(data_arrays):
        if not da.raster.aligned_grid(da_out):
            # clip with buffer
            src_bbox = da_out.raster.transform_bounds(da.raster.crs)
            da = da.raster.clip_bbox(src_bbox, buffer=10)
            if np.any([da[dim].size == 0 for dim in da.raster.dims]):
                continue  # out of bounds
            # reproject
            da = da.raster.reproject(**kwargs)
        # clip to dst_bounds
        da = da.raster.clip_bbox(dst_bounds)
        if np.any([da[dim].size == 0 for dim in da.raster.dims]):
            continue  # out of bounds
        if "source_file" in da.attrs:
            sf_lst.append(da.attrs["source_file"])
        # merge overlap
        w0, s0, e0, n0 = da.raster.bounds
        if y_res < 0:
            top = np.nonzero(ys <= n0)[0][0] if n0 < ys[0] else 0
            bottom = np.nonzero(ys < s0)[0][0] if s0 > ys[-1] else None
        else:
            top = np.nonzero(ys > n0)[0][0] if n0 < ys[-1] else 0
            bottom = np.nonzero(ys >= s0)[0][0] if s0 > ys[0] else None
        left = np.nonzero(xs >= w0)[0][0] if w0 > xs[0] else 0
        right = np.nonzero(xs > e0)[0][0] if e0 < xs[-1] else None
        y_slice = slice(top, bottom)
        x_slice = slice(left, right)
        region = dest[y_slice, x_slice]
        temp = np.ma.masked_equal(da.values.astype(dtype), nodata)
        if isnan:
            region_nodata = np.isnan(region)
            temp_nodata = np.isnan(temp)
        else:
            region_nodata = region == nodata
            temp_nodata = temp.mask
        copyto(region, temp, region_nodata, temp_nodata)
    # set attrs
    da_out.attrs.update(source_file="; ".join(sf_lst))
    da_out.raster.set_crs(dst_crs)
    return da_out.raster.reset_spatial_dims_attrs()