Rongjiehuang's picture
update
222619b
from data_gen.tts.emotion.params_model import *
from data_gen.tts.emotion.params_data import *
from torch.nn.utils import clip_grad_norm_
from scipy.optimize import brentq
from torch import nn
import numpy as np
import torch
class EmotionEncoder(nn.Module):
def __init__(self, device, loss_device):
super().__init__()
self.loss_device = loss_device
# Network defition
self.lstm = nn.LSTM(input_size=mel_n_channels,
hidden_size=model_hidden_size,
num_layers=model_num_layers,
batch_first=True).to(device)
self.linear = nn.Linear(in_features=model_hidden_size,
out_features=model_embedding_size).to(device)
self.relu = torch.nn.ReLU().to(device)
# Cosine similarity scaling (with fixed initial parameter values)
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
# Loss
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
def do_gradient_ops(self):
# Gradient scale
self.similarity_weight.grad *= 0.01
self.similarity_bias.grad *= 0.01
# Gradient clipping
clip_grad_norm_(self.parameters(), 3, norm_type=2)
def forward(self, utterances, hidden_init=None):
"""
Computes the embeddings of a batch of utterance spectrograms.
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
(batch_size, n_frames, n_channels)
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
batch_size, hidden_size). Will default to a tensor of zeros if None.
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
"""
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
# and the final cell state.
out, (hidden, cell) = self.lstm(utterances, hidden_init)
# We take only the hidden state of the last layer
embeds_raw = self.relu(self.linear(hidden[-1]))
# L2-normalize it
embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
return embeds
def inference(self, utterances, hidden_init=None):
"""
Computes the embeddings of a batch of utterance spectrograms.
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
(batch_size, n_frames, n_channels)
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
batch_size, hidden_size). Will default to a tensor of zeros if None.
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
"""
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
# and the final cell state.
out, (hidden, cell) = self.lstm(utterances, hidden_init)
return hidden[-1]