Source code for hydromt.models.model_plugins
"""Implementation of the mechanism to access the plugin entrypoints."""
import logging
from typing import Dict, Iterator, List
from hydromt._compat import Distribution, EntryPoint, EntryPoints, entry_points
from .. import __version__, _compat
from .model_api import Model
logger = logging.getLogger(__name__)
__all__ = ["ModelCatalog"]
# local generic models
LOCAL_EPS = {
"grid_model": "hydromt.models.model_grid:GridModel",
"vector_model": "hydromt.models.model_vector:VectorModel",
"mesh_model": "hydromt.models.model_mesh:MeshModel",
"network_model": "hydromt.models.model_network:NetworkModel",
}
def get_general_eps() -> Dict:
"""Get local hydromt generic model class entrypoints.
Returns
-------
eps : dict
Entrypoints dict
"""
eps = {}
dist = Distribution.from_name("hydromt")
for name, epstr in LOCAL_EPS.items():
if name == "mesh_model" and not _compat.HAS_XUGRID:
continue
eps[name] = EntryPoint(name=name, value=epstr, group="hydromt.models")
eps[name]._for(dist) # add distribution info
return eps
def _discover() -> EntryPoints:
"""Discover drivers via entrypoints."""
return entry_points(group="hydromt.models")
def get_plugin_eps(logger=logger) -> Dict:
"""Discover hydromt model plugins based on 'hydromt.models' entrypoints.
Parameters
----------
path : str or None
Default is ``sys.path``.
logger : logger object, optional
The logger object used for logging messages. If not provided, the default
logger will be used.
Returns
-------
eps : dict
Entrypoints dict
"""
eps = {}
for ep in list(_discover()):
name = ep.name
if name in eps or name in LOCAL_EPS:
plugin = f"{ep.module}.{ep.value}"
logger.warning(f"Duplicated model plugin '{name}'; skipping {plugin}")
continue
if ep.dist:
dist_version = ep.dist.name
else:
dist_version = __version__
logger.debug(
f"Discovered model plugin '{name} = {ep.value}' " f"({dist_version})"
)
eps[ep.name] = ep
return eps
def load(ep, logger=logger) -> Model:
"""Load entrypoint and return plugin model class.
Parameters
----------
ep : entrypoint
discovered entrypoint
logger : logger object, optional
The logger object used for logging messages. If not provided, the default
logger will be used.
Returns
-------
model_class : Model
plugin model class
"""
_str = f"{ep.name} = {ep.value}"
try:
model_class = ep.load()
if not issubclass(model_class, Model):
raise ValueError(f"Model plugin type not recognized '{_str}'")
logger.debug(f"Loaded model plugin {_str}")
return model_class
except (ModuleNotFoundError, AttributeError) as err:
raise ImportError(f"Error while loading model plugin '{_str}' ({err})")
[docs]
class ModelCatalog:
"""The model catalogue provides access to plugins and their Model classes."""
[docs]
def __init__(self):
"""Initiate the catalog object."""
self._eps = {} # entrypoints
self._cls = {} # classes
self._plugins = [] # names of plugins
self._general = [] # names of local model classes
@property
def eps(self) -> Dict:
"""Return dictionary with available model entrypoints."""
if len(self._eps) == 0:
_ = self.plugins # discover plugins
_ = self.generic # get generic local model classes
return self._eps
@property
def cls(self) -> Dict:
"""Return dictionary with available model classes."""
if len(self._cls) != len(self.eps):
for name in self.eps:
if name not in self._cls:
self._cls[name] = load(self.eps[name])
return self._cls
@property
def plugins(self) -> List:
"""Return list with names of model plugins."""
if len(self._plugins) == 0:
eps = get_plugin_eps()
self._plugins = list(eps.keys())
self._eps.update(**eps)
return self._plugins
@property
def generic(self) -> List:
"""Return list with names of generic models."""
if len(self._general) == 0:
eps = get_general_eps()
self._general = list(eps.keys())
self._eps.update(**eps)
return self._general
def load(self, name) -> Model:
"""Return model class."""
if name not in self._cls:
self._cls[name] = load(self[name])
return self._cls[name]
def __str__(self):
"""Generate string representation containing the registered entrypoints."""
plugins = "".join(
[
f" - {name} ({self.eps[name].dist.name}"
f" {self.eps[name].dist.version})\n"
for name in self.plugins
]
)
generic = "".join([f" - {name}\n" for name in self.generic])
return (
f"model plugins:\n{plugins}generic models (hydromt {__version__})"
f":\n{generic}"
)
def __getitem__(self, name) -> Model:
"""Return the entrypoint with the provided name."""
if name not in self.eps:
raise ValueError(f"Unknown model {name}; select from {self.eps.keys()}")
return self._eps[name]
def __iter__(self) -> Iterator:
"""Return an iterator over registered entrypoints."""
return iter(self.eps)