Spaces:
Sleeping
Sleeping
from collections import defaultdict | |
from typing import TypeVar, Type, Dict, List | |
import importlib | |
import logging | |
logger = logging.getLogger("toolbox") | |
T = TypeVar("T") | |
class Registrable(object): | |
_registry: Dict[Type, Dict[str, Type]] = defaultdict(dict) | |
default_implementation: str = None | |
register_name: str = "unknown" | |
def register(cls: Type[T], name: str, exist_ok=False): | |
registry = Registrable._registry[cls] | |
def add_subclass_to_registry(subclass: Type[T]): | |
# set a name on the subclass | |
setattr(subclass, "register_name", name) | |
if name in registry: | |
if exist_ok: | |
message = (f"{name} has already been registered as {registry[name].__name__}, but " | |
f"exist_ok=True, so overwriting with {cls.__name__}") | |
# logger.info(message) | |
else: | |
message = (f"Cannot register {name} as {cls.__name__}; " | |
f"name already in use for {registry[name].__name__}") | |
raise ValueError(message) | |
registry[name] = subclass | |
return subclass | |
return add_subclass_to_registry | |
def by_name(cls: Type[T], name: str) -> Type[T]: | |
# logger.info(f"instantiating registered subclass {name} of {cls}") | |
if name in Registrable._registry[cls]: | |
return Registrable._registry[cls].get(name) | |
else: | |
raise ValueError( | |
f"{name} is not a registered name for {cls.__name__}. " | |
f"the available is: [{Registrable._registry[cls].keys()}]" | |
) | |
def list_available(cls) -> List[str]: | |
keys = list(Registrable._registry[cls].keys()) | |
default = cls.default_implementation | |
if default is None: | |
return keys | |
elif default not in keys: | |
message = "Default implementation %s is not registered" % default | |
raise ValueError(message) | |
else: | |
return [default] + [k for k in keys if k != default] | |