File size: 4,829 Bytes
af7ac2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from math import sqrt

import torch
from torch import nn

from Encoder import Encoder
from Decoder import Decoder
from Postnet import Postnet
from GST import GST

from utils import to_gpu, get_mask_from_lengths
from fp16_optimizer import fp32_to_fp16, fp16_to_fp32

torch.manual_seed(1234)


class tacotron_2(nn.Module):
    def __init__(self, tacotron_hyperparams):
        super(tacotron_2, self).__init__()
        self.mask_padding = tacotron_hyperparams['mask_padding']
        self.fp16_run = tacotron_hyperparams['fp16_run']
        self.n_mel_channels = tacotron_hyperparams['n_mel_channels']
        self.n_frames_per_step = tacotron_hyperparams['number_frames_step']
        self.embedding = nn.Embedding(
            tacotron_hyperparams['n_symbols'], tacotron_hyperparams['symbols_embedding_length'])
        # CHECK THIS OUT!!!
        std = sqrt(2.0 / (tacotron_hyperparams['n_symbols'] + tacotron_hyperparams['symbols_embedding_length']))
        val = sqrt(3.0) * std
        self.embedding.weight.data.uniform_(-val, val)
        self.encoder = Encoder(tacotron_hyperparams)
        self.decoder = Decoder(tacotron_hyperparams)
        self.postnet = Postnet(tacotron_hyperparams)
        self.gst = GST(tacotron_hyperparams)

    def parse_batch(self, batch):
        # GST I add the new tensor from prosody features to train GST tokens:
        text_padded, input_lengths, mel_padded, gate_padded, output_lengths, prosody_padded = batch
        text_padded = to_gpu(text_padded).long()
        max_len = int(torch.max(input_lengths.data).item())  # With item() you get the pure value (not in a tensor)
        input_lengths = to_gpu(input_lengths).long()
        mel_padded = to_gpu(mel_padded).float()
        gate_padded = to_gpu(gate_padded).float()
        output_lengths = to_gpu(output_lengths).long()
        prosody_padded = to_gpu(prosody_padded).float()

        return (
            (text_padded, input_lengths, mel_padded, max_len, output_lengths, prosody_padded),
            (mel_padded, gate_padded))

    def parse_input(self, inputs):
        inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
        return inputs

    def parse_output(self, outputs, output_lengths=None):
        if self.mask_padding and output_lengths is not None:
            mask = ~get_mask_from_lengths(output_lengths)
            mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            outputs[0].data.masked_fill_(mask, 0.0)
            outputs[1].data.masked_fill_(mask, 0.0)
            outputs[2].data.masked_fill_(mask[:, 0, :], 1e3)  # gate energies

        outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs

        return outputs

    def forward(self, inputs):
        inputs, input_lengths, targets, max_len, output_lengths, gst_prosody_padded = self.parse_input(inputs)
        input_lengths, output_lengths = input_lengths.data, output_lengths.data

        embedded_inputs = self.embedding(inputs).transpose(1, 2)

        encoder_outputs = self.encoder(embedded_inputs, input_lengths)

        # GST style embedding plus embedded_inputs before entering the decoder
        # bin_locations = gst_prosody_padded[:, 0, :]
        # pitch_intensities = gst_prosody_padded[:, 1:, :]
        # bin_locations = bin_locations.unsqueeze(2)
        gst_style_embedding, gst_scores = self.gst(gst_prosody_padded, output_lengths)  # [N, 512]
        gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs)

        encoder_outputs = encoder_outputs + gst_style_embedding

        mel_outputs, gate_outputs, alignments = self.decoder(
            encoder_outputs, targets, memory_lengths=input_lengths)
        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        return self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments, gst_scores],
            output_lengths)

    def inference(self, inputs, gst_scores):  # gst_scores must be a torch tensor
        inputs = self.parse_input(inputs)
        embedded_inputs = self.embedding(inputs).transpose(1, 2)
        encoder_outputs = self.encoder.inference(embedded_inputs)

        # GST inference:
        gst_style_embedding = self.gst.inference(gst_scores)
        gst_style_embedding = gst_style_embedding.expand_as(encoder_outputs)

        encoder_outputs = encoder_outputs + gst_style_embedding

        mel_outputs, gate_outputs, alignments = self.decoder.inference(
            encoder_outputs)

        mel_outputs_postnet = self.postnet(mel_outputs)
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet

        outputs = self.parse_output(
            [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])

        return outputs