Spaces:
Runtime error
Runtime error
File size: 7,390 Bytes
27d3bc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import torch
import torch.nn as nn
import numpy as np
from torch.nn import Conv1d
from torch.nn import ConvTranspose1d
from torch.nn.utils import weight_norm
from torch.nn.utils import remove_weight_norm
from .nsf import SourceModuleHnNSF
from .bigv import init_weights, SnakeBeta, AMPBlock
from .alias import Activation1d
class SpeakerAdapter(nn.Module):
def __init__(self,
speaker_dim,
adapter_dim,
epsilon=1e-5
):
super(SpeakerAdapter, self).__init__()
self.speaker_dim = speaker_dim
self.adapter_dim = adapter_dim
self.epsilon = epsilon
self.W_scale = nn.Linear(self.speaker_dim, self.adapter_dim)
self.W_bias = nn.Linear(self.speaker_dim, self.adapter_dim)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.constant_(self.W_scale.weight, 0.0)
torch.nn.init.constant_(self.W_scale.bias, 1.0)
torch.nn.init.constant_(self.W_bias.weight, 0.0)
torch.nn.init.constant_(self.W_bias.bias, 0.0)
def forward(self, x, speaker_embedding):
x = x.transpose(1, -1)
mean = x.mean(dim=-1, keepdim=True)
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
std = (var + self.epsilon).sqrt()
y = (x - mean) / std
scale = self.W_scale(speaker_embedding)
bias = self.W_bias(speaker_embedding)
y *= scale.unsqueeze(1)
y += bias.unsqueeze(1)
y = y.transpose(1, -1)
return y
class Generator(torch.nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self, hp):
super(Generator, self).__init__()
self.hp = hp
self.num_kernels = len(hp.gen.resblock_kernel_sizes)
self.num_upsamples = len(hp.gen.upsample_rates)
# speaker adaper, 256 should change by what speaker encoder you use
self.adapter = SpeakerAdapter(hp.vits.spk_dim, hp.gen.upsample_input)
# pre conv
self.conv_pre = nn.utils.weight_norm(
Conv1d(hp.gen.upsample_input, hp.gen.upsample_initial_channel, 7, 1, padding=3))
# nsf
self.f0_upsamp = torch.nn.Upsample(
scale_factor=np.prod(hp.gen.upsample_rates))
self.m_source = SourceModuleHnNSF()
self.noise_convs = nn.ModuleList()
# transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(hp.gen.upsample_rates, hp.gen.upsample_kernel_sizes)):
# print(f'ups: {i} {k}, {u}, {(k - u) // 2}')
# base
self.ups.append(nn.ModuleList([
weight_norm(ConvTranspose1d(hp.gen.upsample_initial_channel // (2 ** i),
hp.gen.upsample_initial_channel // (
2 ** (i + 1)),
k, u, padding=(k - u) // 2))
]))
# nsf
if i + 1 < len(hp.gen.upsample_rates):
stride_f0 = np.prod(hp.gen.upsample_rates[i + 1:])
stride_f0 = int(stride_f0)
self.noise_convs.append(
Conv1d(
1,
hp.gen.upsample_initial_channel // (2 ** (i + 1)),
kernel_size=stride_f0 * 2,
stride=stride_f0,
padding=stride_f0 // 2,
)
)
else:
self.noise_convs.append(
Conv1d(1, hp.gen.upsample_initial_channel //
(2 ** (i + 1)), kernel_size=1)
)
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = hp.gen.upsample_initial_channel // (2 ** (i + 1))
for k, d in zip(hp.gen.resblock_kernel_sizes, hp.gen.resblock_dilation_sizes):
self.resblocks.append(AMPBlock(hp, ch, k, d))
# post conv
activation_post = SnakeBeta(ch, alpha_logscale=True)
self.activation_post = Activation1d(activation=activation_post)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
# weight initialization
for i in range(len(self.ups)):
self.ups[i].apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, spk, x, f0):
# adapter
x = self.adapter(x, spk)
# nsf
f0 = f0[:, None]
f0 = self.f0_upsamp(f0).transpose(1, 2)
har_source = self.m_source(f0)
har_source = har_source.transpose(1, 2)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# nsf
x_source = self.noise_convs[i](har_source)
x = x + x_source
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.ups:
for l_i in l:
remove_weight_norm(l_i)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
def eval(self, inference=False):
super(Generator, self).eval()
# don't remove weight norm while validation in training loop
if inference:
self.remove_weight_norm()
def pitch2source(self, f0):
f0 = f0[:, None]
f0 = self.f0_upsamp(f0).transpose(1, 2) # [1,len,1]
har_source = self.m_source(f0)
har_source = har_source.transpose(1, 2) # [1,1,len]
return har_source
def source2wav(self, audio):
MAX_WAV_VALUE = 32768.0
audio = audio.squeeze()
audio = MAX_WAV_VALUE * audio
audio = audio.clamp(min=-MAX_WAV_VALUE, max=MAX_WAV_VALUE-1)
audio = audio.short()
return audio.cpu().detach().numpy()
def inference(self, spk, x, har_source):
# adapter
x = self.adapter(x, spk)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
# upsampling
for i_up in range(len(self.ups[i])):
x = self.ups[i][i_up](x)
# nsf
x_source = self.noise_convs[i](har_source)
x = x + x_source
# AMP blocks
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
# post conv
x = self.activation_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
|