Spaces:
Running
Running
from torch import nn | |
import torch | |
from .vits_config import VitsPreTrainedModel | |
#............................................. | |
class VitsHifiGanDiscriminatorScaleResidualBlock(nn.Module): | |
def __init__(self, discriminator_scale_channels, leaky_relu_slope=0.1): | |
super().__init__() | |
self.leaky_relu_slope = leaky_relu_slope | |
in_channels, out_channels = discriminator_scale_channels[:2] | |
self.convs = nn.ModuleList([nn.Conv1d(in_channels, out_channels, 15, 1, padding=7)]) | |
groups = 4 | |
for in_channels, out_channels in zip(discriminator_scale_channels[1:-1], discriminator_scale_channels[2:]): | |
self.convs.append(nn.Conv1d(in_channels, out_channels, 41, 4, groups=groups, padding=20)) | |
groups = groups * 4 | |
channel_size = discriminator_scale_channels[-1] | |
self.convs.append(nn.Conv1d(channel_size, channel_size, 41, 4, groups=groups, padding=20)) | |
self.convs.append(nn.Conv1d(channel_size, channel_size, 5, 1, padding=2)) | |
self.final_conv = nn.Conv1d(channel_size, 1, 3, 1, padding=1) | |
def apply_weight_norm(self): | |
for layer in self.convs: | |
nn.utils.weight_norm(layer) | |
nn.utils.weight_norm(self.final_conv) | |
def remove_weight_norm(self): | |
for layer in self.convs: | |
nn.utils.remove_weight_norm(layer) | |
nn.utils.remove_weight_norm(self.final_conv) | |
def forward(self, hidden_states): | |
fmap = [] | |
for conv in self.convs: | |
hidden_states = conv(hidden_states) | |
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) | |
fmap.append(hidden_states) | |
hidden_states = self.final_conv(hidden_states) | |
fmap.append(hidden_states) | |
hidden_states = torch.flatten(hidden_states, 1, -1) | |
return hidden_states, fmap | |
#............................................................................................. | |
class VitsHifiGanDiscriminatorPeriodResidualBlock(nn.Module): | |
def __init__(self, discriminator_period_channels, period, kernel_size=5, stride=3, leaky_relu_slope=0.1): | |
super().__init__() | |
self.leaky_relu_slope = leaky_relu_slope | |
self.period = period | |
self.convs = nn.ModuleList() | |
for in_channels, out_channels in zip(discriminator_period_channels[:-1], discriminator_period_channels[1:]): | |
self.convs.append( | |
nn.Conv2d( | |
in_channels, | |
out_channels, | |
(kernel_size, 1), | |
(stride, 1), | |
padding=(self.get_padding(kernel_size, 1), 0), | |
) | |
) | |
channel_size = discriminator_period_channels[-1] | |
self.convs.append( | |
nn.Conv2d(channel_size, channel_size, (kernel_size, 1), 1, padding=(self.get_padding(kernel_size, 1), 0)) | |
) | |
self.final_conv = nn.Conv2d(channel_size, 1, (3, 1), 1, padding=(1, 0)) | |
def get_padding(self, kernel_size, dilation=1): | |
return (kernel_size * dilation - dilation) // 2 | |
def apply_weight_norm(self): | |
for layer in self.convs: | |
nn.utils.weight_norm(layer) | |
nn.utils.weight_norm(self.final_conv) | |
def remove_weight_norm(self): | |
for layer in self.convs: | |
nn.utils.remove_weight_norm(layer) | |
nn.utils.remove_weight_norm(self.final_conv) | |
def forward(self, hidden_states): | |
fmap = [] | |
# from 1D to 2D | |
batch_size, channels, length = hidden_states.shape | |
if length % self.period != 0: | |
# pad first | |
n_pad = self.period - (length % self.period) | |
hidden_states = nn.functional.pad(hidden_states, (0, n_pad), "reflect") | |
length = length + n_pad | |
hidden_states = hidden_states.view(batch_size, channels, length // self.period, self.period) | |
for conv in self.convs: | |
hidden_states = conv(hidden_states) | |
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) | |
fmap.append(hidden_states) | |
hidden_states = self.final_conv(hidden_states) | |
fmap.append(hidden_states) | |
hidden_states = torch.flatten(hidden_states, 1, -1) | |
return hidden_states, fmap | |
#............................................................................................. | |
class VitsDiscriminator(VitsPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
if config.discriminator_scale_channels is not None: | |
self.discriminators = nn.ModuleList( | |
[VitsHifiGanDiscriminatorScaleResidualBlock(config.discriminator_scale_channels, config.leaky_relu_slope)] | |
) | |
else: | |
self.discriminators = nn.ModuleList([]) | |
self.discriminators.extend( | |
[ | |
VitsHifiGanDiscriminatorPeriodResidualBlock( | |
config.discriminator_period_channels, | |
period, | |
config.discriminator_kernel_size, | |
config.discriminator_stride, | |
config.leaky_relu_slope, | |
) | |
for period in config.discriminator_periods | |
] | |
) | |
def forward(self, hidden_states): | |
fmaps = [] | |
discriminated_hidden_states_list = [] | |
for discriminator in self.discriminators: | |
discriminated_hidden_states, fmap = discriminator(hidden_states) | |
fmaps.append(fmap) | |
discriminated_hidden_states_list.append(discriminated_hidden_states) | |
return discriminated_hidden_states_list, fmaps | |
def apply_weight_norm(self): | |
for disc in self.discriminators: | |
disc.apply_weight_norm() | |
def remove_weight_norm(self): | |
for disc in self.discriminators: | |
disc.remove_weight_norm() | |
#............................................................................................. |