ONNXServies / VitsModelSplit /vits_models_only_decoder.py
wasmdashai's picture
Update VitsModelSplit/vits_models_only_decoder.py
3e22085 verified
import numpy as np
import torch
from torch import nn
import math
from typing import Any, Callable, Optional, Tuple, Union
from torch.cuda.amp import autocast, GradScaler
from .vits_config import VitsConfig,VitsPreTrainedModel
from .flow import VitsResidualCouplingBlock
from .duration_predictor import VitsDurationPredictor, VitsStochasticDurationPredictor
from .encoder import VitsTextEncoder
from .decoder import VitsHifiGan
from .posterior_encoder import VitsPosteriorEncoder
from .discriminator import VitsDiscriminator
from .vits_output import VitsModelOutput, VitsTrainingOutput
_CONFIG_FOR_DOC = "VitsConfig"
VITS_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`VitsConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
VITS_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
speaker_id (`int`, *optional*):
Which speaker embedding to use. Only used for multispeaker models.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The complete VITS model, for text-to-speech synthesis.",
VITS_START_DOCSTRING,
)
class Vits_models_only_decoder(VitsPreTrainedModel):
def __init__(self, config: VitsConfig):
super().__init__(config)
self.config = config
self.text_encoder = VitsTextEncoder(config)
self.flow = VitsResidualCouplingBlock(config)
self.decoder = VitsHifiGan(config)
if config.use_stochastic_duration_prediction:
self.duration_predictor = VitsStochasticDurationPredictor(config)
else:
self.duration_predictor = VitsDurationPredictor(config)
if config.num_speakers > 1:
self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
# This is used only for training.
# self.posterior_encoder = VitsPosteriorEncoder(config)
# These parameters control the synthesised speech properties
self.speaking_rate = config.speaking_rate
self.noise_scale = config.noise_scale
self.noise_scale_duration = config.noise_scale_duration
# Initialize weights and apply final processing
self.post_init()
def get_encoder(self):
return self.text_encoder
@add_start_docstrings_to_model_forward(VITS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VitsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
speaker_id: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.FloatTensor] = None,
) -> Union[Tuple[Any], VitsModelOutput]:
r"""
labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
computation.
Returns:
Example:
```python
>>> from transformers import VitsTokenizer, VitsModel, set_seed
>>> import torch
>>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
>>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
>>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
>>> set_seed(555) # make deterministic
>>> with torch.no_grad():
... outputs = model(inputs["input_ids"])
>>> outputs.waveform.shape
torch.Size([1, 45824])
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
raise NotImplementedError("Training of VITS is not supported yet.")
if attention_mask is not None:
input_padding_mask = attention_mask.unsqueeze(-1).float()
else:
input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float()
if self.config.num_speakers > 1 and speaker_id is not None:
if not 0 <= speaker_id < self.config.num_speakers:
raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
if isinstance(speaker_id, int):
speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
else:
speaker_embeddings = None
text_encoder_output = self.text_encoder(
input_ids=input_ids,
padding_mask=input_padding_mask,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
hidden_states = hidden_states.transpose(1, 2)
input_padding_mask = input_padding_mask.transpose(1, 2)
prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
if self.config.use_stochastic_duration_prediction:
log_duration = self.duration_predictor(
hidden_states,
input_padding_mask,
speaker_embeddings,
reverse=True,
noise_scale=self.noise_scale_duration,
)
else:
log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
length_scale = 1.0 / self.speaking_rate
duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
# Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
# Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
batch_size, _, output_length, input_length = attn_mask.shape
cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
valid_indices = indices.unsqueeze(0) < cum_duration
valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
# Expand prior distribution
prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
spectrogram = latents * output_padding_mask
return spectrogram
# waveform = self.decoder(spectrogram, speaker_embeddings)
# waveform = waveform.squeeze(1)
# sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
# if not return_dict:
# outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
# return outputs
# return VitsModelOutput(
# waveform=waveform,
# sequence_lengths=sequence_lengths,
# spectrogram=spectrogram,
# hidden_states=text_encoder_output.hidden_states,
# attentions=text_encoder_output.attentions,
# )