Spaces:
Running
on
Zero
Running
on
Zero
### | |
# Author: Kai Li | |
# Date: 2021-06-17 23:08:32 | |
# LastEditors: Please set LastEditors | |
# LastEditTime: 2022-05-26 18:06:22 | |
### | |
import torch | |
import torch.nn as nn | |
from huggingface_hub import PyTorchModelHubMixin | |
def _unsqueeze_to_3d(x): | |
"""Normalize shape of `x` to [batch, n_chan, time].""" | |
if x.ndim == 1: | |
return x.reshape(1, 1, -1) | |
elif x.ndim == 2: | |
return x.unsqueeze(1) | |
else: | |
return x | |
def pad_to_appropriate_length(x, lcm): | |
values_to_pad = int(x.shape[-1]) % lcm | |
if values_to_pad: | |
appropriate_shape = x.shape | |
padded_x = torch.zeros( | |
list(appropriate_shape[:-1]) | |
+ [appropriate_shape[-1] + lcm - values_to_pad], | |
dtype=torch.float32, | |
).to(x.device) | |
padded_x[..., : x.shape[-1]] = x | |
return padded_x | |
return x | |
class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"): | |
def __init__(self, sample_rate, in_chan=1): | |
super().__init__() | |
self._sample_rate = sample_rate | |
self._in_chan = in_chan | |
def forward(self, *args, **kwargs): | |
raise NotImplementedError | |
def sample_rate(self,): | |
return self._sample_rate | |
def load_state_dict_in_audio(model, pretrained_dict): | |
model_dict = model.state_dict() | |
update_dict = {} | |
for k, v in pretrained_dict.items(): | |
if "audio_model" in k: | |
update_dict[k[12:]] = v | |
model_dict.update(update_dict) | |
model.load_state_dict(model_dict) | |
return model | |
def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs): | |
from . import get | |
conf = torch.load( | |
pretrained_model_conf_or_path, map_location="cpu" | |
) # Attempt to find the model and instantiate it. | |
model_class = get(conf["model_name"]) | |
# model_class = get("Conv_TasNet") | |
model = model_class(*args, **kwargs) | |
model.load_state_dict(conf["state_dict"]) | |
return model | |
def serialize(self): | |
import pytorch_lightning as pl # Not used in torch.hub | |
model_conf = dict( | |
model_name=self.__class__.__name__, | |
state_dict=self.get_state_dict(), | |
model_args=self.get_model_args(), | |
) | |
# Additional infos | |
infos = dict() | |
infos["software_versions"] = dict( | |
torch_version=torch.__version__, pytorch_lightning_version=pl.__version__, | |
) | |
model_conf["infos"] = infos | |
return model_conf | |
def get_state_dict(self): | |
"""In case the state dict needs to be modified before sharing the model.""" | |
return self.state_dict() | |
def get_model_args(self): | |
"""Should return args to re-instantiate the class.""" | |
raise NotImplementedError | |