HoneyTian's picture
first commit
e94100d
raw
history blame
2.14 kB
from collections import defaultdict
from typing import TypeVar, Type, Dict, List
import importlib
import logging
logger = logging.getLogger("toolbox")
T = TypeVar("T")
class Registrable(object):
_registry: Dict[Type, Dict[str, Type]] = defaultdict(dict)
default_implementation: str = None
register_name: str = "unknown"
@classmethod
def register(cls: Type[T], name: str, exist_ok=False):
registry = Registrable._registry[cls]
def add_subclass_to_registry(subclass: Type[T]):
# set a name on the subclass
setattr(subclass, "register_name", name)
if name in registry:
if exist_ok:
message = (f"{name} has already been registered as {registry[name].__name__}, but "
f"exist_ok=True, so overwriting with {cls.__name__}")
# logger.info(message)
else:
message = (f"Cannot register {name} as {cls.__name__}; "
f"name already in use for {registry[name].__name__}")
raise ValueError(message)
registry[name] = subclass
return subclass
return add_subclass_to_registry
@classmethod
def by_name(cls: Type[T], name: str) -> Type[T]:
# logger.info(f"instantiating registered subclass {name} of {cls}")
if name in Registrable._registry[cls]:
return Registrable._registry[cls].get(name)
else:
raise ValueError(
f"{name} is not a registered name for {cls.__name__}. "
f"the available is: [{Registrable._registry[cls].keys()}]"
)
@classmethod
def list_available(cls) -> List[str]:
keys = list(Registrable._registry[cls].keys())
default = cls.default_implementation
if default is None:
return keys
elif default not in keys:
message = "Default implementation %s is not registered" % default
raise ValueError(message)
else:
return [default] + [k for k in keys if k != default]