tacotron2-gst-en / Decoder.py
mireiafarrus's picture
tacotron2 and hifigan upload
af7ac2b
raw
history blame
17.6 kB
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
from nn_layers import linear_module, location_layer
from utils import get_mask_from_lengths
torch.manual_seed(1234)
class AttentionNet(nn.Module):
# 1024, 512, 128, 32, 31
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size):
super(AttentionNet, self).__init__()
self.query_layer = linear_module(attention_rnn_dim, attention_dim,
bias=False, w_init_gain='tanh')
# Projecting inputs into 128-D hidden representation
self.memory_layer = linear_module(embedding_dim, attention_dim, bias=False,
w_init_gain='tanh')
# Projecting into 1-D scalar value
self.v = linear_module(attention_dim, 1, bias=False)
# Convolutional layers to obtain location features and projecting them into 128-D hidden representation
self.location_layer = location_layer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
self.score_mask_value = -float("inf")
def get_alignment_energies(self, query, processed_memory,
attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(torch.tanh(
processed_query + processed_attention_weights + processed_memory))
energies = energies.squeeze(-1) # eliminates the third dimension of the tensor, which is 1.
return energies
def forward(self, attention_hidden_state, memory, processed_memory,
attention_weights_cat, mask):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
# I think attention_weights is a [BxNUMENCINPUTS] so with unsequeeze(1): [Bx1xNUMENCINPUTS] and memory is
# [BxNUMENCINPUTSx512]
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class Prenet(nn.Module):
def __init__(self, in_dim, sizes):
super(Prenet, self).__init__()
in_sizes = [in_dim] + sizes[:-1] # all list values but the last one. The result is a list of the in_dim element
# concatenated with sizes of layers (i.e. [80, 256])
self.layers = nn.ModuleList(
[linear_module(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_sizes, sizes)])
def forward(self, x):
for linear in self.layers:
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
return x
class Decoder(nn.Module):
def __init__(self, tacotron_hyperparams):
super(Decoder, self).__init__()
self.n_mel_channels = tacotron_hyperparams['n_mel_channels']
self.n_frames_per_step = tacotron_hyperparams['number_frames_step']
self.encoder_embedding_dim = tacotron_hyperparams['encoder_embedding_dim']
self.attention_rnn_dim = tacotron_hyperparams['attention_rnn_dim'] # 1024
self.decoder_rnn_dim = tacotron_hyperparams['decoder_rnn_dim'] # 1024
self.prenet_dim = tacotron_hyperparams['prenet_dim']
self.max_decoder_steps = tacotron_hyperparams['max_decoder_steps']
# The threshold to decide whether stop or not stop decoding?
self.gate_threshold = tacotron_hyperparams['gate_threshold']
self.p_attention_dropout = tacotron_hyperparams['p_attention_dropout']
self.p_decoder_dropout = tacotron_hyperparams['p_decoder_dropout']
# Define the prenet: there is only one frame per step, so input dim is the number of mel channels.
# There are two fully connected layers:
self.prenet = Prenet(
tacotron_hyperparams['n_mel_channels'] * tacotron_hyperparams['number_frames_step'],
[tacotron_hyperparams['prenet_dim'], tacotron_hyperparams['prenet_dim']])
# input_size: 1024 + 512 (output of first LSTM cell + attention_context) / hidden_size: 1024
self.attention_rnn = nn.LSTMCell(
tacotron_hyperparams['prenet_dim'] + tacotron_hyperparams['encoder_embedding_dim'],
tacotron_hyperparams['attention_rnn_dim'])
# return attention_weights and attention_context. Does the alignments.
self.attention_layer = AttentionNet(
tacotron_hyperparams['attention_rnn_dim'], tacotron_hyperparams['encoder_embedding_dim'],
tacotron_hyperparams['attention_dim'], tacotron_hyperparams['attention_location_n_filters'],
tacotron_hyperparams['attention_location_kernel_size'])
# input_size: 256 + 512 (attention_context + prenet_info), hidden_size: 1024
self.decoder_rnn = nn.LSTMCell(
tacotron_hyperparams['attention_rnn_dim'] + tacotron_hyperparams['encoder_embedding_dim'],
tacotron_hyperparams['decoder_rnn_dim'], 1)
# (LSTM output)1024 + (attention_context)512, out_dim: number of mel channels. Last linear projection that
# generates an output decoder spectral frame.
self.linear_projection = linear_module(
tacotron_hyperparams['decoder_rnn_dim'] + tacotron_hyperparams['encoder_embedding_dim'],
tacotron_hyperparams['n_mel_channels']*tacotron_hyperparams['number_frames_step'])
# decision whether to continue decoding.
self.gate_layer = linear_module(
tacotron_hyperparams['decoder_rnn_dim'] + tacotron_hyperparams['encoder_embedding_dim'], 1,
bias=True, w_init_gain='sigmoid')
def get_go_frame(self, memory):
""" Gets all zeros frames to use as first decoder input
PARAMS
------
memory: decoder outputs
RETURNS
-------
decoder_input: all zeros frames
"""
B = memory.size(0)
decoder_input = Variable(memory.data.new(
B, self.n_mel_channels * self.n_frames_per_step).zero_())
return decoder_input
def initialize_decoder_states(self, memory, mask):
""" Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
PARAMS
------
memory: Encoder outputs
mask: Mask for padded data if training, expects None for inference
"""
B = memory.size(0)
MAX_TIME = memory.size(1)
self.attention_hidden = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
self.attention_cell = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
self.decoder_hidden = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.decoder_cell = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.attention_weights = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(
B, self.encoder_embedding_dim).zero_())
self.memory = memory
self.processed_memory = self.attention_layer.memory_layer(memory)
self.mask = mask
def parse_decoder_inputs(self, decoder_inputs):
""" Prepares decoder inputs, i.e. mel outputs
PARAMS
------
decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
RETURNS
-------
inputs: processed decoder inputs
"""
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
decoder_inputs = decoder_inputs.transpose(1, 2)
# reshape decoder inputs in case we want to work with more than 1 frame per step (chunks). Otherwise, this next
# line does not just do anything
decoder_inputs = decoder_inputs.view(
decoder_inputs.size(0),
int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
decoder_inputs = decoder_inputs.transpose(0, 1)
return decoder_inputs
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
""" Prepares decoder outputs for output
PARAMS
------
mel_outputs:
gate_outputs: gate output energies
alignments:
RETURNS
-------
mel_outputs:
gate_outpust: gate output energies
alignments:
"""
# (T_out, B) -> (B, T_out)
alignments = torch.stack(alignments).transpose(0, 1)
# (T_out, B) -> (B, T_out)
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
gate_outputs = gate_outputs.contiguous()
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
# decouple frames per step
mel_outputs = mel_outputs.view(
mel_outputs.size(0), -1, self.n_mel_channels)
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
mel_outputs = mel_outputs.transpose(1, 2)
return mel_outputs, gate_outputs, alignments
def decode(self, decoder_input):
""" Decoder step using stored states, attention and memory
PARAMS
------
decoder_input: previous mel output
RETURNS
-------
mel_output:
gate_output: gate output energies
attention_weights:
"""
# concatenates [Bx1024] and [Bx512]. All dimensions match except 1 (torch.cat -1)
# concatenate the i-th decoder hidden state together with the i-th attention context
cell_input = torch.cat((decoder_input, self.attention_context), -1)
# the previous input is for the following LSTM cell, initialized with zeroes the hidden states and the cell
# state.
# compute the (i+1)th attention hidden state based on the i-th decoder hidden state and attention context.
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(self.attention_hidden, self.p_attention_dropout, self.training)
self.attention_cell = F.dropout(self.attention_cell, self.p_attention_dropout, self.training)
# concatenate the i-th state attention weights together with the cumulated from previous states to compute
# (i+1)th state
attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)), dim=1)
# compute (i+1)th attention context and provide (i+1)th attention weights based on the (i+1)th attention hidden
# state and (i)th and prev. weights
self.attention_context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.memory, self.processed_memory,
attention_weights_cat, self.mask)
# cumulate attention_weights adding the (i+1)th to compute (i+2)th state
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input,
(self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training)
self.decoder_cell = F.dropout(self.decoder_cell, self.p_decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1)
decoder_output = self.linear_projection(decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
"""
# the decoder_output from ith step passes through the pre-net to compute new decoder hidden state and attention_
# context (i+1)th
prenet_output = self.prenet(decoder_input)
# the decoder_input now is the concatenation of the pre-net output and the new (i+1)th attention_context
decoder_input = torch.cat((prenet_output, self.attention_context), -1)
# another LSTM Cell to compute the decoder hidden (i+1)th state from the decoder_input
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
# with new attention_context we concatenate again with the new (i+1)th decoder_hidden state.
decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
# the (i+1)th output is a linear projection of the decoder hidden state with a weight matrix plus bias.
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
# check whether (i+1)th state is the last of the sequence
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights"""
def forward(self, memory, decoder_inputs, memory_lengths):
""" Decoder forward pass for training
PARAMS
------
memory: Encoder outputs
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
memory_lengths: Encoder output lengths for attention masking.
RETURNS
-------
mel_outputs: mel outputs from the decoder
gate_outputs: gate outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
decoder_input = self.get_go_frame(memory).unsqueeze(0)
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
decoder_inputs = self.prenet(decoder_inputs)
self.initialize_decoder_states(
memory, mask=~get_mask_from_lengths(memory_lengths))
mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
mel_output, gate_output, attention_weights = self.decode(
decoder_input)
# a class list, when += means concatenation of vectors
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
alignments += [attention_weights]
# getting the frame indexing from reference mel frames to pass it as the new input of the next decoding
# step: Teacher Forcing!
# Takes each time_step of sequences of all mini-batch samples (i.e. [48, 80] as the decoder_inputs is
# parsed as [189, 48, 80]).
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
mel_outputs, gate_outputs, alignments)
return mel_outputs, gate_outputs, alignments
def inference(self, memory):
""" Decoder inference
PARAMS
------
memory: Encoder outputs
RETURNS
-------
mel_outputs: mel outputs from the decoder
gate_outputs: gate outputs from the decoder
alignments: sequence of attention weights from the decoder
"""
decoder_input = self.get_go_frame(memory)
self.initialize_decoder_states(memory, mask=None)
mel_outputs, gate_outputs, alignments = [], [], []
while True:
decoder_input = self.prenet(decoder_input)
mel_output, gate_output, alignment = self.decode(decoder_input)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
alignments += [alignment]
if torch.sigmoid(gate_output.data) > self.gate_threshold:
break
elif len(mel_outputs) == self.max_decoder_steps:
print("Warning! Reached max decoder steps")
break
decoder_input = mel_output
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
mel_outputs, gate_outputs, alignments)
return mel_outputs, gate_outputs, alignments