Apollo / look2hear /models /base_model.py
Serhiy Stetskovych
Initial code
78e32cc
raw
history blame
2.89 kB
###
# Author: Kai Li
# Date: 2021-06-17 23:08:32
# LastEditors: Please set LastEditors
# LastEditTime: 2022-05-26 18:06:22
###
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
def _unsqueeze_to_3d(x):
"""Normalize shape of `x` to [batch, n_chan, time]."""
if x.ndim == 1:
return x.reshape(1, 1, -1)
elif x.ndim == 2:
return x.unsqueeze(1)
else:
return x
def pad_to_appropriate_length(x, lcm):
values_to_pad = int(x.shape[-1]) % lcm
if values_to_pad:
appropriate_shape = x.shape
padded_x = torch.zeros(
list(appropriate_shape[:-1])
+ [appropriate_shape[-1] + lcm - values_to_pad],
dtype=torch.float32,
).to(x.device)
padded_x[..., : x.shape[-1]] = x
return padded_x
return x
class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"):
def __init__(self, sample_rate, in_chan=1):
super().__init__()
self._sample_rate = sample_rate
self._in_chan = in_chan
def forward(self, *args, **kwargs):
raise NotImplementedError
def sample_rate(self,):
return self._sample_rate
@staticmethod
def load_state_dict_in_audio(model, pretrained_dict):
model_dict = model.state_dict()
update_dict = {}
for k, v in pretrained_dict.items():
if "audio_model" in k:
update_dict[k[12:]] = v
model_dict.update(update_dict)
model.load_state_dict(model_dict)
return model
@staticmethod
def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs):
from . import get
conf = torch.load(
pretrained_model_conf_or_path, map_location="cpu"
) # Attempt to find the model and instantiate it.
model_class = get(conf["model_name"])
# model_class = get("Conv_TasNet")
model = model_class(*args, **kwargs)
model.load_state_dict(conf["state_dict"])
return model
def serialize(self):
import pytorch_lightning as pl # Not used in torch.hub
model_conf = dict(
model_name=self.__class__.__name__,
state_dict=self.get_state_dict(),
model_args=self.get_model_args(),
)
# Additional infos
infos = dict()
infos["software_versions"] = dict(
torch_version=torch.__version__, pytorch_lightning_version=pl.__version__,
)
model_conf["infos"] = infos
return model_conf
def get_state_dict(self):
"""In case the state dict needs to be modified before sharing the model."""
return self.state_dict()
def get_model_args(self):
"""Should return args to re-instantiate the class."""
raise NotImplementedError