Spaces:
Running
Running
File size: 5,987 Bytes
38f004a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from torch import nn
import torch
from .vits_config import VitsPreTrainedModel
#.............................................
class VitsHifiGanDiscriminatorScaleResidualBlock(nn.Module):
def __init__(self, discriminator_scale_channels, leaky_relu_slope=0.1):
super().__init__()
self.leaky_relu_slope = leaky_relu_slope
in_channels, out_channels = discriminator_scale_channels[:2]
self.convs = nn.ModuleList([nn.Conv1d(in_channels, out_channels, 15, 1, padding=7)])
groups = 4
for in_channels, out_channels in zip(discriminator_scale_channels[1:-1], discriminator_scale_channels[2:]):
self.convs.append(nn.Conv1d(in_channels, out_channels, 41, 4, groups=groups, padding=20))
groups = groups * 4
channel_size = discriminator_scale_channels[-1]
self.convs.append(nn.Conv1d(channel_size, channel_size, 41, 4, groups=groups, padding=20))
self.convs.append(nn.Conv1d(channel_size, channel_size, 5, 1, padding=2))
self.final_conv = nn.Conv1d(channel_size, 1, 3, 1, padding=1)
def apply_weight_norm(self):
for layer in self.convs:
nn.utils.weight_norm(layer)
nn.utils.weight_norm(self.final_conv)
def remove_weight_norm(self):
for layer in self.convs:
nn.utils.remove_weight_norm(layer)
nn.utils.remove_weight_norm(self.final_conv)
def forward(self, hidden_states):
fmap = []
for conv in self.convs:
hidden_states = conv(hidden_states)
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
fmap.append(hidden_states)
hidden_states = self.final_conv(hidden_states)
fmap.append(hidden_states)
hidden_states = torch.flatten(hidden_states, 1, -1)
return hidden_states, fmap
#.............................................................................................
class VitsHifiGanDiscriminatorPeriodResidualBlock(nn.Module):
def __init__(self, discriminator_period_channels, period, kernel_size=5, stride=3, leaky_relu_slope=0.1):
super().__init__()
self.leaky_relu_slope = leaky_relu_slope
self.period = period
self.convs = nn.ModuleList()
for in_channels, out_channels in zip(discriminator_period_channels[:-1], discriminator_period_channels[1:]):
self.convs.append(
nn.Conv2d(
in_channels,
out_channels,
(kernel_size, 1),
(stride, 1),
padding=(self.get_padding(kernel_size, 1), 0),
)
)
channel_size = discriminator_period_channels[-1]
self.convs.append(
nn.Conv2d(channel_size, channel_size, (kernel_size, 1), 1, padding=(self.get_padding(kernel_size, 1), 0))
)
self.final_conv = nn.Conv2d(channel_size, 1, (3, 1), 1, padding=(1, 0))
def get_padding(self, kernel_size, dilation=1):
return (kernel_size * dilation - dilation) // 2
def apply_weight_norm(self):
for layer in self.convs:
nn.utils.weight_norm(layer)
nn.utils.weight_norm(self.final_conv)
def remove_weight_norm(self):
for layer in self.convs:
nn.utils.remove_weight_norm(layer)
nn.utils.remove_weight_norm(self.final_conv)
def forward(self, hidden_states):
fmap = []
# from 1D to 2D
batch_size, channels, length = hidden_states.shape
if length % self.period != 0:
# pad first
n_pad = self.period - (length % self.period)
hidden_states = nn.functional.pad(hidden_states, (0, n_pad), "reflect")
length = length + n_pad
hidden_states = hidden_states.view(batch_size, channels, length // self.period, self.period)
for conv in self.convs:
hidden_states = conv(hidden_states)
hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
fmap.append(hidden_states)
hidden_states = self.final_conv(hidden_states)
fmap.append(hidden_states)
hidden_states = torch.flatten(hidden_states, 1, -1)
return hidden_states, fmap
#.............................................................................................
class VitsDiscriminator(VitsPreTrainedModel):
def __init__(self, config):
super().__init__(config)
if config.discriminator_scale_channels is not None:
self.discriminators = nn.ModuleList(
[VitsHifiGanDiscriminatorScaleResidualBlock(config.discriminator_scale_channels, config.leaky_relu_slope)]
)
else:
self.discriminators = nn.ModuleList([])
self.discriminators.extend(
[
VitsHifiGanDiscriminatorPeriodResidualBlock(
config.discriminator_period_channels,
period,
config.discriminator_kernel_size,
config.discriminator_stride,
config.leaky_relu_slope,
)
for period in config.discriminator_periods
]
)
def forward(self, hidden_states):
fmaps = []
discriminated_hidden_states_list = []
for discriminator in self.discriminators:
discriminated_hidden_states, fmap = discriminator(hidden_states)
fmaps.append(fmap)
discriminated_hidden_states_list.append(discriminated_hidden_states)
return discriminated_hidden_states_list, fmaps
def apply_weight_norm(self):
for disc in self.discriminators:
disc.apply_weight_norm()
def remove_weight_norm(self):
for disc in self.discriminators:
disc.remove_weight_norm()
#............................................................................................. |