ONNXServies / VitsModelSplit /discriminator.py
wasmdashai's picture
model push
38f004a
from torch import nn
import torch
from .vits_config import VitsPreTrainedModel
#.............................................
class VitsHifiGanDiscriminatorScaleResidualBlock(nn.Module):
def __init__(self, discriminator_scale_channels, leaky_relu_slope=0.1):
super().__init__()
self.leaky_relu_slope = leaky_relu_slope
in_channels, out_channels = discriminator_scale_channels[:2]
self.convs = nn.ModuleList([nn.Conv1d(in_channels, out_channels, 15, 1, padding=7)])
groups = 4
for in_channels, out_channels in zip(discriminator_scale_channels[1:-1], discriminator_scale_channels[2:]):
self.convs.append(nn.Conv1d(in_channels, out_channels, 41, 4, groups=groups, padding=20))
groups = groups * 4
channel_size = discriminator_scale_channels[-1]
self.convs.append(nn.Conv1d(channel_size, channel_size, 41, 4, groups=groups, padding=20))
self.convs.append(nn.Conv1d(channel_size, channel_size, 5, 1, padding=2))
self.final_conv = nn.Conv1d(channel_size, 1, 3, 1, padding=1)
def apply_weight_norm(self):
for layer in self.convs:
nn.utils.weight_norm(layer)
nn.utils.weight_norm(self.final_conv)
def remove_weight_norm(self):
for layer in self.convs:
nn.utils.remove_weight_norm(layer)
nn.utils.remove_weight_norm(self.final_conv)
def forward(self, hidden_states):
fmap = []
for conv in self.convs:
hidden_states = conv(hidden_states)
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
fmap.append(hidden_states)
hidden_states = self.final_conv(hidden_states)
fmap.append(hidden_states)
hidden_states = torch.flatten(hidden_states, 1, -1)
return hidden_states, fmap
#.............................................................................................
class VitsHifiGanDiscriminatorPeriodResidualBlock(nn.Module):
def __init__(self, discriminator_period_channels, period, kernel_size=5, stride=3, leaky_relu_slope=0.1):
super().__init__()
self.leaky_relu_slope = leaky_relu_slope
self.period = period
self.convs = nn.ModuleList()
for in_channels, out_channels in zip(discriminator_period_channels[:-1], discriminator_period_channels[1:]):
self.convs.append(
nn.Conv2d(
in_channels,
out_channels,
(kernel_size, 1),
(stride, 1),
padding=(self.get_padding(kernel_size, 1), 0),
)
)
channel_size = discriminator_period_channels[-1]
self.convs.append(
nn.Conv2d(channel_size, channel_size, (kernel_size, 1), 1, padding=(self.get_padding(kernel_size, 1), 0))
)
self.final_conv = nn.Conv2d(channel_size, 1, (3, 1), 1, padding=(1, 0))
def get_padding(self, kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
def apply_weight_norm(self):
for layer in self.convs:
nn.utils.weight_norm(layer)
nn.utils.weight_norm(self.final_conv)
def remove_weight_norm(self):
for layer in self.convs:
nn.utils.remove_weight_norm(layer)
nn.utils.remove_weight_norm(self.final_conv)
def forward(self, hidden_states):
fmap = []
# from 1D to 2D
batch_size, channels, length = hidden_states.shape
if length % self.period != 0:
# pad first
n_pad = self.period - (length % self.period)
hidden_states = nn.functional.pad(hidden_states, (0, n_pad), "reflect")
length = length + n_pad
hidden_states = hidden_states.view(batch_size, channels, length // self.period, self.period)
for conv in self.convs:
hidden_states = conv(hidden_states)
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
fmap.append(hidden_states)
hidden_states = self.final_conv(hidden_states)
fmap.append(hidden_states)
hidden_states = torch.flatten(hidden_states, 1, -1)
return hidden_states, fmap
#.............................................................................................
class VitsDiscriminator(VitsPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if config.discriminator_scale_channels is not None:
self.discriminators = nn.ModuleList(
[VitsHifiGanDiscriminatorScaleResidualBlock(config.discriminator_scale_channels, config.leaky_relu_slope)]
)
else:
self.discriminators = nn.ModuleList([])
self.discriminators.extend(
[
VitsHifiGanDiscriminatorPeriodResidualBlock(
config.discriminator_period_channels,
period,
config.discriminator_kernel_size,
config.discriminator_stride,
config.leaky_relu_slope,
)
for period in config.discriminator_periods
]
)
def forward(self, hidden_states):
fmaps = []
discriminated_hidden_states_list = []
for discriminator in self.discriminators:
discriminated_hidden_states, fmap = discriminator(hidden_states)
fmaps.append(fmap)
discriminated_hidden_states_list.append(discriminated_hidden_states)
return discriminated_hidden_states_list, fmaps
def apply_weight_norm(self):
for disc in self.discriminators:
disc.apply_weight_norm()
def remove_weight_norm(self):
for disc in self.discriminators:
disc.remove_weight_norm()
#.............................................................................................