import math from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel from torch import nn #............................................. class VitsConfig(PretrainedConfig): model_type = "vits" def __init__( self, vocab_size=38, hidden_size=192, num_hidden_layers=6, num_attention_heads=2, window_size=4, use_bias=True, ffn_dim=768, layerdrop=0.1, ffn_kernel_size=3, flow_size=192, spectrogram_bins=513, hidden_act="relu", hidden_dropout=0.1, attention_dropout=0.1, activation_dropout=0.1, initializer_range=0.02, layer_norm_eps=1e-5, use_stochastic_duration_prediction=True, num_speakers=1, speaker_embedding_size=0, upsample_initial_channel=512, upsample_rates=[8, 8, 2, 2], upsample_kernel_sizes=[16, 16, 4, 4], resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], leaky_relu_slope=0.1, depth_separable_channels=2, depth_separable_num_layers=3, duration_predictor_flow_bins=10, duration_predictor_tail_bound=5.0, duration_predictor_kernel_size=3, duration_predictor_dropout=0.5, duration_predictor_num_flows=4, duration_predictor_filter_channels=256, prior_encoder_num_flows=4, prior_encoder_num_wavenet_layers=4, posterior_encoder_num_wavenet_layers=16, wavenet_kernel_size=5, wavenet_dilation_rate=1, wavenet_dropout=0.0, speaking_rate=1.0, noise_scale=0.667, noise_scale_duration=0.8, sampling_rate=16_000, discriminator_kernel_size=5, discriminator_stride=3, discriminator_periods=[2, 3, 5, 7, 11], discriminator_period_channels=[1, 32, 128, 512, 1024], discriminator_scale_channels=[1, 16, 64, 256, 1024], segment_size=8192, hop_length=256, **kwargs, ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.window_size = window_size self.use_bias = use_bias self.ffn_dim = ffn_dim self.layerdrop = layerdrop self.ffn_kernel_size = ffn_kernel_size self.flow_size = flow_size self.spectrogram_bins = spectrogram_bins self.hidden_act = hidden_act self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.activation_dropout = activation_dropout self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.use_stochastic_duration_prediction = use_stochastic_duration_prediction self.num_speakers = num_speakers self.speaker_embedding_size = speaker_embedding_size self.upsample_initial_channel = upsample_initial_channel self.upsample_rates = upsample_rates self.upsample_kernel_sizes = upsample_kernel_sizes self.resblock_kernel_sizes = resblock_kernel_sizes self.resblock_dilation_sizes = resblock_dilation_sizes self.leaky_relu_slope = leaky_relu_slope self.depth_separable_channels = depth_separable_channels self.depth_separable_num_layers = depth_separable_num_layers self.duration_predictor_flow_bins = duration_predictor_flow_bins self.duration_predictor_tail_bound = duration_predictor_tail_bound self.duration_predictor_kernel_size = duration_predictor_kernel_size self.duration_predictor_dropout = duration_predictor_dropout self.duration_predictor_num_flows = duration_predictor_num_flows self.duration_predictor_filter_channels = duration_predictor_filter_channels self.prior_encoder_num_flows = prior_encoder_num_flows self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers self.wavenet_kernel_size = wavenet_kernel_size self.wavenet_dilation_rate = wavenet_dilation_rate self.wavenet_dropout = wavenet_dropout self.speaking_rate = speaking_rate self.noise_scale = noise_scale self.noise_scale_duration = noise_scale_duration self.sampling_rate = sampling_rate # used for training self.discriminator_kernel_size = discriminator_kernel_size self.discriminator_stride = discriminator_stride self.discriminator_periods = discriminator_periods self.discriminator_period_channels = discriminator_period_channels self.discriminator_scale_channels = discriminator_scale_channels self.segment_size = segment_size self.hop_length = hop_length if len(upsample_kernel_sizes) != len(upsample_rates): raise ValueError( f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of " f"`upsample_rates` ({len(upsample_rates)})" ) super().__init__(**kwargs) #............................................................................................. class VitsPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = VitsConfig base_model_prefix = "vits" main_input_name = "input_ids" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() #.............................................................................................