Spaces:
Running
on
Zero
Running
on
Zero
### | |
# Author: Kai Li | |
# Date: 2022-02-12 15:16:35 | |
# Email: [email protected] | |
# LastEditTime: 2022-10-04 16:24:53 | |
### | |
from .base_model import BaseModel | |
from .apollo import Apollo | |
__all__ = [ | |
"BaseModel", | |
"GullFullband", | |
"Apollo" | |
] | |
def register_model(custom_model): | |
"""Register a custom model, gettable with `models.get`. | |
Args: | |
custom_model: Custom model to register. | |
""" | |
if ( | |
custom_model.__name__ in globals().keys() | |
or custom_model.__name__.lower() in globals().keys() | |
): | |
raise ValueError( | |
f"Model {custom_model.__name__} already exists. Choose another name." | |
) | |
globals().update({custom_model.__name__: custom_model}) | |
def get(identifier): | |
"""Returns an model class from a string (case-insensitive). | |
Args: | |
identifier (str): the model name. | |
Returns: | |
:class:`torch.nn.Module` | |
""" | |
if isinstance(identifier, str): | |
to_get = {k.lower(): v for k, v in globals().items()} | |
cls = to_get.get(identifier.lower()) | |
if cls is None: | |
raise ValueError(f"Could not interpret model name : {str(identifier)}") | |
return cls | |
raise ValueError(f"Could not interpret model name : {str(identifier)}") | |