Source code for dpyverification.scores.base
"""An abstract implementation of a calculation."""
from abc import abstractmethod
from typing import ClassVar
import xarray as xr
from dpyverification.base import Base
from dpyverification.configuration.base import (
BaseCategoricalScoreConfig,
BaseEvent,
BaseScoreConfig,
)
from dpyverification.constants import DataType
__all__ = ["BaseScore", "BaseScoreConfig"]
[docs]
class BaseScore(Base):
"""An abstract calculation class."""
kind = "" # to be defined by subclasses
config_class: type[BaseScoreConfig] = BaseScoreConfig # to be defined by subclasses
supported_data_types: ClassVar[set[DataType]] = set()
def __init__(self, config: BaseScoreConfig) -> None:
self.config: BaseScoreConfig = config
[docs]
@abstractmethod
def compute(
self,
obs: xr.DataArray,
sim: xr.DataArray,
) -> xr.DataArray | xr.Dataset:
"""Abstract calculation."""
[docs]
def validate_and_compute(
self,
obs: xr.DataArray,
sim: xr.DataArray,
) -> xr.DataArray | xr.Dataset:
"""Validate and compute."""
data_type: DataType = sim.verification.data_type # type:ignore[misc]
if data_type not in self.supported_data_types:
msg = f"The data type '{data_type} is not supported by"
f"{self.__class__.__name__}. Supported types: "
f"{sorted(self.supported_data_types)}."
raise ValueError(msg)
result = self.compute(obs, sim)
if isinstance(result, xr.DataArray) and result.name is None: # type:ignore[misc]
result.name = self.config.score_adapter
return result
class BaseCategoricalScore(Base):
"""An abstract calculation class for categorical scores."""
kind = "" # to be defined by subclasses
config_class: type[BaseCategoricalScoreConfig] = (
BaseCategoricalScoreConfig # to be defined by subclasses
)
supported_data_types: ClassVar[set[DataType]] = set()
def __init__(self, config: BaseCategoricalScoreConfig) -> None:
self.config: BaseCategoricalScoreConfig = config
@abstractmethod
def compute_score_for_single_event(
self,
obs: xr.DataArray,
sim: xr.DataArray,
thresholds: xr.DataArray,
event: BaseEvent,
) -> xr.DataArray | xr.Dataset:
"""Abstract calculation."""
def validate_and_compute(
self,
obs: xr.DataArray,
sim: xr.DataArray,
thresholds: xr.DataArray,
) -> xr.DataArray | xr.Dataset:
"""Validate and compute."""
data_type: DataType = sim.verification.data_type # type:ignore[misc]
if data_type not in self.supported_data_types:
msg = f"The data type '{data_type} is not supported by"
f"{self.__class__.__name__}. Supported types: "
f"{sorted(self.supported_data_types)}."
raise ValueError(msg)
results: list[xr.DataArray | xr.Dataset] = []
for event in self.config.events:
if not isinstance(event, BaseEvent):
msg = f"Unsupported event type: {type(event)}. Expected a BaseEvent." # type:ignore[unreachable] # runtime check
raise TypeError(msg)
result_for_a_single_event = self.compute_score_for_single_event(
obs,
sim,
thresholds=thresholds,
event=event,
)
results.append(result_for_a_single_event)
result = xr.combine_by_coords(results)
if isinstance(result, xr.DataArray) and result.name is None: # type:ignore[misc]
result.name = self.config.score_adapter
return result