import torch import torch.nn as nn import torch.nn.functional as F class GST(nn.Module): """Global Style Token Module for factorizing prosody in speech. See https://arxiv.org/pdf/1803.09017""" def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim=None): super().__init__() self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim) def forward(self, inputs, speaker_embedding=None): enc_out = self.encoder(inputs) # concat speaker_embedding if speaker_embedding is not None: enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) style_embed = self.style_token_layer(enc_out) return style_embed class ReferenceEncoder(nn.Module): """NN module creating a fixed size prosody embedding from a spectrogram. inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] outputs: [batch_size, embedding_dim] """ def __init__(self, num_mel, embedding_dim): super().__init__() self.num_mel = num_mel filters = [1] + [32, 32, 64, 64, 128, 128] num_layers = len(filters) - 1 convs = [ nn.Conv2d( in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) for i in range(num_layers) ] self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList([ nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:] ]) post_conv_height = self.calculate_post_conv_height( num_mel, 3, 2, 1, num_layers) self.recurrence = nn.GRU( input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True) def forward(self, inputs): batch_size = inputs.size(0) x = inputs.view(batch_size, 1, -1, self.num_mel) # x: 4D tensor [batch_size, num_channels==1, num_frames, num_mel] for conv, bn in zip(self.convs, self.bns): x = conv(x) x = bn(x) x = F.relu(x) x = x.transpose(1, 2) # x: 4D tensor [batch_size, post_conv_width, # num_channels==128, post_conv_height] post_conv_width = x.size(1) x = x.contiguous().view(batch_size, post_conv_width, -1) # x: 3D tensor [batch_size, post_conv_width, # num_channels*post_conv_height] self.recurrence.flatten_parameters() memory, out = self.recurrence(x) # out: 3D tensor [seq_len==1, batch_size, encoding_size=128] return out.squeeze(0) @staticmethod def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): """Height of spec after n convolutions with fixed kernel/stride/pad.""" for _ in range(n_convs): height = (height - kernel_size + 2 * pad) // stride + 1 return height class StyleTokenLayer(nn.Module): """NN Module attending to style tokens based on prosody encodings.""" def __init__(self, num_heads, num_style_tokens, embedding_dim, speaker_embedding_dim=None): super().__init__() self.query_dim = embedding_dim // 2 if speaker_embedding_dim: self.query_dim += speaker_embedding_dim self.key_dim = embedding_dim // num_heads self.style_tokens = nn.Parameter( torch.FloatTensor(num_style_tokens, self.key_dim)) nn.init.normal_(self.style_tokens, mean=0, std=0.5) self.attention = MultiHeadAttention( query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads) def forward(self, inputs): batch_size = inputs.size(0) prosody_encoding = inputs.unsqueeze(1) # prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128] tokens = torch.tanh(self.style_tokens) \ .unsqueeze(0) \ .expand(batch_size, -1, -1) # tokens: 3D tensor [batch_size, num tokens, token embedding size] style_embed = self.attention(prosody_encoding, tokens) return style_embed class MultiHeadAttention(nn.Module): ''' input: query --- [N, T_q, query_dim] key --- [N, T_k, key_dim] output: out --- [N, T_q, num_units] ''' def __init__(self, query_dim, key_dim, num_units, num_heads): super().__init__() self.num_units = num_units self.num_heads = num_heads self.key_dim = key_dim self.W_query = nn.Linear( in_features=query_dim, out_features=num_units, bias=False) self.W_key = nn.Linear( in_features=key_dim, out_features=num_units, bias=False) self.W_value = nn.Linear( in_features=key_dim, out_features=num_units, bias=False) def forward(self, query, key): queries = self.W_query(query) # [N, T_q, num_units] keys = self.W_key(key) # [N, T_k, num_units] values = self.W_value(key) split_size = self.num_units // self.num_heads queries = torch.stack( torch.split(queries, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] keys = torch.stack( torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] values = torch.stack( torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] # score = softmax(QK^T / (d_k ** 0.5)) scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] scores = scores / (self.key_dim**0.5) scores = F.softmax(scores, dim=3) # out = score * V out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] out = torch.cat( torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] return out