Spaces:
Runtime error
Runtime error
File size: 5,605 Bytes
51a61da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import yaml
import torch
import torch.nn as nn
import numpy as np
from .diffusion import GaussianDiffusion
from .wavenet import WaveNet
from .vocoder import Vocoder
class DotDict(dict):
def __getattr__(*args):
val = dict.get(*args)
return DotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def load_model_vocoder(
model_path,
device='cpu',
config_path = None
):
if config_path is None: config_file = os.path.join(os.path.split(model_path)[0], 'config.yaml')
else: config_file = config_path
with open(config_file, "r") as config:
args = yaml.safe_load(config)
args = DotDict(args)
# load vocoder
vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=device)
# load model
model = Unit2Mel(
args.data.encoder_out_channels,
args.model.n_spk,
args.model.use_pitch_aug,
vocoder.dimension,
args.model.n_layers,
args.model.n_chans,
args.model.n_hidden)
print(' [Loading] ' + model_path)
ckpt = torch.load(model_path, map_location=torch.device(device))
model.to(device)
model.load_state_dict(ckpt['model'])
model.eval()
return model, vocoder, args
class Unit2Mel(nn.Module):
def __init__(
self,
input_channel,
n_spk,
use_pitch_aug=False,
out_dims=128,
n_layers=20,
n_chans=384,
n_hidden=256):
super().__init__()
self.unit_embed = nn.Linear(input_channel, n_hidden)
self.f0_embed = nn.Linear(1, n_hidden)
self.volume_embed = nn.Linear(1, n_hidden)
if use_pitch_aug:
self.aug_shift_embed = nn.Linear(1, n_hidden, bias=False)
else:
self.aug_shift_embed = None
self.n_spk = n_spk
if n_spk is not None and n_spk > 1:
self.spk_embed = nn.Embedding(n_spk, n_hidden)
self.n_hidden = n_hidden
# diffusion
self.decoder = GaussianDiffusion(WaveNet(out_dims, n_layers, n_chans, n_hidden), out_dims=out_dims)
self.input_channel = input_channel
def init_spkembed(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
'''
input:
B x n_frames x n_unit
return:
dict of B x n_frames x feat
'''
x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
if self.n_spk is not None and self.n_spk > 1:
if spk_mix_dict is not None:
spk_embed_mix = torch.zeros((1,1,self.hidden_size))
for k, v in spk_mix_dict.items():
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
spk_embeddd = self.spk_embed(spk_id_torch)
self.speaker_map[k] = spk_embeddd
spk_embed_mix = spk_embed_mix + v * spk_embeddd
x = x + spk_embed_mix
else:
x = x + self.spk_embed(spk_id - 1)
self.speaker_map = self.speaker_map.unsqueeze(0)
self.speaker_map = self.speaker_map.detach()
return x.transpose(1, 2)
def init_spkmix(self, n_spk):
self.speaker_map = torch.zeros((n_spk,1,1,self.n_hidden))
hubert_hidden_size = self.input_channel
n_frames = 10
hubert = torch.randn((1, n_frames, hubert_hidden_size))
mel2ph = torch.arange(end=n_frames).unsqueeze(0).long()
f0 = torch.randn((1, n_frames))
volume = torch.randn((1, n_frames))
spks = {}
for i in range(n_spk):
spks.update({i:1.0/float(self.n_spk)})
orgouttt = self.init_spkembed(hubert, f0.unsqueeze(-1), volume.unsqueeze(-1), spk_mix_dict=spks)
def forward(self, units, f0, volume, spk_id = None, spk_mix_dict = None, aug_shift = None,
gt_spec=None, infer=True, infer_speedup=10, method='dpm-solver', k_step=300, use_tqdm=True):
'''
input:
B x n_frames x n_unit
return:
dict of B x n_frames x feat
'''
x = self.unit_embed(units) + self.f0_embed((1+ f0 / 700).log()) + self.volume_embed(volume)
if self.n_spk is not None and self.n_spk > 1:
if spk_mix_dict is not None:
for k, v in spk_mix_dict.items():
spk_id_torch = torch.LongTensor(np.array([[k]])).to(units.device)
x = x + v * self.spk_embed(spk_id_torch)
else:
if spk_id.shape[1] > 1:
g = spk_id.reshape((spk_id.shape[0], spk_id.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
g = g * self.speaker_map # [N, S, B, 1, H]
g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
x = x + g
else:
x = x + self.spk_embed(spk_id)
if self.aug_shift_embed is not None and aug_shift is not None:
x = x + self.aug_shift_embed(aug_shift / 5)
x = self.decoder(x, gt_spec=gt_spec, infer=infer, infer_speedup=infer_speedup, method=method, k_step=k_step, use_tqdm=use_tqdm)
return x
|