Source code for dpyverification.configuration.default.scores

"""A module for default implementation of scores."""

from enum import Enum
from typing import Annotated, Literal

from pydantic import BaseModel, Field, RootModel

from dpyverification.configuration.base import (
    BaseCategoricalScoreConfig,
    BaseEvent,
    BaseScoreConfig,
)
from dpyverification.constants import (
    ScoreKind,
    StandardDim,
    SupportedCategoricalScores,
    SupportedContinuousScore,
)


class EventOperator(Enum):
    """Event operators."""

    GREATER_THAN = "greater_than"
    LESS_THAN = "less_than"
    GREATER_THAN_OR_EQUAL_TO = "greater_than_or_equal_to"
    LESS_THAN_OR_EQUAL_TO = "less_than_or_equal_to"


class ReduceDimsForecast(BaseModel):
    """The dimensions over which a forecast can be reduced."""

    reduce_dims: Annotated[
        list[
            Literal[
                StandardDim.station,
                StandardDim.forecast_reference_time,
                StandardDim.forecast_period,
            ]
        ],
        Field(default_factory=list),
    ]

    @property
    def preserve_dims(self) -> list[StandardDim]:
        """The dimensions to preserve."""
        return [
            k
            for k in [
                StandardDim.variable,
                StandardDim.station,
                StandardDim.forecast_reference_time,
                StandardDim.forecast_period,
            ]
            if k not in self.reduce_dims
        ]


class IdMap(RootModel[dict[str, dict[str, str]]]):
    """Mapping from internal IDs to external IDs per data source."""

    def get_external_to_internal_mapping(self, data_source: str) -> dict[str, str]:
        """Return external → internal mapping for this data source."""
        return {v[data_source]: k for k, v in self.root.items()}


[docs] class RankHistogramConfig(BaseScoreConfig, ReduceDimsForecast): """A rank histogram config element.""" score_adapter: Literal[ScoreKind.rank_histogram]
[docs] class CrpsForEnsembleConfig(BaseScoreConfig, ReduceDimsForecast): """Configuration for CRPS for ensemble. For reference, see: See: https://scores.readthedocs.io/en/stable/api.html#scores.probability.crps_for_ensemble """ score_adapter: Literal[ScoreKind.crps_for_ensemble] method: Annotated[ Literal["ecdf", "fair"], Field( description="Method to compute the cumulative distribution function from an ensemble.", default="ecdf", ), ]
[docs] class CrpsCDFConfig(BaseScoreConfig, ReduceDimsForecast): """Configuration for CRPS for CDF. For reference, see: https://scores.readthedocs.io/en/stable/api.html#scores.probability.crps_cdf """ score_adapter: Literal[ScoreKind.crps_cdf] integration_method: Annotated[ Literal["exact", "trapz"], Field( description="The method of integration. 'exact' computes the exact integral, " "'trapz' uses a trapezoidal rule and is an approximation of the CRPS.", ), ] = "exact"
[docs] class ContinuousScoresConfig(BaseScoreConfig, ReduceDimsForecast): """Configure multiple continuous scores.""" score_adapter: Literal[ScoreKind.continuous_scores] scores: list[SupportedContinuousScore]
class ThresholdEvent(BaseEvent): """An event definition for a threshold.""" threshold: Annotated[ str, Field(description="Threshold id to use in event definition."), ] operator: Annotated[ EventOperator, Field(description="The operator to use for creating the events."), ]
[docs] class CategoricalScoresConfig(BaseCategoricalScoreConfig, ReduceDimsForecast): """Config to compute categorical scores, based on an event definition.""" score_adapter: Literal[ScoreKind.categorical_scores] scores: Annotated[ list[SupportedCategoricalScores], Field( description="For reference, see: https://scores.readthedocs.io/en/stable/api.html#module-scores.categorical.", ), ] events: Annotated[ list[ThresholdEvent], Field( description="A list of threshold event definitions. For each event, a categorical " "score will be computed. A threshold event is defined by a threshold and an operator. " "The threshold is a string that corresponds to a threshold id defined in the " "configuration. The operator defines how the threshold is applied to the data to " "create the event. For example, if the threshold is '10' and the operator is " "'greater_than', the event will be created by applying the operator to the data " "and the threshold, i.e. data > 10. ", ), ] return_contingency_table: Annotated[ bool, Field(description="Whether to return the contingency table in the output."), ] = True