voice-xtts2 / TTS /tts /tf /models /tacotron2.py
antoniomae1234's picture
changes in flenema
2493d72 verified
import tensorflow as tf
from tensorflow import keras
from TTS.tts.tf.layers.tacotron2 import Encoder, Decoder, Postnet
from TTS.tts.tf.utils.tf_utils import shape_list
#pylint: disable=too-many-ancestors, abstract-method
class Tacotron2(keras.models.Model):
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type='original',
attn_win=False,
attn_norm="softmax",
attn_K=4,
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True,
bidirectional_decoder=False,
enable_tflite=False):
super(Tacotron2, self).__init__()
self.r = r
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
self.bidirectional_decoder = bidirectional_decoder
self.num_speakers = num_speakers
self.speaker_embed_dim = 256
self.enable_tflite = enable_tflite
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
self.encoder = Encoder(512, name='encoder')
# TODO: most of the decoder args have no use at the momment
self.decoder = Decoder(decoder_output_dim,
r,
attn_type=attn_type,
use_attn_win=attn_win,
attn_norm=attn_norm,
prenet_type=prenet_type,
prenet_dropout=prenet_dropout,
use_forward_attn=forward_attn,
use_trans_agent=trans_agent,
use_forward_attn_mask=forward_attn_mask,
use_location_attn=location_attn,
attn_K=attn_K,
separate_stopnet=separate_stopnet,
speaker_emb_dim=self.speaker_embed_dim,
name='decoder',
enable_tflite=enable_tflite)
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
@tf.function(experimental_relax_shapes=True)
def call(self, characters, text_lengths=None, frames=None, training=None):
if training:
return self.training(characters, text_lengths, frames)
if not training:
return self.inference(characters)
raise RuntimeError(' [!] Set model training mode True or False')
def training(self, characters, text_lengths, frames):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=True)
encoder_output = self.encoder(embedding_vectors, training=True)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
postnet_frames = self.postnet(decoder_frames, training=True)
output_frames = decoder_frames + postnet_frames
return decoder_frames, output_frames, attentions, stop_tokens
def inference(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
encoder_output = self.encoder(embedding_vectors, training=False)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
postnet_frames = self.postnet(decoder_frames, training=False)
output_frames = decoder_frames + postnet_frames
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens
@tf.function(
experimental_relax_shapes=True,
input_signature=[
tf.TensorSpec([1, None], dtype=tf.int32),
],)
def inference_tflite(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
encoder_output = self.encoder(embedding_vectors, training=False)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
postnet_frames = self.postnet(decoder_frames, training=False)
output_frames = decoder_frames + postnet_frames
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens
def build_inference(self, ):
# TODO: issue https://github.com/PyCQA/pylint/issues/3613
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) #pylint: disable=unexpected-keyword-arg
self(input_ids)