Source code for dpyverification.datamodel.main

"""Module with the dpyverification internal DataModel."""

from collections.abc import Iterable

import xarray as xr
from pydantic import ValidationError

from dpyverification.configuration.utils import VerificationPair
from dpyverification.constants import (
    FORECAST_DATA_TYPES,
    HISTORICAL_DATA_TYPES,
    DataType,
    StandardDim,
)
from dpyverification.datasources.inputschemas import INPUT_SCHEMAS

__all__ = ["InputDataset", "OutputDataset"]


@xr.register_dataarray_accessor("verification")  # type:ignore[no-untyped-call, misc]
class InputDataArrayExtension:
    """xr.DataArray representing specific data type.

    xr.register_dataset_accessor is the recommended way to extend xr.DataArray.
    see: https://docs.xarray.dev/en/stable/internals/extending-xarray.html. It's used there to
    extend the input data arrays so we can directly access properties (like data type and
    source) and the validation method that checks the input array against a schema.
    """

    def __init__(self, xarray_obj: xr.DataArray) -> None:
        self._obj = xarray_obj

    @property
    def data_type(self) -> str:
        """The data type of the array."""
        if "data_type" not in self._obj.attrs:  # type:ignore[misc]
            msg = f"No data type set on {self._obj} attrs."
            raise ValueError(msg)
        return DataType(self._obj.attrs["data_type"])  # type:ignore[misc]

    @property
    def is_thresholds(self) -> bool:
        """Boolean indicating this array is a thresholds array."""
        return self.data_type == DataType.threshold

    @property
    def is_historical(self) -> bool:
        """Boolean indicating this array is a historical."""
        return self.data_type in HISTORICAL_DATA_TYPES

    @property
    def is_forecast(self) -> bool:
        """Boolean indicating this array is a forecast."""
        return self.data_type in FORECAST_DATA_TYPES

    @property
    def source(self) -> str:
        """The source name."""
        return str(self._obj.name)

    def validate(self) -> None:
        """Validate the data according to schema."""
        schema = INPUT_SCHEMAS[self.data_type]  # type:ignore[index] # str is compatible with StrEnum index

        try:
            schema.model_validate(self._obj.to_dict(data=False))  # type:ignore[misc]
        except ValidationError as exc:
            msg = (f"Validation failed for data_type '{self.data_type}'.\n{exc}",)
            raise ValueError(msg) from exc


[docs] class InputDataset: """ Class containing simulations and observations. SimObsDataset has functionality to retrieve verification pairs for computation of scores per pair. It is the central object used in the verification pipeline. """ def __init__( self, data: Iterable[xr.DataArray], ) -> None: """Initialize the InputDataset. Validates each input data array against a schema and collects all input data into a dictionary, keyed by the source and valued by the xr.DataArray. """ self.datastore: dict[str, xr.DataArray] = {} # Validate, and add to datastore for data_array in data: data_array.verification.validate() # type:ignore[misc] self.datastore[data_array.verification.source] = data_array # type:ignore[misc]
[docs] @staticmethod def map_historical_into_forecast_space( obs: xr.DataArray, sim: xr.DataArray, ) -> xr.DataArray: """ Transform array of historical data into forecast structure. Given an observation array with dimension 'time' and a simulation array with dimensions 'forecast_reference_time' and 'forecast_period', project the observed values onto the simulation array. This method is called at runtime when the pipeline starts a score computation on forecast data. On the fly, the observation array is mapped to the forecast structure, so data are aligned along the same dimensions. """ # Stack forecast time axes stacked_time = sim[StandardDim.time].stack( # type:ignore[misc] z=(StandardDim.forecast_reference_time, StandardDim.forecast_period), ) # Reindex observations onto stacked forecast times obs_aligned = obs.reindex( time=stacked_time.to_numpy(), # type:ignore[misc] ) # Attach forecast coordinates explicitly (from the MultiIndex) z_index = stacked_time.indexes["z"] # type:ignore[misc] # Assign forecast_reference_time and forecast_period coordinates to the aligned # observations, based on the MultiIndex of the stacked time dimension. This is # necessary because after re-indexing, the original time dimension of the observations # is now aligned with the stacked time dimension of the simulations, which has a MultiIndex # of forecast_reference_time and forecast_period. obs_aligned = obs_aligned.assign_coords( forecast_reference_time=( # type:ignore[misc] StandardDim.time, z_index.get_level_values(StandardDim.forecast_reference_time), # type:ignore[misc] ), forecast_period=( # type:ignore[misc] StandardDim.time, z_index.get_level_values(StandardDim.forecast_period), # type:ignore[misc] ), ) # Set the time coordinate to be the stacked time (MultiIndex of forecast_reference_time and # forecast_period) obs_indexed = obs_aligned.set_index( time=(StandardDim.forecast_reference_time, StandardDim.forecast_period), ) # Unstack into forecast space obs_projected = obs_indexed.unstack(StandardDim.time) # Preserve attrs obs_projected.attrs = obs.attrs return obs_projected
[docs] def get_pair( self, verification_pair: VerificationPair, ) -> tuple[xr.DataArray, xr.DataArray]: """Return observations and simulations for a given verification pair. This method is called by the verification pipeline at runtime to retrieve the correct data for one of the configured verification pairs. """ obs = self.datastore[verification_pair.obs] sim = self.datastore[verification_pair.sim] if sim.verification.is_forecast: # type:ignore[misc] # Map historical into forecast space upon score computation return self.map_historical_into_forecast_space(obs, sim), sim # If the simulation is not a forecast, it is a historical data type (an observation or # historical simulation). In this case: verify along dimension 'time' instead of mapping # data into forecast space. return obs, sim
[docs] def get_thresholds_array(self) -> xr.DataArray: """Get the thresholds array from the input dataset.""" for data_array in self.datastore.values(): if data_array.verification.is_thresholds: # type:ignore[misc] return data_array msg = ( "No thresholds array found in the input dataset, but required for computing " "categorical scores." ) raise ValueError(msg)
[docs] class OutputDataset: """The internal output dataset. Contains input data, results from verification scores and metadata. """ def __init__( self, input_dataset: InputDataset, ) -> None: self.input_dataset = input_dataset # Internal datastore that stores results of score computation in a dictionary where the # key represent the pair_id of the VerificationPair and the value is an xr.Dataset that # contains all results from varying scores for that pair. self.datastore: dict[str, xr.Dataset] = {}
[docs] def add_score(self, score: xr.DataArray | xr.Dataset, verification_pair_id: str) -> None: """Add a score results to the datastore.""" # Convert to xr.Dataset if isinstance(score, xr.DataArray): # type:ignore[misc] score = score.to_dataset() # Add to the store, if not added before if verification_pair_id not in self.datastore: self.datastore[verification_pair_id] = score # Pair has added data to the datastore before, so merge else: self.datastore[verification_pair_id] = xr.merge( [self.datastore[verification_pair_id], score], # type:ignore[list-item, assignment] )
[docs] def get_output_dataset( self, verification_pair: VerificationPair, *, include_input_data: bool = True, ) -> xr.Dataset: """Get the output dataset for a given verification pair.""" if verification_pair.id in self.datastore: # Get the results for this pair dataset = self.datastore[verification_pair.id] if include_input_data: # Return results, include the input dataset obs, sim = self.input_dataset.get_pair(verification_pair) return xr.merge([obs, sim, dataset]) # type:ignore[list-item, return-value] # Return results, exclude input dataset return dataset # Return only input dataset (no results found in datastore) obs, sim = self.input_dataset.get_pair(verification_pair) return xr.merge([obs, sim]) # type:ignore[list-item, return-value]