wasmdashai's picture
model push
38f004a
raw
history blame
7.63 kB
import torch
from torch import nn
from typing import Optional
from .vits_config import VitsConfig
#.............................................
@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :num_channels, :])
s_act = torch.sigmoid(in_act[:, num_channels:, :])
acts = t_act * s_act
return acts
#.............................................
class VitsWaveNet(torch.nn.Module):
def __init__(self, config: VitsConfig, num_layers: int):
super().__init__()
self.hidden_size = config.hidden_size
self.num_layers = num_layers
self.speaker_embedding_size = config.speaker_embedding_size
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.dropout = nn.Dropout(config.wavenet_dropout)
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
else:
weight_norm = nn.utils.weight_norm
if config.speaker_embedding_size != 0:
cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
self.cond_layer = weight_norm(cond_layer, name="weight")
for i in range(num_layers):
dilation = config.wavenet_dilation_rate**i
padding = (config.wavenet_kernel_size * dilation - dilation) // 2
in_layer = torch.nn.Conv1d(
in_channels=config.hidden_size,
out_channels=2 * config.hidden_size,
kernel_size=config.wavenet_kernel_size,
dilation=dilation,
padding=padding,
)
in_layer = weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
# last one is not necessary
if i < num_layers - 1:
res_skip_channels = 2 * config.hidden_size
else:
res_skip_channels = config.hidden_size
res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
res_skip_layer = weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
def forward(self, inputs, padding_mask, global_conditioning=None):
outputs = torch.zeros_like(inputs)
num_channels_tensor = torch.IntTensor([self.hidden_size])
if global_conditioning is not None:
global_conditioning = self.cond_layer(global_conditioning)
for i in range(self.num_layers):
hidden_states = self.in_layers[i](inputs)
if global_conditioning is not None:
cond_offset = i * 2 * self.hidden_size
global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :]
else:
global_states = torch.zeros_like(hidden_states)
acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
acts = self.dropout(acts)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.num_layers - 1:
res_acts = res_skip_acts[:, : self.hidden_size, :]
inputs = (inputs + res_acts) * padding_mask
outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
else:
outputs = outputs + res_skip_acts
return outputs * padding_mask
def remove_weight_norm(self):
if self.speaker_embedding_size != 0:
torch.nn.utils.remove_weight_norm(self.cond_layer)
for layer in self.in_layers:
torch.nn.utils.remove_weight_norm(layer)
for layer in self.res_skip_layers:
torch.nn.utils.remove_weight_norm(layer)
def apply_weight_norm(self):
if hasattr(nn.utils.parametrizations, "weight_norm"):
weight_norm = nn.utils.parametrizations.weight_norm
else:
weight_norm = nn.utils.weight_norm
if self.speaker_embedding_size != 0:
weight_norm(self.cond_layer)
for layer in self.in_layers:
weight_norm(layer)
for layer in self.res_skip_layers:
weight_norm(layer)
#.............................................................................................
class VitsResidualCouplingLayer(nn.Module):
def __init__(self, config: VitsConfig):
super().__init__()
self.half_channels = config.flow_size // 2
self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
hidden_states = self.conv_pre(first_half) * padding_mask
hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
mean = self.conv_post(hidden_states) * padding_mask
log_stddev = torch.zeros_like(mean)
if not reverse:
second_half = mean + second_half * torch.exp(log_stddev) * padding_mask
outputs = torch.cat([first_half, second_half], dim=1)
log_determinant = torch.sum(log_stddev, [1, 2])
return outputs, log_determinant
else:
second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
outputs = torch.cat([first_half, second_half], dim=1)
return outputs, None
def apply_weight_norm(self):
nn.utils.weight_norm(self.conv_pre)
self.wavenet.apply_weight_norm()
nn.utils.weight_norm(self.conv_post)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.conv_pre)
self.wavenet.remove_weight_norm()
nn.utils.remove_weight_norm(self.conv_post)
#.............................................................................................
class VitsResidualCouplingBlock(nn.Module):
def __init__(self, config: VitsConfig):
super().__init__()
self.flows = nn.ModuleList()
for _ in range(config.prior_encoder_num_flows):
self.flows.append(VitsResidualCouplingLayer(config))
def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
if not reverse:
for flow in self.flows:
inputs, _ = flow(inputs, padding_mask, global_conditioning)
inputs = torch.flip(inputs, [1])
else:
for flow in reversed(self.flows):
inputs = torch.flip(inputs, [1])
inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
return inputs
def apply_weight_norm(self):
for flow in self.flows:
flow.apply_weight_norm()
def remove_weight_norm(self):
for flow in self.flows:
flow.remove_weight_norm()
def resize_speaker_embeddings(self, speaker_embedding_size: Optional[int] = None):
for flow in self.flows:
flow.wavenet.speaker_embedding_size = speaker_embedding_size
hidden_size = flow.wavenet.hidden_size
num_layers = flow.wavenet.num_layers
cond_layer = torch.nn.Conv1d(speaker_embedding_size, 2 * hidden_size * num_layers, 1)
flow.wavenet.cond_layer = nn.utils.weight_norm(cond_layer, name="weight")