import math from typing import Optional import numpy as np import torch from torch import nn from .vits_config import VitsConfig #............................................. # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock class HifiGanResidualBlock(nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): super().__init__() self.leaky_relu_slope = leaky_relu_slope self.convs1 = nn.ModuleList( [ nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=dilation[i], padding=self.get_padding(kernel_size, dilation[i]), ) for i in range(len(dilation)) ] ) self.convs2 = nn.ModuleList( [ nn.Conv1d( channels, channels, kernel_size, stride=1, dilation=1, padding=self.get_padding(kernel_size, 1), ) for _ in range(len(dilation)) ] ) def get_padding(self, kernel_size, dilation=1): return (kernel_size * dilation - dilation) // 2 def apply_weight_norm(self): for layer in self.convs1: nn.utils.weight_norm(layer) for layer in self.convs2: nn.utils.weight_norm(layer) def remove_weight_norm(self): for layer in self.convs1: nn.utils.remove_weight_norm(layer) for layer in self.convs2: nn.utils.remove_weight_norm(layer) def forward(self, hidden_states): for conv1, conv2 in zip(self.convs1, self.convs2): residual = hidden_states hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) hidden_states = conv1(hidden_states) hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) hidden_states = conv2(hidden_states) hidden_states = hidden_states + residual return hidden_states #............................................................................................. class VitsHifiGan(nn.Module): def __init__(self, config: VitsConfig): super().__init__() self.config = config self.num_kernels = len(config.resblock_kernel_sizes) self.num_upsamples = len(config.upsample_rates) self.conv_pre = nn.Conv1d( config.flow_size, config.upsample_initial_channel, kernel_size=7, stride=1, padding=3, ) self.upsampler = nn.ModuleList() for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): self.upsampler.append( nn.ConvTranspose1d( config.upsample_initial_channel // (2**i), config.upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel_size, stride=upsample_rate, padding=(kernel_size - upsample_rate) // 2, ) ) self.resblocks = nn.ModuleList() for i in range(len(self.upsampler)): channels = config.upsample_initial_channel // (2 ** (i + 1)) for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False) if config.speaker_embedding_size != 0: self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1) def resize_speaker_embedding(self, speaker_embedding_size): self.config.speaker_embedding_size = speaker_embedding_size self.cond = nn.Conv1d(speaker_embedding_size, self.config.upsample_initial_channel, 1) nn.init.kaiming_normal_(self.cond.weight) if self.cond.bias is not None: k = math.sqrt(self.cond.groups / (self.cond.in_channels * self.cond.kernel_size[0])) nn.init.uniform_(self.cond.bias, a=-k, b=k) def apply_weight_norm(self): for layer in self.upsampler: nn.utils.weight_norm(layer) for layer in self.resblocks: layer.apply_weight_norm() def remove_weight_norm(self): for layer in self.upsampler: nn.utils.remove_weight_norm(layer) for layer in self.resblocks: layer.remove_weight_norm() def forward( self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None ) -> torch.FloatTensor: r""" Converts a spectrogram into a speech waveform. Args: spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`): Tensor containing the spectrograms. global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*): Tensor containing speaker embeddings, for multispeaker models. Returns: `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform. """ hidden_states = self.conv_pre(spectrogram) if global_conditioning is not None: hidden_states = hidden_states + self.cond(global_conditioning) for i in range(self.num_upsamples): hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) hidden_states = self.upsampler[i](hidden_states) res_state = self.resblocks[i * self.num_kernels](hidden_states) for j in range(1, self.num_kernels): res_state += self.resblocks[i * self.num_kernels + j](hidden_states) hidden_states = res_state / self.num_kernels hidden_states = nn.functional.leaky_relu(hidden_states) hidden_states = self.conv_post(hidden_states) waveform = torch.tanh(hidden_states) return waveform #.............................................................................................