|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
|
|
class GST(nn.Module): |
|
"""Global Style Token Module for factorizing prosody in speech. |
|
|
|
See https://arxiv.org/pdf/1803.09017""" |
|
|
|
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None): |
|
super().__init__() |
|
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) |
|
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim) |
|
|
|
def forward(self, inputs, speaker_embedding=None): |
|
enc_out = self.encoder(inputs) |
|
|
|
if speaker_embedding is not None: |
|
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) |
|
style_embed = self.style_token_layer(enc_out) |
|
|
|
return style_embed |
|
|
|
|
|
class ReferenceEncoder(nn.Module): |
|
"""NN module creating a fixed size prosody embedding from a spectrogram. |
|
|
|
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] |
|
outputs: [batch_size, embedding_dim] |
|
""" |
|
|
|
def __init__(self, num_mel, embedding_dim): |
|
super().__init__() |
|
self.num_mel = num_mel |
|
filters = [1] + [32, 32, 64, 64, 128, 128] |
|
num_layers = len(filters) - 1 |
|
convs = [ |
|
nn.Conv2d( |
|
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) |
|
) |
|
for i in range(num_layers) |
|
] |
|
self.convs = nn.ModuleList(convs) |
|
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) |
|
|
|
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers) |
|
self.recurrence = nn.GRU( |
|
input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True |
|
) |
|
|
|
def forward(self, inputs): |
|
batch_size = inputs.size(0) |
|
x = inputs.view(batch_size, 1, -1, self.num_mel) |
|
|
|
for conv, bn in zip(self.convs, self.bns): |
|
x = conv(x) |
|
x = bn(x) |
|
x = F.relu(x) |
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
post_conv_width = x.size(1) |
|
x = x.contiguous().view(batch_size, post_conv_width, -1) |
|
|
|
|
|
self.recurrence.flatten_parameters() |
|
_, out = self.recurrence(x) |
|
|
|
|
|
return out.squeeze(0) |
|
|
|
@staticmethod |
|
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): |
|
"""Height of spec after n convolutions with fixed kernel/stride/pad.""" |
|
for _ in range(n_convs): |
|
height = (height - kernel_size + 2 * pad) // stride + 1 |
|
return height |
|
|
|
|
|
class StyleTokenLayer(nn.Module): |
|
"""NN Module attending to style tokens based on prosody encodings.""" |
|
|
|
def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None): |
|
super().__init__() |
|
|
|
self.query_dim = gst_embedding_dim // 2 |
|
|
|
if d_vector_dim: |
|
self.query_dim += d_vector_dim |
|
|
|
self.key_dim = gst_embedding_dim // num_heads |
|
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim)) |
|
nn.init.normal_(self.style_tokens, mean=0, std=0.5) |
|
self.attention = MultiHeadAttention( |
|
query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads |
|
) |
|
|
|
def forward(self, inputs): |
|
batch_size = inputs.size(0) |
|
prosody_encoding = inputs.unsqueeze(1) |
|
|
|
tokens = torch.tanh(self.style_tokens).unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
style_embed = self.attention(prosody_encoding, tokens) |
|
|
|
return style_embed |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
""" |
|
input: |
|
query --- [N, T_q, query_dim] |
|
key --- [N, T_k, key_dim] |
|
output: |
|
out --- [N, T_q, num_units] |
|
""" |
|
|
|
def __init__(self, query_dim, key_dim, num_units, num_heads): |
|
super().__init__() |
|
self.num_units = num_units |
|
self.num_heads = num_heads |
|
self.key_dim = key_dim |
|
|
|
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) |
|
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) |
|
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) |
|
|
|
def forward(self, query, key): |
|
queries = self.W_query(query) |
|
keys = self.W_key(key) |
|
values = self.W_value(key) |
|
|
|
split_size = self.num_units // self.num_heads |
|
queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0) |
|
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) |
|
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) |
|
|
|
|
|
scores = torch.matmul(queries, keys.transpose(2, 3)) |
|
scores = scores / (self.key_dim**0.5) |
|
scores = F.softmax(scores, dim=3) |
|
|
|
|
|
out = torch.matmul(scores, values) |
|
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) |
|
|
|
return out |
|
|