File size: 3,282 Bytes
6faeba1
 
6a79837
 
6faeba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch

from Modules.EmbeddingModel.GST import GSTStyleEncoder
from Modules.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder


class StyleEmbedding(torch.nn.Module):
    """
    The style embedding should provide information of the speaker and their speaking style

    The feedback signal for the module will come from the TTS objective, so it doesn't have a dedicated train loop.
    The train loop does however supply supervision in the form of a barlow twins objective.

    See the git history for some other approaches for style embedding, like the SWIN transformer
    and a simple LSTM baseline. GST turned out to be the best.
    """

    def __init__(self, embedding_dim=16, style_tts_encoder=False):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.use_gst = not style_tts_encoder
        if style_tts_encoder:
            self.style_encoder = StyleTTSEncoder(style_dim=embedding_dim)
        else:
            self.style_encoder = GSTStyleEncoder(gst_token_dim=embedding_dim)

    def forward(self,
                batch_of_feature_sequences,
                batch_of_feature_sequence_lengths):
        """
        Args:
            batch_of_feature_sequences: b is the batch axis, 128 features per timestep
                                   and l time-steps, which may include padding
                                   for most elements in the batch (b, l, 128)
            batch_of_feature_sequence_lengths: indicate for every element in the batch,
                                          what the true length is, since they are
                                          all padded to the length of the longest
                                          element in the batch (b, 1)
        Returns:
            batch of n dimensional embeddings (b,n)
        """

        minimum_sequence_length = 512
        specs = list()
        for index, spec_length in enumerate(batch_of_feature_sequence_lengths):
            spec = batch_of_feature_sequences[index][:spec_length]
            # double the length at least once, then check
            spec = spec.repeat((2, 1))
            current_spec_length = len(spec)
            while current_spec_length < minimum_sequence_length:
                # make it longer
                spec = spec.repeat((2, 1))
                current_spec_length = len(spec)
            specs.append(spec[:minimum_sequence_length])

        spec_batch = torch.stack(specs, dim=0)
        return self.style_encoder(speech=spec_batch)


if __name__ == '__main__':
    style_emb = StyleEmbedding(style_tts_encoder=False)
    print(f"GST parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}")

    seq_length = 398
    print(style_emb(torch.randn(5, seq_length, 512),
                    torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape)

    style_emb = StyleEmbedding(style_tts_encoder=True)
    print(f"StyleTTS encoder parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}")

    seq_length = 398
    print(style_emb(torch.randn(5, seq_length, 512),
                    torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape)