"""
Categorical scores, based on a 2x2 contingency table.
For verification of non-probabilistic forecasts, and historical simulations of discrete variables.
For references, see: https://scores.readthedocs.io/en/stable/included.html#categorical.
"""
import operator
from collections.abc import Callable
from enum import StrEnum
from typing import ClassVar
import xarray as xr
from scores.categorical import ( # type:ignore[import-untyped]
BasicContingencyManager,
BinaryContingencyManager,
)
from dpyverification.configuration.default.scores import (
BaseEvent,
CategoricalScoresConfig,
EventOperator,
ThresholdEvent,
)
from dpyverification.constants import DataType, SupportedCategoricalScores
from dpyverification.scores.base import BaseCategoricalScore
__all__ = [
"CategoricalScores",
"CategoricalScoresConfig",
"create_binary_array",
]
def get_categorical_score(score_name: SupportedCategoricalScores) -> type:
"""Get a categorical score from the scores package."""
return getattr(BasicContingencyManager, score_name.value) # type:ignore[no-any-return, misc]
def get_event_operator(
operator_name: EventOperator,
) -> Callable[[xr.DataArray, xr.DataArray], xr.DataArray]:
"""Get an event operator function based on the operator name."""
if operator_name == EventOperator.GREATER_THAN:
return operator.gt # type:ignore[misc]
if operator_name == EventOperator.LESS_THAN:
return operator.lt # type:ignore[misc]
if operator_name == EventOperator.GREATER_THAN_OR_EQUAL_TO:
return operator.ge # type:ignore[misc]
if operator_name == EventOperator.LESS_THAN_OR_EQUAL_TO:
return operator.le # type:ignore[misc]
msg = f"Unsupported operator: {operator_name}" # type:ignore[unreachable] # runtime check
raise ValueError(msg)
class CategoricalScoreDim(StrEnum):
"""Names of dimensions added when computing a categorical score."""
EVENT_THRESHOLD = "event_threshold"
EVENT_OPERATOR = "event_operator"
[docs]
def create_binary_array(
data: xr.DataArray,
thresholds: xr.DataArray,
operator: Callable[[xr.DataArray, xr.DataArray], xr.DataArray],
) -> xr.DataArray:
"""Given data and thresholds, compute the binary events."""
# Align along dimension station
data_aligned, thresholds_aligned = xr.align(data, thresholds, join="inner")
result = operator(data_aligned, thresholds_aligned)
if isinstance(result, xr.DataArray): # type:ignore[misc]
return result
msg = "Failed to create a binary xr.DataArray based on data and thresholds." # type:ignore[unreachable] # runtime check
raise ValueError(msg)
def set_event_coordinates_on_result(
data_array: xr.Dataset,
threshold: str,
operator: EventOperator,
) -> xr.Dataset:
"""Set coordinates on data array to represent the event for which a score was computed."""
data_array = data_array.expand_dims(
{CategoricalScoreDim.EVENT_THRESHOLD: 1, CategoricalScoreDim.EVENT_OPERATOR: 1},
)
return data_array.assign_coords(
{ # type:ignore[misc]
CategoricalScoreDim.EVENT_OPERATOR: [operator.name], # type:ignore[misc]
CategoricalScoreDim.EVENT_THRESHOLD: [threshold], # type:ignore[misc]
},
)
[docs]
class CategoricalScores(BaseCategoricalScore):
"""Categorical scores, based on the 2x2 contingency table.
For reference: https://scores.readthedocs.io/en/stable/included.html#categorical
"""
kind = "categorical_scores"
config_class = CategoricalScoresConfig
supported_data_types: ClassVar[set[DataType]] = {
DataType.simulated_forecast_single,
}
def __init__(self, config: CategoricalScoresConfig) -> None:
self.config: CategoricalScoresConfig = config
[docs]
def compute_score_for_single_event(
self,
obs: xr.DataArray,
sim: xr.DataArray,
thresholds: xr.DataArray,
event: BaseEvent,
) -> xr.Dataset | xr.DataArray:
"""Compute any number of categorical scores for a single event."""
if not isinstance(event, ThresholdEvent):
msg = f"Unsupported event type: {type(event)}. Expected a ThresholdEvent."
raise TypeError(msg)
operator_func = get_event_operator(event.operator)
obs_events = create_binary_array(
obs,
thresholds=thresholds,
operator=operator_func,
)
sim_events = create_binary_array(
sim,
thresholds=thresholds,
operator=operator_func,
)
binary_contingency_manager = BinaryContingencyManager( # type:ignore[misc]
fcst_events=sim_events,
obs_events=obs_events,
)
basic_contingency_manager = binary_contingency_manager.transform( # type:ignore[misc]
preserve_dims=self.config.preserve_dims,
)
scores = []
for score in self.config.scores:
score_func = get_categorical_score(score)
score_array = score_func(basic_contingency_manager) # type:ignore[misc]
score_array.name = str(score.value) # type:ignore[misc]
scores.append(score_array) # type:ignore[misc]
if self.config.return_contingency_table is True:
table: xr.DataArray = basic_contingency_manager.get_table() # type:ignore[misc]
table.name = "contingency_table"
scores.append(table) # type:ignore[misc]
merged_scores: xr.Dataset = xr.merge(scores) # type:ignore[misc, assignment]
return set_event_coordinates_on_result(
merged_scores,
threshold=event.threshold,
operator=event.operator,
)