ntt123 commited on
Commit
52b79f2
1 Parent(s): 4bae968
Files changed (10) hide show
  1. app.py +309 -0
  2. attentions.py +329 -0
  3. commons.py +162 -0
  4. config.json +72 -0
  5. flow.py +120 -0
  6. models.py +489 -0
  7. modules.py +356 -0
  8. packages.txt +1 -0
  9. phone_set.json +1 -0
  10. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # isort:skip
2
+
3
+ torch.manual_seed(42)
4
+ import json
5
+ import re
6
+ import unicodedata
7
+ from types import SimpleNamespace
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import regex
12
+
13
+ from models import DurationNet, SynthesizerTrn
14
+
15
+ title = "LightSpeed: Vietnamese Feale Voice TTS"
16
+ description = "Vietnam Feale Voice TTS."
17
+ config_file = "config.json"
18
+ duration_model_path = "duration_model.pth"
19
+ lightspeed_model_path = "gen_210k.pth"
20
+ phone_set_file = "phone_set.json"
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ with open(config_file, "rb") as f:
23
+ hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x))
24
+
25
+ # load phone set json file
26
+ with open(phone_set_file, "r") as f:
27
+ phone_set = json.load(f)
28
+
29
+ assert phone_set[0][1:-1] == "SEP"
30
+ assert "sil" in phone_set
31
+ sil_idx = phone_set.index("sil")
32
+
33
+ vietnamese_characters = [
34
+ "a",
35
+ "à",
36
+ "á",
37
+ "ả",
38
+ "ã",
39
+ "ạ",
40
+ "ă",
41
+ "ằ",
42
+ "ắ",
43
+ "ẳ",
44
+ "ẵ",
45
+ "ặ",
46
+ "â",
47
+ "ầ",
48
+ "ấ",
49
+ "ẩ",
50
+ "ẫ",
51
+ "ậ",
52
+ "e",
53
+ "è",
54
+ "é",
55
+ "ẻ",
56
+ "ẽ",
57
+ "ẹ",
58
+ "ê",
59
+ "ề",
60
+ "ế",
61
+ "ể",
62
+ "ễ",
63
+ "ệ",
64
+ "i",
65
+ "ì",
66
+ "í",
67
+ "ỉ",
68
+ "ĩ",
69
+ "ị",
70
+ "o",
71
+ "ò",
72
+ "ó",
73
+ "ỏ",
74
+ "õ",
75
+ "ọ",
76
+ "ô",
77
+ "ồ",
78
+ "ố",
79
+ "ổ",
80
+ "ỗ",
81
+ "ộ",
82
+ "ơ",
83
+ "ờ",
84
+ "ớ",
85
+ "ở",
86
+ "ỡ",
87
+ "ợ",
88
+ "u",
89
+ "ù",
90
+ "ú",
91
+ "ủ",
92
+ "ũ",
93
+ "ụ",
94
+ "ư",
95
+ "ừ",
96
+ "ứ",
97
+ "ử",
98
+ "ữ",
99
+ "ự",
100
+ "y",
101
+ "ỳ",
102
+ "ý",
103
+ "ỷ",
104
+ "ỹ",
105
+ "ỵ",
106
+ "b",
107
+ "c",
108
+ "d",
109
+ "đ",
110
+ "g",
111
+ "h",
112
+ "k",
113
+ "l",
114
+ "m",
115
+ "n",
116
+ "p",
117
+ "q",
118
+ "r",
119
+ "s",
120
+ "t",
121
+ "v",
122
+ "x",
123
+ ]
124
+ alphabet = "".join(vietnamese_characters)
125
+ space_re = regex.compile(r"\s+")
126
+ number_re = regex.compile("([0-9]+)")
127
+ digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"]
128
+ num_re = regex.compile(r"([0-9.,]*[0-9])")
129
+ keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]")
130
+ keep_text_re = regex.compile(rf"[^\s{alphabet}]")
131
+
132
+
133
+ def read_number(num: str) -> str:
134
+ if len(num) == 1:
135
+ return digits[int(num)]
136
+ elif len(num) == 2 and num.isdigit():
137
+ n = int(num)
138
+ end = digits[n % 10]
139
+ if n == 10:
140
+ return "mười"
141
+ if n % 10 == 5:
142
+ end = "lăm"
143
+ if n % 10 == 0:
144
+ return digits[n // 10] + " mươi"
145
+ elif n < 20:
146
+ return "mười " + end
147
+ else:
148
+ if n % 10 == 1:
149
+ end = "mốt"
150
+ return digits[n // 10] + " mươi " + end
151
+ elif len(num) == 3 and num.isdigit():
152
+ n = int(num)
153
+ if n % 100 == 0:
154
+ return digits[n // 100] + " trăm"
155
+ elif num[1] == "0":
156
+ return digits[n // 100] + " trăm lẻ " + digits[n % 100]
157
+ else:
158
+ return digits[n // 100] + " trăm " + read_number(num[1:])
159
+ elif len(num) >= 4 and len(num) <= 6 and num.isdigit():
160
+ n = int(num)
161
+ n1 = n // 1000
162
+ return read_number(str(n1)) + " ngàn " + read_number(num[-3:])
163
+ elif "," in num:
164
+ n1, n2 = num.split(",")
165
+ return read_number(n1) + " phẩy " + read_number(n2)
166
+ elif "." in num:
167
+ parts = num.split(".")
168
+ if len(parts) == 2:
169
+ if parts[1] == "000":
170
+ return read_number(parts[0]) + " ngàn"
171
+ elif parts[1].startswith("00"):
172
+ end = digits[int(parts[1][2:])]
173
+ return read_number(parts[0]) + " ngàn lẻ " + end
174
+ else:
175
+ return read_number(parts[0]) + " ngàn " + read_number(parts[1])
176
+ elif len(parts) == 3:
177
+ return (
178
+ read_number(parts[0])
179
+ + " triệu "
180
+ + read_number(parts[1])
181
+ + " ngàn "
182
+ + read_number(parts[2])
183
+ )
184
+ return num
185
+
186
+
187
+ def text_to_phone_idx(text):
188
+ # lowercase
189
+ text = text.lower()
190
+ # unicode normalize
191
+ text = unicodedata.normalize("NFKC", text)
192
+ text = text.replace(".", " . ")
193
+ text = text.replace(",", " , ")
194
+ text = text.replace(";", " ; ")
195
+ text = text.replace(":", " : ")
196
+ text = text.replace("!", " ! ")
197
+ text = text.replace("?", " ? ")
198
+ text = text.replace("(", " ( ")
199
+
200
+ text = num_re.sub(r" \1 ", text)
201
+ words = text.split()
202
+ words = [read_number(w) if num_re.fullmatch(w) else w for w in words]
203
+ text = " ".join(words)
204
+
205
+ # remove redundant spaces
206
+ text = re.sub(r"\s+", " ", text)
207
+ # remove leading and trailing spaces
208
+ text = text.strip()
209
+ # convert words to phone indices
210
+ tokens = []
211
+ for c in text:
212
+ # if c is "," or ".", add <sil> phone
213
+ if c in ":,.!?;(":
214
+ tokens.append(sil_idx)
215
+ elif c in phone_set:
216
+ tokens.append(phone_set.index(c))
217
+ elif c == " ":
218
+ # add <sep> phone
219
+ tokens.append(0)
220
+ if tokens[0] != sil_idx:
221
+ # insert <sil> phone at the beginning
222
+ tokens = [sil_idx, 0] + tokens
223
+ if tokens[-1] != sil_idx:
224
+ tokens = tokens + [0, sil_idx]
225
+ return tokens
226
+
227
+
228
+ def text_to_speech(text):
229
+ # prevent too long text
230
+ if len(text) > 500:
231
+ text = text[:500]
232
+
233
+ phone_idx = text_to_phone_idx(text)
234
+ batch = {
235
+ "phone_idx": np.array([phone_idx]),
236
+ "phone_length": np.array([len(phone_idx)]),
237
+ }
238
+
239
+ # predict phoneme duration
240
+ duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device)
241
+ duration_net.load_state_dict(torch.load(duration_model_path, map_location=device))
242
+ duration_net = duration_net.eval()
243
+ phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device)
244
+ phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device)
245
+ with torch.inference_mode():
246
+ phone_duration = duration_net(phone_idx, phone_length)[:, :, 0] * 1000
247
+ phone_duration = torch.where(
248
+ phone_idx == sil_idx, torch.clamp_min(phone_duration, 200), phone_duration
249
+ )
250
+ phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
251
+
252
+ generator = SynthesizerTrn(
253
+ hps.data.vocab_size,
254
+ hps.data.filter_length // 2 + 1,
255
+ hps.train.segment_size // hps.data.hop_length,
256
+ **vars(hps.model),
257
+ ).to(device)
258
+ del generator.enc_q
259
+ ckpt = torch.load(lightspeed_model_path, map_location=device)
260
+ params = {}
261
+ for k, v in ckpt["net_g"].items():
262
+ k = k[7:] if k.startswith("module.") else k
263
+ params[k] = v
264
+ generator.load_state_dict(params, strict=False)
265
+ del ckpt, params
266
+ generator = generator.eval()
267
+ # mininum 1 frame for each phone
268
+ # phone_duration = torch.clamp_min(phone_duration, hps.data.hop_length * 1000 / hps.data.sampling_rate)
269
+ # phone_duration = torch.where(phone_idx == 0, 0, phone_duration)
270
+ end_time = torch.cumsum(phone_duration, dim=-1)
271
+ start_time = end_time - phone_duration
272
+ start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
273
+ end_frame = end_time / 1000 * hps.data.sampling_rate / hps.data.hop_length
274
+ spec_length = end_frame.max(dim=-1).values
275
+ pos = torch.arange(0, spec_length.item(), device=device)
276
+ attn = torch.logical_and(
277
+ pos[None, :, None] >= start_frame[:, None, :],
278
+ pos[None, :, None] < end_frame[:, None, :],
279
+ ).float()
280
+ with torch.inference_mode():
281
+ y_hat = generator.infer(
282
+ phone_idx, phone_length, spec_length, attn, max_len=None, noise_scale=0.0
283
+ )[0]
284
+ wave = y_hat[0, 0].data.cpu().numpy()
285
+ return (wave * (2**15)).astype(np.int16)
286
+
287
+
288
+ def speak(text):
289
+ y = text_to_speech(text)
290
+ return hps.data.sampling_rate, y
291
+
292
+
293
+ gr.Interface(
294
+ fn=speak,
295
+ inputs="text",
296
+ outputs="audio",
297
+ title=title,
298
+ examples=[
299
+ "Trăm năm trong cõi người ta, chữ tài chữ mệnh khéo là ghét nhau.",
300
+ "Đoạn trường tân thanh, thường được biết đến với cái tên đơn giản là Truyện Kiều, là một truyện thơ của đại thi hào Nguyễn Du",
301
+ "Lục Vân Tiên quê ở huyện Đông Thành, khôi ngô tuấn tú, tài kiêm văn võ. Nghe tin triều đình mở khoa thi, Vân Tiên từ giã thầy xuống núi đua tài.",
302
+ "Lê Quý Đôn, tên thuở nhỏ là Lê Danh Phương, là vị quan thời Lê trung hưng, cũng là nhà thơ và được mệnh danh là nhà bác học lớn của Việt Nam trong thời phong kiến",
303
+ "Tất cả mọi người đều sinh ra có quyền bình đẳng. Tạo hóa cho họ những quyền không ai có thể xâm phạm được; trong những quyền ấy, có quyền được sống, quyền tự do và quyền mưu cầu hạnh phúc.",
304
+ ],
305
+ description=description,
306
+ theme="default",
307
+ allow_screenshot=False,
308
+ allow_flagging="never",
309
+ ).launch(debug=False)
attentions.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ from modules import LayerNorm
9
+
10
+
11
+ class Encoder(nn.Module):
12
+ def __init__(
13
+ self,
14
+ hidden_channels,
15
+ filter_channels,
16
+ n_heads,
17
+ n_layers,
18
+ kernel_size=1,
19
+ p_dropout=0.0,
20
+ window_size=4,
21
+ **kwargs
22
+ ):
23
+ super().__init__()
24
+ self.hidden_channels = hidden_channels
25
+ self.filter_channels = filter_channels
26
+ self.n_heads = n_heads
27
+ self.n_layers = n_layers
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.window_size = window_size
31
+
32
+ self.drop = nn.Dropout(p_dropout)
33
+ self.attn_layers = nn.ModuleList()
34
+ self.norm_layers_1 = nn.ModuleList()
35
+ self.ffn_layers = nn.ModuleList()
36
+ self.norm_layers_2 = nn.ModuleList()
37
+ for i in range(self.n_layers):
38
+ self.attn_layers.append(
39
+ MultiHeadAttention(
40
+ hidden_channels,
41
+ hidden_channels,
42
+ n_heads,
43
+ p_dropout=p_dropout,
44
+ window_size=window_size,
45
+ )
46
+ )
47
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
48
+ self.ffn_layers.append(
49
+ FFN(
50
+ hidden_channels,
51
+ hidden_channels,
52
+ filter_channels,
53
+ kernel_size,
54
+ p_dropout=p_dropout,
55
+ )
56
+ )
57
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
58
+
59
+ def forward(self, x, x_mask):
60
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
61
+ x = x * x_mask
62
+ for i in range(self.n_layers):
63
+ y = self.attn_layers[i](x, x, attn_mask)
64
+ y = self.drop(y)
65
+ x = self.norm_layers_1[i](x + y)
66
+
67
+ y = self.ffn_layers[i](x, x_mask)
68
+ y = self.drop(y)
69
+ x = self.norm_layers_2[i](x + y)
70
+ x = x * x_mask
71
+ return x
72
+
73
+
74
+ class MultiHeadAttention(nn.Module):
75
+ def __init__(
76
+ self,
77
+ channels,
78
+ out_channels,
79
+ n_heads,
80
+ p_dropout=0.0,
81
+ window_size=None,
82
+ heads_share=True,
83
+ block_length=None,
84
+ proximal_bias=False,
85
+ proximal_init=False,
86
+ ):
87
+ super().__init__()
88
+ assert channels % n_heads == 0
89
+
90
+ self.channels = channels
91
+ self.out_channels = out_channels
92
+ self.n_heads = n_heads
93
+ self.p_dropout = p_dropout
94
+ self.window_size = window_size
95
+ self.heads_share = heads_share
96
+ self.block_length = block_length
97
+ self.proximal_bias = proximal_bias
98
+ self.proximal_init = proximal_init
99
+ # self.attn = None
100
+
101
+ self.k_channels = channels // n_heads
102
+ self.conv_q = nn.Conv1d(channels, channels, 1)
103
+ self.conv_k = nn.Conv1d(channels, channels, 1)
104
+ self.conv_v = nn.Conv1d(channels, channels, 1)
105
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
106
+ self.drop = nn.Dropout(p_dropout)
107
+
108
+ if window_size is not None:
109
+ n_heads_rel = 1 if heads_share else n_heads
110
+ rel_stddev = self.k_channels**-0.5
111
+ self.emb_rel_k = nn.Parameter(
112
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
113
+ * rel_stddev
114
+ )
115
+ self.emb_rel_v = nn.Parameter(
116
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
117
+ * rel_stddev
118
+ )
119
+
120
+ nn.init.xavier_uniform_(self.conv_q.weight)
121
+ nn.init.xavier_uniform_(self.conv_k.weight)
122
+ nn.init.xavier_uniform_(self.conv_v.weight)
123
+ if proximal_init:
124
+ with torch.no_grad():
125
+ self.conv_k.weight.copy_(self.conv_q.weight)
126
+ self.conv_k.bias.copy_(self.conv_q.bias)
127
+
128
+ def forward(self, x, c, attn_mask=None):
129
+ q = self.conv_q(x)
130
+ k = self.conv_k(c)
131
+ v = self.conv_v(c)
132
+
133
+ x, _ = self.attention(q, k, v, mask=attn_mask)
134
+
135
+ x = self.conv_o(x)
136
+ return x
137
+
138
+ def attention(self, query, key, value, mask=None):
139
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
140
+ b, d, t_s, t_t = (*key.size(), query.size(2))
141
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
142
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
143
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
144
+
145
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
146
+ if self.window_size is not None:
147
+ assert (
148
+ t_s == t_t
149
+ ), "Relative attention is only available for self-attention."
150
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
151
+ rel_logits = self._matmul_with_relative_keys(
152
+ query / math.sqrt(self.k_channels), key_relative_embeddings
153
+ )
154
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
155
+ scores = scores + scores_local
156
+ if self.proximal_bias:
157
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
158
+ scores = scores + self._attention_bias_proximal(t_s).to(
159
+ device=scores.device, dtype=scores.dtype
160
+ )
161
+ if mask is not None:
162
+ scores = scores.masked_fill(mask == 0, -1e4)
163
+ if self.block_length is not None:
164
+ assert (
165
+ t_s == t_t
166
+ ), "Local attention is only available for self-attention."
167
+ block_mask = (
168
+ torch.ones_like(scores)
169
+ .triu(-self.block_length)
170
+ .tril(self.block_length)
171
+ )
172
+ scores = scores.masked_fill(block_mask == 0, -1e4)
173
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
174
+ p_attn = self.drop(p_attn)
175
+ output = torch.matmul(p_attn, value)
176
+ if self.window_size is not None:
177
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
178
+ value_relative_embeddings = self._get_relative_embeddings(
179
+ self.emb_rel_v, t_s
180
+ )
181
+ output = output + self._matmul_with_relative_values(
182
+ relative_weights, value_relative_embeddings
183
+ )
184
+ output = (
185
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
186
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
187
+ return output, p_attn
188
+
189
+ def _matmul_with_relative_values(self, x, y):
190
+ """
191
+ x: [b, h, l, m]
192
+ y: [h or 1, m, d]
193
+ ret: [b, h, l, d]
194
+ """
195
+ ret = torch.matmul(x, y.unsqueeze(0))
196
+ return ret
197
+
198
+ def _matmul_with_relative_keys(self, x, y):
199
+ """
200
+ x: [b, h, l, d]
201
+ y: [h or 1, m, d]
202
+ ret: [b, h, l, m]
203
+ """
204
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
205
+ return ret
206
+
207
+ def _get_relative_embeddings(self, relative_embeddings, length):
208
+ max_relative_position = 2 * self.window_size + 1
209
+ # Pad first before slice to avoid using cond ops.
210
+ pad_length = max(length - (self.window_size + 1), 0)
211
+ slice_start_position = max((self.window_size + 1) - length, 0)
212
+ slice_end_position = slice_start_position + 2 * length - 1
213
+ if pad_length > 0:
214
+ padded_relative_embeddings = F.pad(
215
+ relative_embeddings,
216
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
217
+ )
218
+ else:
219
+ padded_relative_embeddings = relative_embeddings
220
+ used_relative_embeddings = padded_relative_embeddings[
221
+ :, slice_start_position:slice_end_position
222
+ ]
223
+ return used_relative_embeddings
224
+
225
+ def _relative_position_to_absolute_position(self, x):
226
+ """
227
+ x: [b, h, l, 2*l-1]
228
+ ret: [b, h, l, l]
229
+ """
230
+ batch, heads, length, _ = x.size()
231
+ # Concat columns of pad to shift from relative to absolute indexing.
232
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
233
+
234
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
235
+ x_flat = x.view([batch, heads, length * 2 * length])
236
+ x_flat = F.pad(
237
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
238
+ )
239
+
240
+ # Reshape and slice out the padded elements.
241
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
242
+ :, :, :length, length - 1 :
243
+ ]
244
+ return x_final
245
+
246
+ def _absolute_position_to_relative_position(self, x):
247
+ """
248
+ x: [b, h, l, l]
249
+ ret: [b, h, l, 2*l-1]
250
+ """
251
+ batch, heads, length, _ = x.shape
252
+ # padd along column
253
+ x = F.pad(
254
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
255
+ )
256
+ x_flat = x.view([batch, heads, length * length + length * (length - 1)])
257
+ # add 0's in the beginning that will skew the elements after reshape
258
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
259
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
260
+ return x_final
261
+
262
+ def _attention_bias_proximal(self, length):
263
+ """Bias for self-attention to encourage attention to close positions.
264
+ Args:
265
+ length: an integer scalar.
266
+ Returns:
267
+ a Tensor with shape [1, 1, length, length]
268
+ """
269
+ r = torch.arange(length, dtype=torch.float32)
270
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
271
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
272
+
273
+
274
+ class FFN(nn.Module):
275
+ def __init__(
276
+ self,
277
+ in_channels,
278
+ out_channels,
279
+ filter_channels,
280
+ kernel_size,
281
+ p_dropout=0.0,
282
+ activation=None,
283
+ causal=False,
284
+ ):
285
+ super().__init__()
286
+ self.in_channels = in_channels
287
+ self.out_channels = out_channels
288
+ self.filter_channels = filter_channels
289
+ self.kernel_size = kernel_size
290
+ self.p_dropout = p_dropout
291
+ self.activation = activation
292
+ self.causal = causal
293
+
294
+ if causal:
295
+ self.padding = self._causal_padding
296
+ else:
297
+ self.padding = self._same_padding
298
+
299
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
300
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
301
+ self.drop = nn.Dropout(p_dropout)
302
+
303
+ def forward(self, x, x_mask):
304
+ x = self.conv_1(self.padding(x * x_mask))
305
+ if self.activation == "gelu":
306
+ x = x * torch.sigmoid(1.702 * x)
307
+ else:
308
+ x = torch.relu(x)
309
+ x = self.drop(x)
310
+ x = self.conv_2(self.padding(x * x_mask))
311
+ return x * x_mask
312
+
313
+ def _causal_padding(self, x):
314
+ if self.kernel_size == 1:
315
+ return x
316
+ pad_l = self.kernel_size - 1
317
+ pad_r = 0
318
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
319
+ x = F.pad(x, commons.convert_pad_shape(padding))
320
+ return x
321
+
322
+ def _same_padding(self, x):
323
+ if self.kernel_size == 1:
324
+ return x
325
+ pad_l = (self.kernel_size - 1) // 2
326
+ pad_r = self.kernel_size // 2
327
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
328
+ x = F.pad(x, commons.convert_pad_shape(padding))
329
+ return x
commons.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def init_weights(m, mean=0.0, std=0.01):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ m.weight.data.normal_(mean, std)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size * dilation - dilation) / 2)
15
+
16
+
17
+ def convert_pad_shape(pad_shape):
18
+ l = pad_shape[::-1]
19
+ pad_shape = [item for sublist in l for item in sublist]
20
+ return pad_shape
21
+
22
+
23
+ def intersperse(lst, item):
24
+ result = [item] * (len(lst) * 2 + 1)
25
+ result[1::2] = lst
26
+ return result
27
+
28
+
29
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
30
+ """KL(P||Q)"""
31
+ kl = (logs_q - logs_p) - 0.5
32
+ kl += (
33
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
34
+ )
35
+ return kl
36
+
37
+
38
+ def rand_gumbel(shape):
39
+ """Sample from the Gumbel distribution, protect from overflows."""
40
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
41
+ return -torch.log(-torch.log(uniform_samples))
42
+
43
+
44
+ def rand_gumbel_like(x):
45
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
46
+ return g
47
+
48
+
49
+ def slice_segments(x, ids_str, segment_size=4):
50
+ ret = torch.zeros_like(x[:, :, :segment_size])
51
+ for i in range(x.size(0)):
52
+ idx_str = ids_str[i]
53
+ idx_end = idx_str + segment_size
54
+ ret[i] = x[i, :, idx_str:idx_end]
55
+ return ret
56
+
57
+
58
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
59
+ b, d, t = x.size()
60
+ if x_lengths is None:
61
+ x_lengths = t
62
+ ids_str_max = x_lengths - segment_size + 1
63
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
64
+ ret = slice_segments(x, ids_str, segment_size)
65
+ return ret, ids_str
66
+
67
+
68
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
69
+ position = torch.arange(length, dtype=torch.float)
70
+ num_timescales = channels // 2
71
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
72
+ num_timescales - 1
73
+ )
74
+ inv_timescales = min_timescale * torch.exp(
75
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
76
+ )
77
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
78
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
79
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
80
+ signal = signal.view(1, channels, length)
81
+ return signal
82
+
83
+
84
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
85
+ b, channels, length = x.size()
86
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
87
+ return x + signal.to(dtype=x.dtype, device=x.device)
88
+
89
+
90
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
91
+ b, channels, length = x.size()
92
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
93
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
94
+
95
+
96
+ def subsequent_mask(length):
97
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
98
+ return mask
99
+
100
+
101
+ @torch.jit.script
102
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
103
+ n_channels_int = n_channels[0]
104
+ in_act = input_a + input_b
105
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
106
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
107
+ acts = t_act * s_act
108
+ return acts
109
+
110
+
111
+ def convert_pad_shape(pad_shape):
112
+ l = pad_shape[::-1]
113
+ pad_shape = [item for sublist in l for item in sublist]
114
+ return pad_shape
115
+
116
+
117
+ def shift_1d(x):
118
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
119
+ return x
120
+
121
+
122
+ def sequence_mask(length, max_length=None):
123
+ if max_length is None:
124
+ max_length = length.max()
125
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
126
+ return x.unsqueeze(0) < length.unsqueeze(1)
127
+
128
+
129
+ def generate_path(duration, mask):
130
+ """
131
+ duration: [b, 1, t_x]
132
+ mask: [b, 1, t_y, t_x]
133
+ """
134
+ device = duration.device
135
+
136
+ b, _, t_y, t_x = mask.shape
137
+ cum_duration = torch.cumsum(duration, -1)
138
+
139
+ cum_duration_flat = cum_duration.view(b * t_x)
140
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
141
+ path = path.view(b, t_x, t_y)
142
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
143
+ path = path.unsqueeze(1).transpose(2, 3) * mask
144
+ return path
145
+
146
+
147
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
148
+ if isinstance(parameters, torch.Tensor):
149
+ parameters = [parameters]
150
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
151
+ norm_type = float(norm_type)
152
+ if clip_value is not None:
153
+ clip_value = float(clip_value)
154
+
155
+ total_norm = 0
156
+ for p in parameters:
157
+ param_norm = p.grad.data.norm(norm_type)
158
+ total_norm += param_norm.item() ** norm_type
159
+ if clip_value is not None:
160
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
161
+ total_norm = total_norm ** (1.0 / norm_type)
162
+ return total_norm
config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "learning_rate": 2e-4,
4
+ "betas": [
5
+ 0.8,
6
+ 0.99
7
+ ],
8
+ "eps": 1e-9,
9
+ "lr_decay": 0.999875,
10
+ "segment_size": 8192,
11
+ "c_mel": 45,
12
+ "c_kl": 1.0
13
+ },
14
+ "data": {
15
+ "vocab_size": 256,
16
+ "max_wav_value": 32768.0,
17
+ "sampling_rate": 16000,
18
+ "filter_length": 1024,
19
+ "hop_length": 256,
20
+ "win_length": 1024,
21
+ "n_mel_channels": 80,
22
+ "mel_fmin": 0.0,
23
+ "mel_fmax": null
24
+ },
25
+ "model": {
26
+ "inter_channels": 192,
27
+ "hidden_channels": 192,
28
+ "filter_channels": 768,
29
+ "n_heads": 2,
30
+ "n_layers": 6,
31
+ "kernel_size": 3,
32
+ "p_dropout": 0.1,
33
+ "resblock": "1",
34
+ "resblock_kernel_sizes": [
35
+ 3,
36
+ 7,
37
+ 11
38
+ ],
39
+ "resblock_dilation_sizes": [
40
+ [
41
+ 1,
42
+ 3,
43
+ 5
44
+ ],
45
+ [
46
+ 1,
47
+ 3,
48
+ 5
49
+ ],
50
+ [
51
+ 1,
52
+ 3,
53
+ 5
54
+ ]
55
+ ],
56
+ "upsample_rates": [
57
+ 8,
58
+ 8,
59
+ 2,
60
+ 2
61
+ ],
62
+ "upsample_initial_channel": 512,
63
+ "upsample_kernel_sizes": [
64
+ 16,
65
+ 16,
66
+ 4,
67
+ 4
68
+ ],
69
+ "n_layers_q": 3,
70
+ "use_spectral_norm": false
71
+ }
72
+ }
flow.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from modules import WN
5
+
6
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
7
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
8
+ DEFAULT_MIN_DERIVATIVE = 1e-3
9
+
10
+
11
+ class ResidualCouplingLayer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ channels,
15
+ hidden_channels,
16
+ kernel_size,
17
+ dilation_rate,
18
+ n_layers,
19
+ p_dropout=0,
20
+ gin_channels=0,
21
+ mean_only=False,
22
+ ):
23
+ assert channels % 2 == 0, "channels should be divisible by 2"
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.hidden_channels = hidden_channels
27
+ self.kernel_size = kernel_size
28
+ self.dilation_rate = dilation_rate
29
+ self.n_layers = n_layers
30
+ self.half_channels = channels // 2
31
+ self.mean_only = mean_only
32
+
33
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
34
+ self.enc = WN(
35
+ hidden_channels,
36
+ kernel_size,
37
+ dilation_rate,
38
+ n_layers,
39
+ p_dropout=p_dropout,
40
+ gin_channels=gin_channels,
41
+ )
42
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
43
+ self.post.weight.data.zero_()
44
+ self.post.bias.data.zero_()
45
+
46
+ def forward(self, x, x_mask, g=None, reverse=False):
47
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
48
+ h = self.pre(x0) * x_mask
49
+ h = self.enc(h, x_mask, g=g)
50
+ stats = self.post(h) * x_mask
51
+ if not self.mean_only:
52
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
53
+ else:
54
+ m = stats
55
+ logs = torch.zeros_like(m)
56
+
57
+ if not reverse:
58
+ x1 = m + x1 * torch.exp(logs) * x_mask
59
+ x = torch.cat([x0, x1], 1)
60
+ logdet = torch.sum(logs, [1, 2])
61
+ return x, logdet
62
+ else:
63
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
64
+ x = torch.cat([x0, x1], 1)
65
+ return x
66
+
67
+
68
+ class Flip(nn.Module):
69
+ def forward(self, x, *args, reverse=False, **kwargs):
70
+ x = torch.flip(x, [1])
71
+ if not reverse:
72
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
73
+ return x, logdet
74
+ else:
75
+ return x
76
+
77
+
78
+ class ResidualCouplingBlock(nn.Module):
79
+ def __init__(
80
+ self,
81
+ channels,
82
+ hidden_channels,
83
+ kernel_size,
84
+ dilation_rate,
85
+ n_layers,
86
+ n_flows=4,
87
+ gin_channels=0,
88
+ ):
89
+ super().__init__()
90
+ self.channels = channels
91
+ self.hidden_channels = hidden_channels
92
+ self.kernel_size = kernel_size
93
+ self.dilation_rate = dilation_rate
94
+ self.n_layers = n_layers
95
+ self.n_flows = n_flows
96
+ self.gin_channels = gin_channels
97
+
98
+ self.flows = nn.ModuleList()
99
+ for i in range(n_flows):
100
+ self.flows.append(
101
+ ResidualCouplingLayer(
102
+ channels,
103
+ hidden_channels,
104
+ kernel_size,
105
+ dilation_rate,
106
+ n_layers,
107
+ gin_channels=gin_channels,
108
+ mean_only=True,
109
+ )
110
+ )
111
+ self.flows.append(Flip())
112
+
113
+ def forward(self, x, x_mask, g=None, reverse=False):
114
+ if not reverse:
115
+ for flow in self.flows:
116
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
117
+ else:
118
+ for flow in reversed(self.flows):
119
+ x = flow(x, x_mask, g=g, reverse=reverse)
120
+ return x
models.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
8
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
9
+
10
+ import attentions
11
+ import commons
12
+ import modules
13
+ from commons import get_padding, init_weights
14
+ from flow import ResidualCouplingBlock
15
+
16
+
17
+ class PriorEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ n_vocab,
21
+ out_channels,
22
+ hidden_channels,
23
+ filter_channels,
24
+ n_heads,
25
+ n_layers,
26
+ kernel_size,
27
+ p_dropout,
28
+ ):
29
+ super().__init__()
30
+ self.n_vocab = n_vocab
31
+ self.out_channels = out_channels
32
+ self.hidden_channels = hidden_channels
33
+ self.filter_channels = filter_channels
34
+ self.n_heads = n_heads
35
+ self.n_layers = n_layers
36
+ self.kernel_size = kernel_size
37
+ self.p_dropout = p_dropout
38
+
39
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
40
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
41
+ self.pre_attn_encoder = attentions.Encoder(
42
+ hidden_channels,
43
+ filter_channels,
44
+ n_heads,
45
+ n_layers // 2,
46
+ kernel_size,
47
+ p_dropout,
48
+ )
49
+ self.post_attn_encoder = attentions.Encoder(
50
+ hidden_channels,
51
+ filter_channels,
52
+ n_heads,
53
+ n_layers - n_layers // 2,
54
+ kernel_size,
55
+ p_dropout,
56
+ )
57
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
58
+
59
+ def forward(self, x, x_lengths, y_lengths, attn):
60
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
61
+ x = torch.transpose(x, 1, -1) # [b, h, t]
62
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
63
+ x.dtype
64
+ )
65
+ x = self.pre_attn_encoder(x * x_mask, x_mask)
66
+ y = torch.einsum("bht,blt->bhl", x, attn)
67
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
68
+ y.dtype
69
+ )
70
+ y = self.post_attn_encoder(y * y_mask, y_mask)
71
+ stats = self.proj(y) * y_mask
72
+
73
+ m, logs = torch.split(stats, self.out_channels, dim=1)
74
+ return y, m, logs, y_mask
75
+
76
+
77
+ class PosteriorEncoder(nn.Module):
78
+ def __init__(
79
+ self,
80
+ in_channels,
81
+ out_channels,
82
+ hidden_channels,
83
+ kernel_size,
84
+ dilation_rate,
85
+ n_layers,
86
+ gin_channels=0,
87
+ ):
88
+ super().__init__()
89
+ self.in_channels = in_channels
90
+ self.out_channels = out_channels
91
+ self.hidden_channels = hidden_channels
92
+ self.kernel_size = kernel_size
93
+ self.dilation_rate = dilation_rate
94
+ self.n_layers = n_layers
95
+ self.gin_channels = gin_channels
96
+
97
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
98
+ self.enc = modules.WN(
99
+ hidden_channels,
100
+ kernel_size,
101
+ dilation_rate,
102
+ n_layers,
103
+ gin_channels=gin_channels,
104
+ )
105
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
106
+
107
+ def forward(self, x, x_lengths, g=None):
108
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
109
+ x.dtype
110
+ )
111
+ x = self.pre(x) * x_mask
112
+ x = self.enc(x, x_mask, g=g)
113
+ stats = self.proj(x) * x_mask
114
+ m, logs = torch.split(stats, self.out_channels, dim=1)
115
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
116
+ return z, m, logs, x_mask
117
+
118
+
119
+ class Generator(torch.nn.Module):
120
+ def __init__(
121
+ self,
122
+ initial_channel,
123
+ resblock,
124
+ resblock_kernel_sizes,
125
+ resblock_dilation_sizes,
126
+ upsample_rates,
127
+ upsample_initial_channel,
128
+ upsample_kernel_sizes,
129
+ gin_channels=0,
130
+ ):
131
+ super(Generator, self).__init__()
132
+ self.num_kernels = len(resblock_kernel_sizes)
133
+ self.num_upsamples = len(upsample_rates)
134
+ self.conv_pre = Conv1d(
135
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
136
+ )
137
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
138
+
139
+ self.ups = nn.ModuleList()
140
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
141
+ self.ups.append(
142
+ weight_norm(
143
+ ConvTranspose1d(
144
+ upsample_initial_channel // (2**i),
145
+ upsample_initial_channel // (2 ** (i + 1)),
146
+ k,
147
+ u,
148
+ padding=(k - u) // 2,
149
+ )
150
+ )
151
+ )
152
+
153
+ self.resblocks = nn.ModuleList()
154
+ for i in range(len(self.ups)):
155
+ ch = upsample_initial_channel // (2 ** (i + 1))
156
+ for j, (k, d) in enumerate(
157
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
158
+ ):
159
+ self.resblocks.append(resblock(ch, k, d))
160
+
161
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
162
+ self.ups.apply(init_weights)
163
+
164
+ if gin_channels != 0:
165
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
166
+
167
+ def forward(self, x, g=None):
168
+ x = self.conv_pre(x)
169
+ if g is not None:
170
+ x = x + self.cond(g)
171
+
172
+ for i in range(self.num_upsamples):
173
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
174
+ x = self.ups[i](x)
175
+ xs = None
176
+ for j in range(self.num_kernels):
177
+ if xs is None:
178
+ xs = self.resblocks[i * self.num_kernels + j](x)
179
+ else:
180
+ xs += self.resblocks[i * self.num_kernels + j](x)
181
+ x = xs / self.num_kernels
182
+ x = F.leaky_relu(x)
183
+ x = self.conv_post(x)
184
+ x = torch.tanh(x)
185
+
186
+ return x
187
+
188
+ def remove_weight_norm(self):
189
+ print("Removing weight norm...")
190
+ for l in self.ups:
191
+ remove_weight_norm(l)
192
+ for l in self.resblocks:
193
+ l.remove_weight_norm()
194
+
195
+
196
+ class DiscriminatorP(torch.nn.Module):
197
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
198
+ super(DiscriminatorP, self).__init__()
199
+ self.period = period
200
+ self.use_spectral_norm = use_spectral_norm
201
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
202
+ self.convs = nn.ModuleList(
203
+ [
204
+ norm_f(
205
+ Conv2d(
206
+ 1,
207
+ 32,
208
+ (kernel_size, 1),
209
+ (stride, 1),
210
+ padding=(get_padding(kernel_size, 1), 0),
211
+ )
212
+ ),
213
+ norm_f(
214
+ Conv2d(
215
+ 32,
216
+ 128,
217
+ (kernel_size, 1),
218
+ (stride, 1),
219
+ padding=(get_padding(kernel_size, 1), 0),
220
+ )
221
+ ),
222
+ norm_f(
223
+ Conv2d(
224
+ 128,
225
+ 512,
226
+ (kernel_size, 1),
227
+ (stride, 1),
228
+ padding=(get_padding(kernel_size, 1), 0),
229
+ )
230
+ ),
231
+ norm_f(
232
+ Conv2d(
233
+ 512,
234
+ 1024,
235
+ (kernel_size, 1),
236
+ (stride, 1),
237
+ padding=(get_padding(kernel_size, 1), 0),
238
+ )
239
+ ),
240
+ norm_f(
241
+ Conv2d(
242
+ 1024,
243
+ 1024,
244
+ (kernel_size, 1),
245
+ 1,
246
+ padding=(get_padding(kernel_size, 1), 0),
247
+ )
248
+ ),
249
+ ]
250
+ )
251
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
252
+
253
+ def forward(self, x):
254
+ fmap = []
255
+
256
+ # 1d to 2d
257
+ b, c, t = x.shape
258
+ if t % self.period != 0: # pad first
259
+ n_pad = self.period - (t % self.period)
260
+ x = F.pad(x, (0, n_pad), "reflect")
261
+ t = t + n_pad
262
+ x = x.view(b, c, t // self.period, self.period)
263
+
264
+ for l in self.convs:
265
+ x = l(x)
266
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
267
+ fmap.append(x)
268
+ x = self.conv_post(x)
269
+ fmap.append(x)
270
+ x = torch.flatten(x, 1, -1)
271
+
272
+ return x, fmap
273
+
274
+
275
+ class DiscriminatorS(torch.nn.Module):
276
+ def __init__(self, use_spectral_norm=False):
277
+ super(DiscriminatorS, self).__init__()
278
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
279
+ self.convs = nn.ModuleList(
280
+ [
281
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
282
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
283
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
284
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
285
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
286
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
287
+ ]
288
+ )
289
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
290
+
291
+ def forward(self, x):
292
+ fmap = []
293
+
294
+ for l in self.convs:
295
+ x = l(x)
296
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
297
+ fmap.append(x)
298
+ x = self.conv_post(x)
299
+ fmap.append(x)
300
+ x = torch.flatten(x, 1, -1)
301
+
302
+ return x, fmap
303
+
304
+
305
+ class MultiPeriodDiscriminator(torch.nn.Module):
306
+ def __init__(self, use_spectral_norm=False):
307
+ super(MultiPeriodDiscriminator, self).__init__()
308
+ periods = [2, 3, 5, 7, 11]
309
+
310
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
311
+ discs = discs + [
312
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
313
+ ]
314
+ self.discriminators = nn.ModuleList(discs)
315
+
316
+ def forward(self, y, y_hat):
317
+ y_d_rs = []
318
+ y_d_gs = []
319
+ fmap_rs = []
320
+ fmap_gs = []
321
+ for i, d in enumerate(self.discriminators):
322
+ y_d_r, fmap_r = d(y)
323
+ y_d_g, fmap_g = d(y_hat)
324
+ y_d_rs.append(y_d_r)
325
+ y_d_gs.append(y_d_g)
326
+ fmap_rs.append(fmap_r)
327
+ fmap_gs.append(fmap_g)
328
+
329
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
330
+
331
+
332
+ class SynthesizerTrn(nn.Module):
333
+ """
334
+ Synthesizer for Training
335
+ """
336
+
337
+ def __init__(
338
+ self,
339
+ n_vocab,
340
+ spec_channels,
341
+ segment_size,
342
+ inter_channels,
343
+ hidden_channels,
344
+ filter_channels,
345
+ n_heads,
346
+ n_layers,
347
+ kernel_size,
348
+ p_dropout,
349
+ resblock,
350
+ resblock_kernel_sizes,
351
+ resblock_dilation_sizes,
352
+ upsample_rates,
353
+ upsample_initial_channel,
354
+ upsample_kernel_sizes,
355
+ n_speakers=0,
356
+ gin_channels=0,
357
+ **kwargs
358
+ ):
359
+ super().__init__()
360
+ self.n_vocab = n_vocab
361
+ self.spec_channels = spec_channels
362
+ self.inter_channels = inter_channels
363
+ self.hidden_channels = hidden_channels
364
+ self.filter_channels = filter_channels
365
+ self.n_heads = n_heads
366
+ self.n_layers = n_layers
367
+ self.kernel_size = kernel_size
368
+ self.p_dropout = p_dropout
369
+ self.resblock = resblock
370
+ self.resblock_kernel_sizes = resblock_kernel_sizes
371
+ self.resblock_dilation_sizes = resblock_dilation_sizes
372
+ self.upsample_rates = upsample_rates
373
+ self.upsample_initial_channel = upsample_initial_channel
374
+ self.upsample_kernel_sizes = upsample_kernel_sizes
375
+ self.segment_size = segment_size
376
+ self.n_speakers = n_speakers
377
+ self.gin_channels = gin_channels
378
+
379
+ self.enc_p = PriorEncoder(
380
+ n_vocab,
381
+ inter_channels,
382
+ hidden_channels,
383
+ filter_channels,
384
+ n_heads,
385
+ n_layers,
386
+ kernel_size,
387
+ p_dropout,
388
+ )
389
+ self.dec = Generator(
390
+ inter_channels,
391
+ resblock,
392
+ resblock_kernel_sizes,
393
+ resblock_dilation_sizes,
394
+ upsample_rates,
395
+ upsample_initial_channel,
396
+ upsample_kernel_sizes,
397
+ gin_channels=gin_channels,
398
+ )
399
+ self.enc_q = PosteriorEncoder(
400
+ spec_channels,
401
+ inter_channels,
402
+ hidden_channels,
403
+ 5,
404
+ 1,
405
+ 16,
406
+ gin_channels=gin_channels,
407
+ )
408
+ self.flow = ResidualCouplingBlock(
409
+ inter_channels, hidden_channels, 5, 2, 4, gin_channels=gin_channels
410
+ )
411
+
412
+ if n_speakers > 1:
413
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
414
+
415
+ def forward(self, x, x_lengths, attn, y, y_lengths, sid=None):
416
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, y_lengths, attn=attn)
417
+ if self.n_speakers > 0:
418
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
419
+ else:
420
+ g = None
421
+
422
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
423
+ z_p = self.flow(z, y_mask, g=g)
424
+
425
+ z_slice, ids_slice = commons.rand_slice_segments(
426
+ z, y_lengths, self.segment_size
427
+ )
428
+ o = self.dec(z_slice, g=g)
429
+ l_length = None
430
+ return (
431
+ o,
432
+ l_length,
433
+ attn,
434
+ ids_slice,
435
+ x_mask,
436
+ y_mask,
437
+ (z, z_p, m_p, logs_p, m_q, logs_q),
438
+ )
439
+
440
+ def infer(
441
+ self,
442
+ x,
443
+ x_lengths,
444
+ y_lengths,
445
+ attn,
446
+ sid=None,
447
+ noise_scale=1,
448
+ max_len=None,
449
+ ):
450
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, y_lengths, attn=attn)
451
+ if self.n_speakers > 0:
452
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
453
+ else:
454
+ g = None
455
+
456
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, attn.shape[1]), 1).to(
457
+ x_mask.dtype
458
+ )
459
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
460
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
461
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
462
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
463
+
464
+
465
+ class DurationNet(torch.nn.Module):
466
+ def __init__(self, vocab_size: int, dim: int, num_layers=2):
467
+ super().__init__()
468
+ self.embed = torch.nn.Embedding(vocab_size, embedding_dim=dim)
469
+ self.rnn = torch.nn.GRU(
470
+ dim,
471
+ dim,
472
+ num_layers=num_layers,
473
+ batch_first=True,
474
+ bidirectional=True,
475
+ dropout=0.2,
476
+ )
477
+ self.proj = torch.nn.Linear(2 * dim, 1)
478
+
479
+ def forward(self, token, lengths):
480
+ x = self.embed(token)
481
+ lengths = lengths.long().cpu()
482
+ x = pack_padded_sequence(
483
+ x, lengths=lengths, batch_first=True, enforce_sorted=False
484
+ )
485
+ x, _ = self.rnn(x)
486
+ x, _ = pad_packed_sequence(x, batch_first=True, total_length=token.shape[1])
487
+ x = self.proj(x)
488
+ x = torch.nn.functional.softplus(x)
489
+ return x
modules.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Conv1d
4
+ from torch.nn import functional as F
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+
7
+ import commons
8
+ from commons import get_padding, init_weights
9
+
10
+ LRELU_SLOPE = 0.1
11
+
12
+
13
+ class LayerNorm(nn.Module):
14
+ def __init__(self, channels, eps=1e-5):
15
+ super().__init__()
16
+ self.channels = channels
17
+ self.eps = eps
18
+
19
+ self.gamma = nn.Parameter(torch.ones(channels))
20
+ self.beta = nn.Parameter(torch.zeros(channels))
21
+
22
+ def forward(self, x):
23
+ x = x.transpose(1, -1)
24
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
25
+ return x.transpose(1, -1)
26
+
27
+
28
+ class ConvReluNorm(nn.Module):
29
+ def __init__(
30
+ self,
31
+ in_channels,
32
+ hidden_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ n_layers,
36
+ p_dropout,
37
+ ):
38
+ super().__init__()
39
+ self.in_channels = in_channels
40
+ self.hidden_channels = hidden_channels
41
+ self.out_channels = out_channels
42
+ self.kernel_size = kernel_size
43
+ self.n_layers = n_layers
44
+ self.p_dropout = p_dropout
45
+ assert n_layers > 1, "Number of layers should be larger than 0."
46
+
47
+ self.conv_layers = nn.ModuleList()
48
+ self.norm_layers = nn.ModuleList()
49
+ self.conv_layers.append(
50
+ nn.Conv1d(
51
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
52
+ )
53
+ )
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
56
+ for _ in range(n_layers - 1):
57
+ self.conv_layers.append(
58
+ nn.Conv1d(
59
+ hidden_channels,
60
+ hidden_channels,
61
+ kernel_size,
62
+ padding=kernel_size // 2,
63
+ )
64
+ )
65
+ self.norm_layers.append(LayerNorm(hidden_channels))
66
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
67
+ self.proj.weight.data.zero_()
68
+ self.proj.bias.data.zero_()
69
+
70
+ def forward(self, x, x_mask):
71
+ x_org = x
72
+ for i in range(self.n_layers):
73
+ x = self.conv_layers[i](x * x_mask)
74
+ x = self.norm_layers[i](x)
75
+ x = self.relu_drop(x)
76
+ x = x_org + self.proj(x)
77
+ return x * x_mask
78
+
79
+
80
+ class DDSConv(nn.Module):
81
+ """
82
+ Dialted and Depth-Separable Convolution
83
+ """
84
+
85
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
86
+ super().__init__()
87
+ self.channels = channels
88
+ self.kernel_size = kernel_size
89
+ self.n_layers = n_layers
90
+ self.p_dropout = p_dropout
91
+
92
+ self.drop = nn.Dropout(p_dropout)
93
+ self.convs_sep = nn.ModuleList()
94
+ self.convs_1x1 = nn.ModuleList()
95
+ self.norms_1 = nn.ModuleList()
96
+ self.norms_2 = nn.ModuleList()
97
+ for i in range(n_layers):
98
+ dilation = kernel_size**i
99
+ padding = (kernel_size * dilation - dilation) // 2
100
+ self.convs_sep.append(
101
+ nn.Conv1d(
102
+ channels,
103
+ channels,
104
+ kernel_size,
105
+ groups=channels,
106
+ dilation=dilation,
107
+ padding=padding,
108
+ )
109
+ )
110
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
111
+ self.norms_1.append(LayerNorm(channels))
112
+ self.norms_2.append(LayerNorm(channels))
113
+
114
+ def forward(self, x, x_mask, g=None):
115
+ if g is not None:
116
+ x = x + g
117
+ for i in range(self.n_layers):
118
+ y = self.convs_sep[i](x * x_mask)
119
+ y = self.norms_1[i](y)
120
+ y = F.gelu(y)
121
+ y = self.convs_1x1[i](y)
122
+ y = self.norms_2[i](y)
123
+ y = F.gelu(y)
124
+ y = self.drop(y)
125
+ x = x + y
126
+ return x * x_mask
127
+
128
+
129
+ class WN(torch.nn.Module):
130
+ def __init__(
131
+ self,
132
+ hidden_channels,
133
+ kernel_size,
134
+ dilation_rate,
135
+ n_layers,
136
+ gin_channels=0,
137
+ p_dropout=0,
138
+ ):
139
+ super(WN, self).__init__()
140
+ assert kernel_size % 2 == 1
141
+ self.hidden_channels = hidden_channels
142
+ self.kernel_size = (kernel_size,)
143
+ self.dilation_rate = dilation_rate
144
+ self.n_layers = n_layers
145
+ self.gin_channels = gin_channels
146
+ self.p_dropout = p_dropout
147
+
148
+ self.in_layers = torch.nn.ModuleList()
149
+ self.res_skip_layers = torch.nn.ModuleList()
150
+ self.drop = nn.Dropout(p_dropout)
151
+
152
+ if gin_channels != 0:
153
+ cond_layer = torch.nn.Conv1d(
154
+ gin_channels, 2 * hidden_channels * n_layers, 1
155
+ )
156
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
157
+
158
+ for i in range(n_layers):
159
+ dilation = dilation_rate**i
160
+ padding = int((kernel_size * dilation - dilation) / 2)
161
+ in_layer = torch.nn.Conv1d(
162
+ hidden_channels,
163
+ 2 * hidden_channels,
164
+ kernel_size,
165
+ dilation=dilation,
166
+ padding=padding,
167
+ )
168
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
169
+ self.in_layers.append(in_layer)
170
+
171
+ # last one is not necessary
172
+ if i < n_layers - 1:
173
+ res_skip_channels = 2 * hidden_channels
174
+ else:
175
+ res_skip_channels = hidden_channels
176
+
177
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
178
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
179
+ self.res_skip_layers.append(res_skip_layer)
180
+
181
+ def forward(self, x, x_mask, g=None, **kwargs):
182
+ output = torch.zeros_like(x)
183
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
184
+
185
+ if g is not None:
186
+ g = self.cond_layer(g)
187
+
188
+ for i in range(self.n_layers):
189
+ x_in = self.in_layers[i](x)
190
+ if g is not None:
191
+ cond_offset = i * 2 * self.hidden_channels
192
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
193
+ else:
194
+ g_l = torch.zeros_like(x_in)
195
+
196
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
197
+ acts = self.drop(acts)
198
+
199
+ res_skip_acts = self.res_skip_layers[i](acts)
200
+ if i < self.n_layers - 1:
201
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
202
+ x = (x + res_acts) * x_mask
203
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
204
+ else:
205
+ output = output + res_skip_acts
206
+ return output * x_mask
207
+
208
+ def remove_weight_norm(self):
209
+ if self.gin_channels != 0:
210
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
211
+ for l in self.in_layers:
212
+ torch.nn.utils.remove_weight_norm(l)
213
+ for l in self.res_skip_layers:
214
+ torch.nn.utils.remove_weight_norm(l)
215
+
216
+
217
+ class ResBlock1(torch.nn.Module):
218
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
219
+ super(ResBlock1, self).__init__()
220
+ self.convs1 = nn.ModuleList(
221
+ [
222
+ weight_norm(
223
+ Conv1d(
224
+ channels,
225
+ channels,
226
+ kernel_size,
227
+ 1,
228
+ dilation=dilation[0],
229
+ padding=get_padding(kernel_size, dilation[0]),
230
+ )
231
+ ),
232
+ weight_norm(
233
+ Conv1d(
234
+ channels,
235
+ channels,
236
+ kernel_size,
237
+ 1,
238
+ dilation=dilation[1],
239
+ padding=get_padding(kernel_size, dilation[1]),
240
+ )
241
+ ),
242
+ weight_norm(
243
+ Conv1d(
244
+ channels,
245
+ channels,
246
+ kernel_size,
247
+ 1,
248
+ dilation=dilation[2],
249
+ padding=get_padding(kernel_size, dilation[2]),
250
+ )
251
+ ),
252
+ ]
253
+ )
254
+ self.convs1.apply(init_weights)
255
+
256
+ self.convs2 = nn.ModuleList(
257
+ [
258
+ weight_norm(
259
+ Conv1d(
260
+ channels,
261
+ channels,
262
+ kernel_size,
263
+ 1,
264
+ dilation=1,
265
+ padding=get_padding(kernel_size, 1),
266
+ )
267
+ ),
268
+ weight_norm(
269
+ Conv1d(
270
+ channels,
271
+ channels,
272
+ kernel_size,
273
+ 1,
274
+ dilation=1,
275
+ padding=get_padding(kernel_size, 1),
276
+ )
277
+ ),
278
+ weight_norm(
279
+ Conv1d(
280
+ channels,
281
+ channels,
282
+ kernel_size,
283
+ 1,
284
+ dilation=1,
285
+ padding=get_padding(kernel_size, 1),
286
+ )
287
+ ),
288
+ ]
289
+ )
290
+ self.convs2.apply(init_weights)
291
+
292
+ def forward(self, x, x_mask=None):
293
+ for c1, c2 in zip(self.convs1, self.convs2):
294
+ xt = F.leaky_relu(x, LRELU_SLOPE)
295
+ if x_mask is not None:
296
+ xt = xt * x_mask
297
+ xt = c1(xt)
298
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c2(xt)
302
+ x = xt + x
303
+ if x_mask is not None:
304
+ x = x * x_mask
305
+ return x
306
+
307
+ def remove_weight_norm(self):
308
+ for l in self.convs1:
309
+ remove_weight_norm(l)
310
+ for l in self.convs2:
311
+ remove_weight_norm(l)
312
+
313
+
314
+ class ResBlock2(torch.nn.Module):
315
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
316
+ super(ResBlock2, self).__init__()
317
+ self.convs = nn.ModuleList(
318
+ [
319
+ weight_norm(
320
+ Conv1d(
321
+ channels,
322
+ channels,
323
+ kernel_size,
324
+ 1,
325
+ dilation=dilation[0],
326
+ padding=get_padding(kernel_size, dilation[0]),
327
+ )
328
+ ),
329
+ weight_norm(
330
+ Conv1d(
331
+ channels,
332
+ channels,
333
+ kernel_size,
334
+ 1,
335
+ dilation=dilation[1],
336
+ padding=get_padding(kernel_size, dilation[1]),
337
+ )
338
+ ),
339
+ ]
340
+ )
341
+ self.convs.apply(init_weights)
342
+
343
+ def forward(self, x, x_mask=None):
344
+ for c in self.convs:
345
+ xt = F.leaky_relu(x, LRELU_SLOPE)
346
+ if x_mask is not None:
347
+ xt = xt * x_mask
348
+ xt = c(xt)
349
+ x = xt + x
350
+ if x_mask is not None:
351
+ x = x * x_mask
352
+ return x
353
+
354
+ def remove_weight_norm(self):
355
+ for l in self.convs:
356
+ remove_weight_norm(l)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libsndfile1-dev
phone_set.json ADDED
@@ -0,0 +1 @@
 
 
1
+ ["[SEP]", "a", "b", "c", "d", "e", "g", "h", "i", "k", "l", "m", "n", "o", "p", "q", "r", "s", "sil", "spn", "t", "u", "v", "x", "y", "\u00e0", "\u00e1", "\u00e2", "\u00e3", "\u00e8", "\u00e9", "\u00ea", "\u00ec", "\u00ed", "\u00f2", "\u00f3", "\u00f4", "\u00f5", "\u00f9", "\u00fa", "\u00fd", "\u0103", "\u0111", "\u0129", "\u0169", "\u01a1", "\u01b0", "\u1ea1", "\u1ea3", "\u1ea5", "\u1ea7", "\u1ea9", "\u1eab", "\u1ead", "\u1eaf", "\u1eb1", "\u1eb3", "\u1eb5", "\u1eb7", "\u1eb9", "\u1ebb", "\u1ebd", "\u1ebf", "\u1ec1", "\u1ec3", "\u1ec5", "\u1ec7", "\u1ec9", "\u1ecb", "\u1ecd", "\u1ecf", "\u1ed1", "\u1ed3", "\u1ed5", "\u1ed7", "\u1ed9", "\u1edb", "\u1edd", "\u1edf", "\u1ee1", "\u1ee3", "\u1ee5", "\u1ee7", "\u1ee9", "\u1eeb", "\u1eed", "\u1eef", "\u1ef1", "\u1ef3", "\u1ef5", "\u1ef7", "\u1ef9"]
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ regex
3
+ torch