Source code for dpyverification.datasources.base

"""Module with the base class that all datasources should inherit from."""

import hashlib
from abc import abstractmethod
from os import R_OK, access
from typing import ClassVar, Self

import xarray
import xarray as xr

from dpyverification.base import Base
from dpyverification.configuration.base import (
    BaseDatasourceConfig,
)
from dpyverification.configuration.utils import TimePeriod
from dpyverification.constants import FORECAST_DATA_TYPES, DataType, StandardDim

__all__ = [
    "BaseDatasource",
    "BaseDatasourceConfig",
]


[docs] class BaseDatasource(Base): """Class to inherit from, defines the required methods and attributes.""" kind: str = "" config_class: type[BaseDatasourceConfig] = BaseDatasourceConfig supported_data_types: ClassVar[set[DataType]] = set() def __init__(self, config: BaseDatasourceConfig) -> None: self.config: BaseDatasourceConfig = config self.data_type = config.data_type self.data_array = xarray.DataArray() @property def data_type(self) -> str: """Whether the instance represents sim or obs data.""" return self.config.data_type @data_type.setter def data_type(self, new_data_type: DataType) -> None: if new_data_type not in self.supported_data_types: msg = ( f"Data type '{new_data_type}' is not supported ", f"by {self.__class__.__name__}", ) raise NotImplementedError(msg) self._data_type = new_data_type
[docs] @abstractmethod def fetch_data(self) -> Self: """Fetch data from datasource."""
@staticmethod def _drop_times_outside_vp( da: xr.DataArray, verification_period_on_time: TimePeriod, ) -> xr.DataArray: """Mask times outside of verification period with inclusive endpoints.""" # Mask values outside of verification period filtered = da.where( (da[StandardDim.time] >= verification_period_on_time.start_datetime64) & (da[StandardDim.time] <= verification_period_on_time.end_datetime64), ) # Drop NaN values along frt and fp dims, if all values are NaN return filtered.dropna(dim=StandardDim.forecast_reference_time, how="all").dropna( dim=StandardDim.forecast_period, how="all", )
[docs] def get_data(self) -> Self: """Get cached data, or fetch and cache.""" config_json = self.config.model_dump_json().encode("utf-8") config_hash = hashlib.sha256(config_json).hexdigest() cache_dir = self.config.general.cache_dir # Create cache if not exists if not cache_dir.exists(): cache_dir.mkdir(parents=True) # If it exists, check it's an accessible dir elif not cache_dir.is_dir() and access(cache_dir, R_OK): msg = "Cache directory is not an accessible directory." raise NotADirectoryError(msg) # Define file path for caching cached_data_array_path = cache_dir / f"{self.__class__.__name__}_{config_hash}.nc" if cached_data_array_path.exists(): self.data_array = xr.open_dataarray(cached_data_array_path) return self # Go fetch and cache self.fetch_data() data_array_original = self.data_array # Check that the datatype is defined, and consistent with the config if "data_type" not in data_array_original.attrs: # type:ignore[misc] msg = "The fetched data array does not have a 'data_type' attribute." raise ValueError(msg) if data_array_original.attrs["data_type"] != self.config.data_type: # type:ignore[misc] msg = ( f"The data type of the fetched data array " f"({data_array_original.attrs['data_type']}) does not match the configured data " # type:ignore[misc] f"type ({self.config.data_type})." ) raise ValueError(msg) # Make sure the name of the array is set to the configured source data_array_original.name = self.config.source # Apply re-naming based on configured id mapping, if not None if self.config.id_mapping is not None: data_array_original = self.config.id_mapping.rename_data_array(data_array_original) # Additional layer to filter time, frt and fp properly according to config. if data_array_original.attrs["data_type"] in FORECAST_DATA_TYPES: # type:ignore[misc] # Select only relevant forecast periods for simulations data_array_original = data_array_original.sel( forecast_period=self.config.forecast_periods.timedelta64, ) # Mask and drop time values outside of the configured vp data_array_original = self._drop_times_outside_vp( da=data_array_original, verification_period_on_time=self.config.verification_period_on_time, ) if data_array_original.attrs["data_type"] == DataType.observed_historical: # type:ignore[misc] # Mask and drop time values outside of the configured vp # Historical data type data_array_original = data_array_original.sel( { StandardDim.time: slice( # type:ignore[misc] self.config.verification_period_on_time.start, self.config.verification_period_on_time.end, ), }, ) # Cache data_array_original.to_netcdf(cached_data_array_path) # Re-open to read from cache and prevent links to original files from which the dataarray # was loaded data_array_reloaded = xr.open_dataarray(cached_data_array_path) # Explicitly close original backing files if hasattr(data_array_original, "close"): data_array_original.close() # Re-assign from cache self.data_array = data_array_reloaded return self