Spaces:
Running
Running
import numpy as np | |
import torch | |
from torch import nn | |
import math | |
from typing import Any, Callable, Optional, Tuple, Union | |
from torch.cuda.amp import autocast, GradScaler | |
from .vits_config import VitsConfig,VitsPreTrainedModel | |
from .flow import VitsResidualCouplingBlock | |
from .duration_predictor import VitsDurationPredictor, VitsStochasticDurationPredictor | |
from .encoder import VitsTextEncoder | |
from .decoder import VitsHifiGan | |
from .posterior_encoder import VitsPosteriorEncoder | |
from .discriminator import VitsDiscriminator | |
from .vits_output import VitsModelOutput, VitsTrainingOutput | |
_CONFIG_FOR_DOC = "VitsConfig" | |
VITS_START_DOCSTRING = r""" | |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the | |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
etc.) | |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | |
and behavior. | |
Parameters: | |
config ([`VitsConfig`]): | |
Model configuration class with all the parameters of the model. Initializing with a config file does not | |
load the weights associated with the model, only the configuration. Check out the | |
[`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
""" | |
VITS_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide | |
it. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, | |
1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
speaker_id (`int`, *optional*): | |
Which speaker embedding to use. Only used for multispeaker models. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |
tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |
more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
class Vits_models_only_decoder(VitsPreTrainedModel): | |
def __init__(self, config: VitsConfig): | |
super().__init__(config) | |
self.config = config | |
self.text_encoder = VitsTextEncoder(config) | |
self.flow = VitsResidualCouplingBlock(config) | |
self.decoder = VitsHifiGan(config) | |
if config.use_stochastic_duration_prediction: | |
self.duration_predictor = VitsStochasticDurationPredictor(config) | |
else: | |
self.duration_predictor = VitsDurationPredictor(config) | |
if config.num_speakers > 1: | |
self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size) | |
# This is used only for training. | |
# self.posterior_encoder = VitsPosteriorEncoder(config) | |
# These parameters control the synthesised speech properties | |
self.speaking_rate = config.speaking_rate | |
self.noise_scale = config.noise_scale | |
self.noise_scale_duration = config.noise_scale_duration | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_encoder(self): | |
return self.text_encoder | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
speaker_id: Optional[int] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
labels: Optional[torch.FloatTensor] = None, | |
) -> Union[Tuple[Any], VitsModelOutput]: | |
r""" | |
labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*): | |
Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss | |
computation. | |
Returns: | |
Example: | |
```python | |
>>> from transformers import VitsTokenizer, VitsModel, set_seed | |
>>> import torch | |
>>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng") | |
>>> model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
>>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt") | |
>>> set_seed(555) # make deterministic | |
>>> with torch.no_grad(): | |
... outputs = model(inputs["input_ids"]) | |
>>> outputs.waveform.shape | |
torch.Size([1, 45824]) | |
``` | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if labels is not None: | |
raise NotImplementedError("Training of VITS is not supported yet.") | |
if attention_mask is not None: | |
input_padding_mask = attention_mask.unsqueeze(-1).float() | |
else: | |
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float() | |
if self.config.num_speakers > 1 and speaker_id is not None: | |
if not 0 <= speaker_id < self.config.num_speakers: | |
raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.") | |
if isinstance(speaker_id, int): | |
speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device) | |
speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) | |
else: | |
speaker_embeddings = None | |
text_encoder_output = self.text_encoder( | |
input_ids=input_ids, | |
padding_mask=input_padding_mask, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state | |
hidden_states = hidden_states.transpose(1, 2) | |
input_padding_mask = input_padding_mask.transpose(1, 2) | |
prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means | |
prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances | |
if self.config.use_stochastic_duration_prediction: | |
log_duration = self.duration_predictor( | |
hidden_states, | |
input_padding_mask, | |
speaker_embeddings, | |
reverse=True, | |
noise_scale=self.noise_scale_duration, | |
) | |
else: | |
log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings) | |
length_scale = 1.0 / self.speaking_rate | |
duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale) | |
predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() | |
# Create a padding mask for the output lengths of shape (batch, 1, max_output_length) | |
indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) | |
output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) | |
output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) | |
# Reconstruct an attention tensor of shape (batch, 1, out_length, in_length) | |
attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) | |
batch_size, _, output_length, input_length = attn_mask.shape | |
cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) | |
indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) | |
valid_indices = indices.unsqueeze(0) < cum_duration | |
valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) | |
padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] | |
attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask | |
# Expand prior distribution | |
prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2) | |
prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2) | |
prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale | |
latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True) | |
spectrogram = latents * output_padding_mask | |
return spectrogram | |
# waveform = self.decoder(spectrogram, speaker_embeddings) | |
# waveform = waveform.squeeze(1) | |
# sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates) | |
# if not return_dict: | |
# outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:] | |
# return outputs | |
# return VitsModelOutput( | |
# waveform=waveform, | |
# sequence_lengths=sequence_lengths, | |
# spectrogram=spectrogram, | |
# hidden_states=text_encoder_output.hidden_states, | |
# attentions=text_encoder_output.attentions, | |
# ) |