File size: 4,622 Bytes
6faeba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Code taken from https://github.com/tuanh123789/AdaSpeech/blob/main/model/adaspeech_modules.py
By https://github.com/tuanh123789
No license specified

Implemented as outlined in AdaSpeech https://arxiv.org/pdf/2103.00993.pdf
Used in this toolkit similar to how it is done in AdaSpeech 4 https://arxiv.org/pdf/2204.00436.pdf

"""

import torch
from torch import nn


class ConditionalLayerNorm(nn.Module):

    def __init__(self,
                 hidden_dim,
                 speaker_embedding_dim,
                 dim=-1):
        super(ConditionalLayerNorm, self).__init__()
        self.dim = dim
        if isinstance(hidden_dim, int):
            self.normal_shape = hidden_dim
        self.speaker_embedding_dim = speaker_embedding_dim
        self.W_scale = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
                                     nn.Tanh(),
                                     nn.Linear(self.normal_shape, self.normal_shape))
        self.W_bias = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
                                    nn.Tanh(),
                                    nn.Linear(self.normal_shape, self.normal_shape))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.constant_(self.W_scale[0].weight, 0.0)
        torch.nn.init.constant_(self.W_scale[2].weight, 0.0)
        torch.nn.init.constant_(self.W_scale[0].bias, 1.0)
        torch.nn.init.constant_(self.W_scale[2].bias, 1.0)
        torch.nn.init.constant_(self.W_bias[0].weight, 0.0)
        torch.nn.init.constant_(self.W_bias[2].weight, 0.0)
        torch.nn.init.constant_(self.W_bias[0].bias, 0.0)
        torch.nn.init.constant_(self.W_bias[2].bias, 0.0)

    def forward(self, x, speaker_embedding):

        if self.dim != -1:
            x = x.transpose(-1, self.dim)

        mean = x.mean(dim=-1, keepdim=True)
        var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
        scale = self.W_scale(speaker_embedding)
        bias = self.W_bias(speaker_embedding)

        y = scale.unsqueeze(1) * ((x - mean) / var) + bias.unsqueeze(1)

        if self.dim != -1:
            y = y.transpose(-1, self.dim)

        return y


class SequentialWrappableConditionalLayerNorm(nn.Module):

    def __init__(self,
                 hidden_dim,
                 speaker_embedding_dim):
        super(SequentialWrappableConditionalLayerNorm, self).__init__()
        if isinstance(hidden_dim, int):
            self.normal_shape = hidden_dim
        self.speaker_embedding_dim = speaker_embedding_dim
        self.W_scale = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
                                     nn.Tanh(),
                                     nn.Linear(self.normal_shape, self.normal_shape))
        self.W_bias = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape),
                                    nn.Tanh(),
                                    nn.Linear(self.normal_shape, self.normal_shape))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.constant_(self.W_scale[0].weight, 0.0)
        torch.nn.init.constant_(self.W_scale[2].weight, 0.0)
        torch.nn.init.constant_(self.W_scale[0].bias, 1.0)
        torch.nn.init.constant_(self.W_scale[2].bias, 1.0)
        torch.nn.init.constant_(self.W_bias[0].weight, 0.0)
        torch.nn.init.constant_(self.W_bias[2].weight, 0.0)
        torch.nn.init.constant_(self.W_bias[0].bias, 0.0)
        torch.nn.init.constant_(self.W_bias[2].bias, 0.0)

    def forward(self, packed_input):
        x, speaker_embedding = packed_input
        mean = x.mean(dim=-1, keepdim=True)
        var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
        scale = self.W_scale(speaker_embedding)
        bias = self.W_bias(speaker_embedding)

        y = scale.unsqueeze(1) * ((x - mean) / var) + bias.unsqueeze(1)

        return y


class AdaIN1d(nn.Module):
    """
    MIT Licensed

    Copyright (c) 2022 Aaron (Yinghao) Li
    https://github.com/yl4579/StyleTTS/blob/main/models.py
    """

    def __init__(self, style_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm1d(num_features, affine=False)
        self.fc = nn.Linear(style_dim, num_features * 2)

    def forward(self, x, s):
        s = torch.nn.functional.normalize(s)
        h = self.fc(s)
        h = h.view(h.size(0), h.size(1), 1)
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
        return (1 + gamma.transpose(1, 2)) * self.norm(x.transpose(1, 2)).transpose(1, 2) + beta.transpose(1, 2)