Apollo / look2hear /models /__init__.py
Serhiy Stetskovych
Initial code
78e32cc
raw
history blame
1.26 kB
###
# Author: Kai Li
# Date: 2022-02-12 15:16:35
# Email: [email protected]
# LastEditTime: 2022-10-04 16:24:53
###
from .base_model import BaseModel
from .apollo import Apollo
__all__ = [
"BaseModel",
"GullFullband",
"Apollo"
]
def register_model(custom_model):
"""Register a custom model, gettable with `models.get`.
Args:
custom_model: Custom model to register.
"""
if (
custom_model.__name__ in globals().keys()
or custom_model.__name__.lower() in globals().keys()
):
raise ValueError(
f"Model {custom_model.__name__} already exists. Choose another name."
)
globals().update({custom_model.__name__: custom_model})
def get(identifier):
"""Returns an model class from a string (case-insensitive).
Args:
identifier (str): the model name.
Returns:
:class:`torch.nn.Module`
"""
if isinstance(identifier, str):
to_get = {k.lower(): v for k, v in globals().items()}
cls = to_get.get(identifier.lower())
if cls is None:
raise ValueError(f"Could not interpret model name : {str(identifier)}")
return cls
raise ValueError(f"Could not interpret model name : {str(identifier)}")