"""Vulnerability workflows."""
import logging
from typing import Any
import numpy as np
import numpy.typing as npt
import pandas as pd
from barril.units import Scalar
from hydromt_fiat.utils import (
    CURVE,
    CURVE_ID,
    EXPOSURE_LINK,
    EXPOSURE_TYPE,
    create_query,
    standard_unit,
)
__all__ = ["process_vulnerability_linking", "vulnerability_curves"]
logger = logging.getLogger(f"hydromt.{__name__}")
def process_vulnerability_linking(
    types: list[str] | tuple[str] | npt.NDArray[np.str_],
    vulnerability_linking: pd.DataFrame | None = None,
) -> pd.DataFrame:
    """Process the vulnerability linking table.
    Is created based on the vulnerability data if no initial table is provided.
    Parameters
    ----------
    types : list | tuple | np.ndarray,
        Types of vulnerability curves.
    vulnerability_linking : pd.DataFrame, optional
        The vulnerability linking table, by default None.
    Returns
    -------
    pd.DataFrame
        Vulnerability linking table.
    """
    # Construct if not provided
    if vulnerability_linking is None:
        logger.warning("No linking table provided, inferred from vulnerability data")
        vulnerability_linking = pd.DataFrame(
            data={
                EXPOSURE_LINK: types,
                CURVE: types,
            }
        )
    # Drop completely duplicate rows
    vulnerability_linking.drop_duplicates(inplace=True)
    if CURVE not in vulnerability_linking:
        raise KeyError("The 'curve' column in not present in the linking table")
    if EXPOSURE_TYPE not in vulnerability_linking:  # default to damage
        vulnerability_linking[EXPOSURE_TYPE] = "damage"
    # Query the linking data
    vulnerability_linking.loc[:, CURVE_ID] = vulnerability_linking[CURVE]
    types = list(types)  # Ensure list type for the query
    vulnerability_linking = vulnerability_linking.query(f"curve in {str(types)}")
    return vulnerability_linking
[docs]
def vulnerability_curves(
    vulnerability_data: pd.DataFrame,
    vulnerability_linking: pd.DataFrame | None = None,
    *,
    unit: str = "m",
    index_name: str = "water depth",
    column_oriented: bool = True,
    **select: dict[str, Any],
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Create vulnerability curves from raw data.
    Warning
    -------
    If not default exposure type is present in the vulnerability linking, the
    default exposure type is assumed to be 'damage'.
    Parameters
    ----------
    vulnerability_data : pd.DataFrame
        The raw vulnerability dataset.
    vulnerability_linking : pd.DataFrame, optional
        The vulnerability linking table, by default None.
    unit : str, optional
        The unit of the vulnerability dataset index, by default "m".
    index_name : str, optional
        The name of the outgoing vulnerability curves dataset index,
        by default "water depth".
    column_oriented : bool, optional
        Whether the vulnerability data is column oriented, i.e. the values of a curve
        are in the same column spanning multiple rows. If False, the values are ought
        to be in the same row spanning multiple columns. By default True.
    **select : dict, optional
        Keyword arguments to select data from 'vulnerability_data'.
    Returns
    -------
    tuple[pd.DataFrame]
        A tuple containing the the vulnerability curves and updated link table.
    """
    # Transpose the data if columns oriented
    if column_oriented:
        vulnerability_data = vulnerability_data.transpose()
        vulnerability_data.columns = vulnerability_data.iloc[0]
        vulnerability_data.drop(0, inplace=True)
    # Quick check on the data
    if CURVE not in vulnerability_data:
        raise KeyError("The 'curve' column in not present in the vulnerability data")
    # Build a query from the index kwargs
    if len(select) != 0:
        query = create_query(**select)
        vulnerability_data = vulnerability_data.query(query)
    # Sort the linking table
    vulnerability_linking = process_vulnerability_linking(
        types=vulnerability_data[CURVE].values,
        vulnerability_linking=vulnerability_linking,
    )
    # Set a separate column with the curve id's for merging
    vulnerability_data = pd.merge(
        vulnerability_data,
        vulnerability_linking.drop_duplicates(subset=CURVE_ID),
        on=CURVE,
        how="inner",
        validate="many_to_many",
    )
    # Reshape the vulnerability data
    columns = list(set(list(select.keys()) + vulnerability_linking.columns.to_list()))
    columns.remove(CURVE_ID)
    vulnerability_data = vulnerability_data.drop(columns, axis=1)
    vulnerability_data = vulnerability_data.transpose()
    vulnerability_data = vulnerability_data.rename(
        columns=vulnerability_data.loc[CURVE_ID]
    )
    vulnerability_data = vulnerability_data.drop(CURVE_ID)
    vulnerability_data.index.name = index_name
    # Again query the linking table based on the vulnerability curves
    # But this time on the curve ID
    types = vulnerability_data.columns.tolist()
    vulnerability_linking = vulnerability_linking.query(f"curve_id in {str(types)}")
    # At last reset the index
    vulnerability_data.reset_index(inplace=True)
    vulnerability_data = vulnerability_data.astype(float)
    # Scale the data according to the unit
    conversion = standard_unit(Scalar(1.0, unit))
    vulnerability_data[index_name] *= conversion.value
    return vulnerability_data, vulnerability_linking