import torch from torch import nn from typing import Optional from .vits_config import VitsConfig #............................................. @torch.jit.script def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): in_act = input_a + input_b t_act = torch.tanh(in_act[:, :num_channels, :]) s_act = torch.sigmoid(in_act[:, num_channels:, :]) acts = t_act * s_act return acts #............................................. class VitsWaveNet(torch.nn.Module): def __init__(self, config: VitsConfig, num_layers: int): super().__init__() self.hidden_size = config.hidden_size self.num_layers = num_layers self.speaker_embedding_size = config.speaker_embedding_size self.in_layers = torch.nn.ModuleList() self.res_skip_layers = torch.nn.ModuleList() self.dropout = nn.Dropout(config.wavenet_dropout) if hasattr(nn.utils.parametrizations, "weight_norm"): weight_norm = nn.utils.parametrizations.weight_norm else: weight_norm = nn.utils.weight_norm if config.speaker_embedding_size != 0: cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1) self.cond_layer = weight_norm(cond_layer, name="weight") for i in range(num_layers): dilation = config.wavenet_dilation_rate**i padding = (config.wavenet_kernel_size * dilation - dilation) // 2 in_layer = torch.nn.Conv1d( in_channels=config.hidden_size, out_channels=2 * config.hidden_size, kernel_size=config.wavenet_kernel_size, dilation=dilation, padding=padding, ) in_layer = weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) # last one is not necessary if i < num_layers - 1: res_skip_channels = 2 * config.hidden_size else: res_skip_channels = config.hidden_size res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) res_skip_layer = weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) def forward(self, inputs, padding_mask, global_conditioning=None): outputs = torch.zeros_like(inputs) num_channels_tensor = torch.IntTensor([self.hidden_size]) if global_conditioning is not None: global_conditioning = self.cond_layer(global_conditioning) for i in range(self.num_layers): hidden_states = self.in_layers[i](inputs) if global_conditioning is not None: cond_offset = i * 2 * self.hidden_size global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :] else: global_states = torch.zeros_like(hidden_states) acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0]) acts = self.dropout(acts) res_skip_acts = self.res_skip_layers[i](acts) if i < self.num_layers - 1: res_acts = res_skip_acts[:, : self.hidden_size, :] inputs = (inputs + res_acts) * padding_mask outputs = outputs + res_skip_acts[:, self.hidden_size :, :] else: outputs = outputs + res_skip_acts return outputs * padding_mask def remove_weight_norm(self): if self.speaker_embedding_size != 0: torch.nn.utils.remove_weight_norm(self.cond_layer) for layer in self.in_layers: torch.nn.utils.remove_weight_norm(layer) for layer in self.res_skip_layers: torch.nn.utils.remove_weight_norm(layer) def apply_weight_norm(self): if hasattr(nn.utils.parametrizations, "weight_norm"): weight_norm = nn.utils.parametrizations.weight_norm else: weight_norm = nn.utils.weight_norm if self.speaker_embedding_size != 0: weight_norm(self.cond_layer) for layer in self.in_layers: weight_norm(layer) for layer in self.res_skip_layers: weight_norm(layer) #............................................................................................. class VitsResidualCouplingLayer(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.half_channels = config.flow_size // 2 self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers) self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) hidden_states = self.conv_pre(first_half) * padding_mask hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning) mean = self.conv_post(hidden_states) * padding_mask log_stddev = torch.zeros_like(mean) if not reverse: second_half = mean + second_half * torch.exp(log_stddev) * padding_mask outputs = torch.cat([first_half, second_half], dim=1) log_determinant = torch.sum(log_stddev, [1, 2]) return outputs, log_determinant else: second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask outputs = torch.cat([first_half, second_half], dim=1) return outputs, None def apply_weight_norm(self): nn.utils.weight_norm(self.conv_pre) self.wavenet.apply_weight_norm() nn.utils.weight_norm(self.conv_post) def remove_weight_norm(self): nn.utils.remove_weight_norm(self.conv_pre) self.wavenet.remove_weight_norm() nn.utils.remove_weight_norm(self.conv_post) #............................................................................................. class VitsResidualCouplingBlock(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.flows = nn.ModuleList() for _ in range(config.prior_encoder_num_flows): self.flows.append(VitsResidualCouplingLayer(config)) def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): if not reverse: for flow in self.flows: inputs, _ = flow(inputs, padding_mask, global_conditioning) inputs = torch.flip(inputs, [1]) else: for flow in reversed(self.flows): inputs = torch.flip(inputs, [1]) inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True) return inputs def apply_weight_norm(self): for flow in self.flows: flow.apply_weight_norm() def remove_weight_norm(self): for flow in self.flows: flow.remove_weight_norm() def resize_speaker_embeddings(self, speaker_embedding_size: Optional[int] = None): for flow in self.flows: flow.wavenet.speaker_embedding_size = speaker_embedding_size hidden_size = flow.wavenet.hidden_size num_layers = flow.wavenet.num_layers cond_layer = torch.nn.Conv1d(speaker_embedding_size, 2 * hidden_size * num_layers, 1) flow.wavenet.cond_layer = nn.utils.weight_norm(cond_layer, name="weight")