Mahiruoshi commited on
Commit
42f7394
·
1 Parent(s): 805bf42

Delete onnx_modules

Browse files
onnx_modules/V200/__init__.py DELETED
File without changes
onnx_modules/V200/attentions_onnx.py DELETED
@@ -1,378 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- import commons
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class LayerNorm(nn.Module):
13
- def __init__(self, channels, eps=1e-5):
14
- super().__init__()
15
- self.channels = channels
16
- self.eps = eps
17
-
18
- self.gamma = nn.Parameter(torch.ones(channels))
19
- self.beta = nn.Parameter(torch.zeros(channels))
20
-
21
- def forward(self, x):
22
- x = x.transpose(1, -1)
23
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
- return x.transpose(1, -1)
25
-
26
-
27
- @torch.jit.script
28
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
- n_channels_int = n_channels[0]
30
- in_act = input_a + input_b
31
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
- acts = t_act * s_act
34
- return acts
35
-
36
-
37
- class Encoder(nn.Module):
38
- def __init__(
39
- self,
40
- hidden_channels,
41
- filter_channels,
42
- n_heads,
43
- n_layers,
44
- kernel_size=1,
45
- p_dropout=0.0,
46
- window_size=4,
47
- isflow=True,
48
- **kwargs
49
- ):
50
- super().__init__()
51
- self.hidden_channels = hidden_channels
52
- self.filter_channels = filter_channels
53
- self.n_heads = n_heads
54
- self.n_layers = n_layers
55
- self.kernel_size = kernel_size
56
- self.p_dropout = p_dropout
57
- self.window_size = window_size
58
- # if isflow:
59
- # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
- # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
- # self.cond_layer = weight_norm(cond_layer, name='weight')
62
- # self.gin_channels = 256
63
- self.cond_layer_idx = self.n_layers
64
- if "gin_channels" in kwargs:
65
- self.gin_channels = kwargs["gin_channels"]
66
- if self.gin_channels != 0:
67
- self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
- # vits2 says 3rd block, so idx is 2 by default
69
- self.cond_layer_idx = (
70
- kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
- )
72
- logging.debug(self.gin_channels, self.cond_layer_idx)
73
- assert (
74
- self.cond_layer_idx < self.n_layers
75
- ), "cond_layer_idx should be less than n_layers"
76
- self.drop = nn.Dropout(p_dropout)
77
- self.attn_layers = nn.ModuleList()
78
- self.norm_layers_1 = nn.ModuleList()
79
- self.ffn_layers = nn.ModuleList()
80
- self.norm_layers_2 = nn.ModuleList()
81
- for i in range(self.n_layers):
82
- self.attn_layers.append(
83
- MultiHeadAttention(
84
- hidden_channels,
85
- hidden_channels,
86
- n_heads,
87
- p_dropout=p_dropout,
88
- window_size=window_size,
89
- )
90
- )
91
- self.norm_layers_1.append(LayerNorm(hidden_channels))
92
- self.ffn_layers.append(
93
- FFN(
94
- hidden_channels,
95
- hidden_channels,
96
- filter_channels,
97
- kernel_size,
98
- p_dropout=p_dropout,
99
- )
100
- )
101
- self.norm_layers_2.append(LayerNorm(hidden_channels))
102
-
103
- def forward(self, x, x_mask, g=None):
104
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105
- x = x * x_mask
106
- for i in range(self.n_layers):
107
- if i == self.cond_layer_idx and g is not None:
108
- g = self.spk_emb_linear(g.transpose(1, 2))
109
- g = g.transpose(1, 2)
110
- x = x + g
111
- x = x * x_mask
112
- y = self.attn_layers[i](x, x, attn_mask)
113
- y = self.drop(y)
114
- x = self.norm_layers_1[i](x + y)
115
-
116
- y = self.ffn_layers[i](x, x_mask)
117
- y = self.drop(y)
118
- x = self.norm_layers_2[i](x + y)
119
- x = x * x_mask
120
- return x
121
-
122
-
123
- class MultiHeadAttention(nn.Module):
124
- def __init__(
125
- self,
126
- channels,
127
- out_channels,
128
- n_heads,
129
- p_dropout=0.0,
130
- window_size=None,
131
- heads_share=True,
132
- block_length=None,
133
- proximal_bias=False,
134
- proximal_init=False,
135
- ):
136
- super().__init__()
137
- assert channels % n_heads == 0
138
-
139
- self.channels = channels
140
- self.out_channels = out_channels
141
- self.n_heads = n_heads
142
- self.p_dropout = p_dropout
143
- self.window_size = window_size
144
- self.heads_share = heads_share
145
- self.block_length = block_length
146
- self.proximal_bias = proximal_bias
147
- self.proximal_init = proximal_init
148
- self.attn = None
149
-
150
- self.k_channels = channels // n_heads
151
- self.conv_q = nn.Conv1d(channels, channels, 1)
152
- self.conv_k = nn.Conv1d(channels, channels, 1)
153
- self.conv_v = nn.Conv1d(channels, channels, 1)
154
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
155
- self.drop = nn.Dropout(p_dropout)
156
-
157
- if window_size is not None:
158
- n_heads_rel = 1 if heads_share else n_heads
159
- rel_stddev = self.k_channels**-0.5
160
- self.emb_rel_k = nn.Parameter(
161
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
162
- * rel_stddev
163
- )
164
- self.emb_rel_v = nn.Parameter(
165
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
166
- * rel_stddev
167
- )
168
-
169
- nn.init.xavier_uniform_(self.conv_q.weight)
170
- nn.init.xavier_uniform_(self.conv_k.weight)
171
- nn.init.xavier_uniform_(self.conv_v.weight)
172
- if proximal_init:
173
- with torch.no_grad():
174
- self.conv_k.weight.copy_(self.conv_q.weight)
175
- self.conv_k.bias.copy_(self.conv_q.bias)
176
-
177
- def forward(self, x, c, attn_mask=None):
178
- q = self.conv_q(x)
179
- k = self.conv_k(c)
180
- v = self.conv_v(c)
181
-
182
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
183
-
184
- x = self.conv_o(x)
185
- return x
186
-
187
- def attention(self, query, key, value, mask=None):
188
- # reshape [b, d, t] -> [b, n_h, t, d_k]
189
- b, d, t_s, t_t = (*key.size(), query.size(2))
190
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
191
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
193
-
194
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
195
- if self.window_size is not None:
196
- assert (
197
- t_s == t_t
198
- ), "Relative attention is only available for self-attention."
199
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
200
- rel_logits = self._matmul_with_relative_keys(
201
- query / math.sqrt(self.k_channels), key_relative_embeddings
202
- )
203
- scores_local = self._relative_position_to_absolute_position(rel_logits)
204
- scores = scores + scores_local
205
- if self.proximal_bias:
206
- assert t_s == t_t, "Proximal bias is only available for self-attention."
207
- scores = scores + self._attention_bias_proximal(t_s).to(
208
- device=scores.device, dtype=scores.dtype
209
- )
210
- if mask is not None:
211
- scores = scores.masked_fill(mask == 0, -1e4)
212
- if self.block_length is not None:
213
- assert (
214
- t_s == t_t
215
- ), "Local attention is only available for self-attention."
216
- block_mask = (
217
- torch.ones_like(scores)
218
- .triu(-self.block_length)
219
- .tril(self.block_length)
220
- )
221
- scores = scores.masked_fill(block_mask == 0, -1e4)
222
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
223
- p_attn = self.drop(p_attn)
224
- output = torch.matmul(p_attn, value)
225
- if self.window_size is not None:
226
- relative_weights = self._absolute_position_to_relative_position(p_attn)
227
- value_relative_embeddings = self._get_relative_embeddings(
228
- self.emb_rel_v, t_s
229
- )
230
- output = output + self._matmul_with_relative_values(
231
- relative_weights, value_relative_embeddings
232
- )
233
- output = (
234
- output.transpose(2, 3).contiguous().view(b, d, t_t)
235
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
236
- return output, p_attn
237
-
238
- def _matmul_with_relative_values(self, x, y):
239
- """
240
- x: [b, h, l, m]
241
- y: [h or 1, m, d]
242
- ret: [b, h, l, d]
243
- """
244
- ret = torch.matmul(x, y.unsqueeze(0))
245
- return ret
246
-
247
- def _matmul_with_relative_keys(self, x, y):
248
- """
249
- x: [b, h, l, d]
250
- y: [h or 1, m, d]
251
- ret: [b, h, l, m]
252
- """
253
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
254
- return ret
255
-
256
- def _get_relative_embeddings(self, relative_embeddings, length):
257
- max_relative_position = 2 * self.window_size + 1
258
- # Pad first before slice to avoid using cond ops.
259
- pad_length = max(length - (self.window_size + 1), 0)
260
- slice_start_position = max((self.window_size + 1) - length, 0)
261
- slice_end_position = slice_start_position + 2 * length - 1
262
- if pad_length > 0:
263
- padded_relative_embeddings = F.pad(
264
- relative_embeddings,
265
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
266
- )
267
- else:
268
- padded_relative_embeddings = relative_embeddings
269
- used_relative_embeddings = padded_relative_embeddings[
270
- :, slice_start_position:slice_end_position
271
- ]
272
- return used_relative_embeddings
273
-
274
- def _relative_position_to_absolute_position(self, x):
275
- """
276
- x: [b, h, l, 2*l-1]
277
- ret: [b, h, l, l]
278
- """
279
- batch, heads, length, _ = x.size()
280
- # Concat columns of pad to shift from relative to absolute indexing.
281
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
282
-
283
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
284
- x_flat = x.view([batch, heads, length * 2 * length])
285
- x_flat = F.pad(
286
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
287
- )
288
-
289
- # Reshape and slice out the padded elements.
290
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
291
- :, :, :length, length - 1 :
292
- ]
293
- return x_final
294
-
295
- def _absolute_position_to_relative_position(self, x):
296
- """
297
- x: [b, h, l, l]
298
- ret: [b, h, l, 2*l-1]
299
- """
300
- batch, heads, length, _ = x.size()
301
- # padd along column
302
- x = F.pad(
303
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
304
- )
305
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
306
- # add 0's in the beginning that will skew the elements after reshape
307
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
308
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
309
- return x_final
310
-
311
- def _attention_bias_proximal(self, length):
312
- """Bias for self-attention to encourage attention to close positions.
313
- Args:
314
- length: an integer scalar.
315
- Returns:
316
- a Tensor with shape [1, 1, length, length]
317
- """
318
- r = torch.arange(length, dtype=torch.float32)
319
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
320
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
321
-
322
-
323
- class FFN(nn.Module):
324
- def __init__(
325
- self,
326
- in_channels,
327
- out_channels,
328
- filter_channels,
329
- kernel_size,
330
- p_dropout=0.0,
331
- activation=None,
332
- causal=False,
333
- ):
334
- super().__init__()
335
- self.in_channels = in_channels
336
- self.out_channels = out_channels
337
- self.filter_channels = filter_channels
338
- self.kernel_size = kernel_size
339
- self.p_dropout = p_dropout
340
- self.activation = activation
341
- self.causal = causal
342
-
343
- if causal:
344
- self.padding = self._causal_padding
345
- else:
346
- self.padding = self._same_padding
347
-
348
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
349
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
350
- self.drop = nn.Dropout(p_dropout)
351
-
352
- def forward(self, x, x_mask):
353
- x = self.conv_1(self.padding(x * x_mask))
354
- if self.activation == "gelu":
355
- x = x * torch.sigmoid(1.702 * x)
356
- else:
357
- x = torch.relu(x)
358
- x = self.drop(x)
359
- x = self.conv_2(self.padding(x * x_mask))
360
- return x * x_mask
361
-
362
- def _causal_padding(self, x):
363
- if self.kernel_size == 1:
364
- return x
365
- pad_l = self.kernel_size - 1
366
- pad_r = 0
367
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
368
- x = F.pad(x, commons.convert_pad_shape(padding))
369
- return x
370
-
371
- def _same_padding(self, x):
372
- if self.kernel_size == 1:
373
- return x
374
- pad_l = (self.kernel_size - 1) // 2
375
- pad_r = self.kernel_size // 2
376
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
377
- x = F.pad(x, commons.convert_pad_shape(padding))
378
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/models_onnx.py DELETED
@@ -1,990 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- import commons
7
- import modules
8
- from . import attentions_onnx
9
-
10
- from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from commons import init_weights, get_padding
13
- from .text import symbols, num_tones, num_languages
14
-
15
-
16
- class DurationDiscriminator(nn.Module): # vits2
17
- def __init__(
18
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
19
- ):
20
- super().__init__()
21
-
22
- self.in_channels = in_channels
23
- self.filter_channels = filter_channels
24
- self.kernel_size = kernel_size
25
- self.p_dropout = p_dropout
26
- self.gin_channels = gin_channels
27
-
28
- self.drop = nn.Dropout(p_dropout)
29
- self.conv_1 = nn.Conv1d(
30
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
31
- )
32
- self.norm_1 = modules.LayerNorm(filter_channels)
33
- self.conv_2 = nn.Conv1d(
34
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
- )
36
- self.norm_2 = modules.LayerNorm(filter_channels)
37
- self.dur_proj = nn.Conv1d(1, filter_channels, 1)
38
-
39
- self.pre_out_conv_1 = nn.Conv1d(
40
- 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
41
- )
42
- self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
43
- self.pre_out_conv_2 = nn.Conv1d(
44
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
- )
46
- self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
47
-
48
- if gin_channels != 0:
49
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
50
-
51
- self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
52
-
53
- def forward_probability(self, x, x_mask, dur, g=None):
54
- dur = self.dur_proj(dur)
55
- x = torch.cat([x, dur], dim=1)
56
- x = self.pre_out_conv_1(x * x_mask)
57
- x = torch.relu(x)
58
- x = self.pre_out_norm_1(x)
59
- x = self.drop(x)
60
- x = self.pre_out_conv_2(x * x_mask)
61
- x = torch.relu(x)
62
- x = self.pre_out_norm_2(x)
63
- x = self.drop(x)
64
- x = x * x_mask
65
- x = x.transpose(1, 2)
66
- output_prob = self.output_layer(x)
67
- return output_prob
68
-
69
- def forward(self, x, x_mask, dur_r, dur_hat, g=None):
70
- x = torch.detach(x)
71
- if g is not None:
72
- g = torch.detach(g)
73
- x = x + self.cond(g)
74
- x = self.conv_1(x * x_mask)
75
- x = torch.relu(x)
76
- x = self.norm_1(x)
77
- x = self.drop(x)
78
- x = self.conv_2(x * x_mask)
79
- x = torch.relu(x)
80
- x = self.norm_2(x)
81
- x = self.drop(x)
82
-
83
- output_probs = []
84
- for dur in [dur_r, dur_hat]:
85
- output_prob = self.forward_probability(x, x_mask, dur, g)
86
- output_probs.append(output_prob)
87
-
88
- return output_probs
89
-
90
-
91
- class TransformerCouplingBlock(nn.Module):
92
- def __init__(
93
- self,
94
- channels,
95
- hidden_channels,
96
- filter_channels,
97
- n_heads,
98
- n_layers,
99
- kernel_size,
100
- p_dropout,
101
- n_flows=4,
102
- gin_channels=0,
103
- share_parameter=False,
104
- ):
105
- super().__init__()
106
- self.channels = channels
107
- self.hidden_channels = hidden_channels
108
- self.kernel_size = kernel_size
109
- self.n_layers = n_layers
110
- self.n_flows = n_flows
111
- self.gin_channels = gin_channels
112
-
113
- self.flows = nn.ModuleList()
114
-
115
- self.wn = (
116
- attentions_onnx.FFT(
117
- hidden_channels,
118
- filter_channels,
119
- n_heads,
120
- n_layers,
121
- kernel_size,
122
- p_dropout,
123
- isflow=True,
124
- gin_channels=self.gin_channels,
125
- )
126
- if share_parameter
127
- else None
128
- )
129
-
130
- for i in range(n_flows):
131
- self.flows.append(
132
- modules.TransformerCouplingLayer(
133
- channels,
134
- hidden_channels,
135
- kernel_size,
136
- n_layers,
137
- n_heads,
138
- p_dropout,
139
- filter_channels,
140
- mean_only=True,
141
- wn_sharing_parameter=self.wn,
142
- gin_channels=self.gin_channels,
143
- )
144
- )
145
- self.flows.append(modules.Flip())
146
-
147
- def forward(self, x, x_mask, g=None, reverse=True):
148
- if not reverse:
149
- for flow in self.flows:
150
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
151
- else:
152
- for flow in reversed(self.flows):
153
- x = flow(x, x_mask, g=g, reverse=reverse)
154
- return x
155
-
156
-
157
- class StochasticDurationPredictor(nn.Module):
158
- def __init__(
159
- self,
160
- in_channels,
161
- filter_channels,
162
- kernel_size,
163
- p_dropout,
164
- n_flows=4,
165
- gin_channels=0,
166
- ):
167
- super().__init__()
168
- filter_channels = in_channels # it needs to be removed from future version.
169
- self.in_channels = in_channels
170
- self.filter_channels = filter_channels
171
- self.kernel_size = kernel_size
172
- self.p_dropout = p_dropout
173
- self.n_flows = n_flows
174
- self.gin_channels = gin_channels
175
-
176
- self.log_flow = modules.Log()
177
- self.flows = nn.ModuleList()
178
- self.flows.append(modules.ElementwiseAffine(2))
179
- for i in range(n_flows):
180
- self.flows.append(
181
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
182
- )
183
- self.flows.append(modules.Flip())
184
-
185
- self.post_pre = nn.Conv1d(1, filter_channels, 1)
186
- self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
187
- self.post_convs = modules.DDSConv(
188
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
189
- )
190
- self.post_flows = nn.ModuleList()
191
- self.post_flows.append(modules.ElementwiseAffine(2))
192
- for i in range(4):
193
- self.post_flows.append(
194
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
195
- )
196
- self.post_flows.append(modules.Flip())
197
-
198
- self.pre = nn.Conv1d(in_channels, filter_channels, 1)
199
- self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
200
- self.convs = modules.DDSConv(
201
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
202
- )
203
- if gin_channels != 0:
204
- self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
205
-
206
- def forward(self, x, x_mask, z, g=None):
207
- x = torch.detach(x)
208
- x = self.pre(x)
209
- if g is not None:
210
- g = torch.detach(g)
211
- x = x + self.cond(g)
212
- x = self.convs(x, x_mask)
213
- x = self.proj(x) * x_mask
214
-
215
- flows = list(reversed(self.flows))
216
- flows = flows[:-2] + [flows[-1]] # remove a useless vflow
217
- for flow in flows:
218
- z = flow(z, x_mask, g=x, reverse=True)
219
- z0, z1 = torch.split(z, [1, 1], 1)
220
- logw = z0
221
- return logw
222
-
223
-
224
- class DurationPredictor(nn.Module):
225
- def __init__(
226
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
227
- ):
228
- super().__init__()
229
-
230
- self.in_channels = in_channels
231
- self.filter_channels = filter_channels
232
- self.kernel_size = kernel_size
233
- self.p_dropout = p_dropout
234
- self.gin_channels = gin_channels
235
-
236
- self.drop = nn.Dropout(p_dropout)
237
- self.conv_1 = nn.Conv1d(
238
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
239
- )
240
- self.norm_1 = modules.LayerNorm(filter_channels)
241
- self.conv_2 = nn.Conv1d(
242
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
243
- )
244
- self.norm_2 = modules.LayerNorm(filter_channels)
245
- self.proj = nn.Conv1d(filter_channels, 1, 1)
246
-
247
- if gin_channels != 0:
248
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
249
-
250
- def forward(self, x, x_mask, g=None):
251
- x = torch.detach(x)
252
- if g is not None:
253
- g = torch.detach(g)
254
- x = x + self.cond(g)
255
- x = self.conv_1(x * x_mask)
256
- x = torch.relu(x)
257
- x = self.norm_1(x)
258
- x = self.drop(x)
259
- x = self.conv_2(x * x_mask)
260
- x = torch.relu(x)
261
- x = self.norm_2(x)
262
- x = self.drop(x)
263
- x = self.proj(x * x_mask)
264
- return x * x_mask
265
-
266
-
267
- class TextEncoder(nn.Module):
268
- def __init__(
269
- self,
270
- n_vocab,
271
- out_channels,
272
- hidden_channels,
273
- filter_channels,
274
- n_heads,
275
- n_layers,
276
- kernel_size,
277
- p_dropout,
278
- gin_channels=0,
279
- ):
280
- super().__init__()
281
- self.n_vocab = n_vocab
282
- self.out_channels = out_channels
283
- self.hidden_channels = hidden_channels
284
- self.filter_channels = filter_channels
285
- self.n_heads = n_heads
286
- self.n_layers = n_layers
287
- self.kernel_size = kernel_size
288
- self.p_dropout = p_dropout
289
- self.gin_channels = gin_channels
290
- self.emb = nn.Embedding(len(symbols), hidden_channels)
291
- nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
292
- self.tone_emb = nn.Embedding(num_tones, hidden_channels)
293
- nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
294
- self.language_emb = nn.Embedding(num_languages, hidden_channels)
295
- nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
296
- self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
297
- self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
298
- self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
299
-
300
- self.encoder = attentions_onnx.Encoder(
301
- hidden_channels,
302
- filter_channels,
303
- n_heads,
304
- n_layers,
305
- kernel_size,
306
- p_dropout,
307
- gin_channels=self.gin_channels,
308
- )
309
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
310
-
311
- def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
312
- x_mask = torch.ones_like(x).unsqueeze(0)
313
- bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
314
- ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
315
- 1, 2
316
- )
317
- en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
318
- 1, 2
319
- )
320
- x = (
321
- self.emb(x)
322
- + self.tone_emb(tone)
323
- + self.language_emb(language)
324
- + bert_emb
325
- + ja_bert_emb
326
- + en_bert_emb
327
- ) * math.sqrt(
328
- self.hidden_channels
329
- ) # [b, t, h]
330
- x = torch.transpose(x, 1, -1) # [b, h, t]
331
- x_mask = x_mask.to(x.dtype)
332
-
333
- x = self.encoder(x * x_mask, x_mask, g=g)
334
- stats = self.proj(x) * x_mask
335
-
336
- m, logs = torch.split(stats, self.out_channels, dim=1)
337
- return x, m, logs, x_mask
338
-
339
-
340
- class ResidualCouplingBlock(nn.Module):
341
- def __init__(
342
- self,
343
- channels,
344
- hidden_channels,
345
- kernel_size,
346
- dilation_rate,
347
- n_layers,
348
- n_flows=4,
349
- gin_channels=0,
350
- ):
351
- super().__init__()
352
- self.channels = channels
353
- self.hidden_channels = hidden_channels
354
- self.kernel_size = kernel_size
355
- self.dilation_rate = dilation_rate
356
- self.n_layers = n_layers
357
- self.n_flows = n_flows
358
- self.gin_channels = gin_channels
359
-
360
- self.flows = nn.ModuleList()
361
- for i in range(n_flows):
362
- self.flows.append(
363
- modules.ResidualCouplingLayer(
364
- channels,
365
- hidden_channels,
366
- kernel_size,
367
- dilation_rate,
368
- n_layers,
369
- gin_channels=gin_channels,
370
- mean_only=True,
371
- )
372
- )
373
- self.flows.append(modules.Flip())
374
-
375
- def forward(self, x, x_mask, g=None, reverse=True):
376
- if not reverse:
377
- for flow in self.flows:
378
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
379
- else:
380
- for flow in reversed(self.flows):
381
- x = flow(x, x_mask, g=g, reverse=reverse)
382
- return x
383
-
384
-
385
- class PosteriorEncoder(nn.Module):
386
- def __init__(
387
- self,
388
- in_channels,
389
- out_channels,
390
- hidden_channels,
391
- kernel_size,
392
- dilation_rate,
393
- n_layers,
394
- gin_channels=0,
395
- ):
396
- super().__init__()
397
- self.in_channels = in_channels
398
- self.out_channels = out_channels
399
- self.hidden_channels = hidden_channels
400
- self.kernel_size = kernel_size
401
- self.dilation_rate = dilation_rate
402
- self.n_layers = n_layers
403
- self.gin_channels = gin_channels
404
-
405
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
406
- self.enc = modules.WN(
407
- hidden_channels,
408
- kernel_size,
409
- dilation_rate,
410
- n_layers,
411
- gin_channels=gin_channels,
412
- )
413
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
414
-
415
- def forward(self, x, x_lengths, g=None):
416
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
417
- x.dtype
418
- )
419
- x = self.pre(x) * x_mask
420
- x = self.enc(x, x_mask, g=g)
421
- stats = self.proj(x) * x_mask
422
- m, logs = torch.split(stats, self.out_channels, dim=1)
423
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
424
- return z, m, logs, x_mask
425
-
426
-
427
- class Generator(torch.nn.Module):
428
- def __init__(
429
- self,
430
- initial_channel,
431
- resblock,
432
- resblock_kernel_sizes,
433
- resblock_dilation_sizes,
434
- upsample_rates,
435
- upsample_initial_channel,
436
- upsample_kernel_sizes,
437
- gin_channels=0,
438
- ):
439
- super(Generator, self).__init__()
440
- self.num_kernels = len(resblock_kernel_sizes)
441
- self.num_upsamples = len(upsample_rates)
442
- self.conv_pre = Conv1d(
443
- initial_channel, upsample_initial_channel, 7, 1, padding=3
444
- )
445
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
446
-
447
- self.ups = nn.ModuleList()
448
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
449
- self.ups.append(
450
- weight_norm(
451
- ConvTranspose1d(
452
- upsample_initial_channel // (2**i),
453
- upsample_initial_channel // (2 ** (i + 1)),
454
- k,
455
- u,
456
- padding=(k - u) // 2,
457
- )
458
- )
459
- )
460
-
461
- self.resblocks = nn.ModuleList()
462
- for i in range(len(self.ups)):
463
- ch = upsample_initial_channel // (2 ** (i + 1))
464
- for j, (k, d) in enumerate(
465
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
466
- ):
467
- self.resblocks.append(resblock(ch, k, d))
468
-
469
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
470
- self.ups.apply(init_weights)
471
-
472
- if gin_channels != 0:
473
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
474
-
475
- def forward(self, x, g=None):
476
- x = self.conv_pre(x)
477
- if g is not None:
478
- x = x + self.cond(g)
479
-
480
- for i in range(self.num_upsamples):
481
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
482
- x = self.ups[i](x)
483
- xs = None
484
- for j in range(self.num_kernels):
485
- if xs is None:
486
- xs = self.resblocks[i * self.num_kernels + j](x)
487
- else:
488
- xs += self.resblocks[i * self.num_kernels + j](x)
489
- x = xs / self.num_kernels
490
- x = F.leaky_relu(x)
491
- x = self.conv_post(x)
492
- x = torch.tanh(x)
493
-
494
- return x
495
-
496
- def remove_weight_norm(self):
497
- print("Removing weight norm...")
498
- for layer in self.ups:
499
- remove_weight_norm(layer)
500
- for layer in self.resblocks:
501
- layer.remove_weight_norm()
502
-
503
-
504
- class DiscriminatorP(torch.nn.Module):
505
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
506
- super(DiscriminatorP, self).__init__()
507
- self.period = period
508
- self.use_spectral_norm = use_spectral_norm
509
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
510
- self.convs = nn.ModuleList(
511
- [
512
- norm_f(
513
- Conv2d(
514
- 1,
515
- 32,
516
- (kernel_size, 1),
517
- (stride, 1),
518
- padding=(get_padding(kernel_size, 1), 0),
519
- )
520
- ),
521
- norm_f(
522
- Conv2d(
523
- 32,
524
- 128,
525
- (kernel_size, 1),
526
- (stride, 1),
527
- padding=(get_padding(kernel_size, 1), 0),
528
- )
529
- ),
530
- norm_f(
531
- Conv2d(
532
- 128,
533
- 512,
534
- (kernel_size, 1),
535
- (stride, 1),
536
- padding=(get_padding(kernel_size, 1), 0),
537
- )
538
- ),
539
- norm_f(
540
- Conv2d(
541
- 512,
542
- 1024,
543
- (kernel_size, 1),
544
- (stride, 1),
545
- padding=(get_padding(kernel_size, 1), 0),
546
- )
547
- ),
548
- norm_f(
549
- Conv2d(
550
- 1024,
551
- 1024,
552
- (kernel_size, 1),
553
- 1,
554
- padding=(get_padding(kernel_size, 1), 0),
555
- )
556
- ),
557
- ]
558
- )
559
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
560
-
561
- def forward(self, x):
562
- fmap = []
563
-
564
- # 1d to 2d
565
- b, c, t = x.shape
566
- if t % self.period != 0: # pad first
567
- n_pad = self.period - (t % self.period)
568
- x = F.pad(x, (0, n_pad), "reflect")
569
- t = t + n_pad
570
- x = x.view(b, c, t // self.period, self.period)
571
-
572
- for layer in self.convs:
573
- x = layer(x)
574
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
575
- fmap.append(x)
576
- x = self.conv_post(x)
577
- fmap.append(x)
578
- x = torch.flatten(x, 1, -1)
579
-
580
- return x, fmap
581
-
582
-
583
- class DiscriminatorS(torch.nn.Module):
584
- def __init__(self, use_spectral_norm=False):
585
- super(DiscriminatorS, self).__init__()
586
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
587
- self.convs = nn.ModuleList(
588
- [
589
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
590
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
591
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
592
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
593
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
594
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
595
- ]
596
- )
597
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
598
-
599
- def forward(self, x):
600
- fmap = []
601
-
602
- for layer in self.convs:
603
- x = layer(x)
604
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
605
- fmap.append(x)
606
- x = self.conv_post(x)
607
- fmap.append(x)
608
- x = torch.flatten(x, 1, -1)
609
-
610
- return x, fmap
611
-
612
-
613
- class MultiPeriodDiscriminator(torch.nn.Module):
614
- def __init__(self, use_spectral_norm=False):
615
- super(MultiPeriodDiscriminator, self).__init__()
616
- periods = [2, 3, 5, 7, 11]
617
-
618
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
619
- discs = discs + [
620
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
621
- ]
622
- self.discriminators = nn.ModuleList(discs)
623
-
624
- def forward(self, y, y_hat):
625
- y_d_rs = []
626
- y_d_gs = []
627
- fmap_rs = []
628
- fmap_gs = []
629
- for i, d in enumerate(self.discriminators):
630
- y_d_r, fmap_r = d(y)
631
- y_d_g, fmap_g = d(y_hat)
632
- y_d_rs.append(y_d_r)
633
- y_d_gs.append(y_d_g)
634
- fmap_rs.append(fmap_r)
635
- fmap_gs.append(fmap_g)
636
-
637
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
638
-
639
-
640
- class ReferenceEncoder(nn.Module):
641
- """
642
- inputs --- [N, Ty/r, n_mels*r] mels
643
- outputs --- [N, ref_enc_gru_size]
644
- """
645
-
646
- def __init__(self, spec_channels, gin_channels=0):
647
- super().__init__()
648
- self.spec_channels = spec_channels
649
- ref_enc_filters = [32, 32, 64, 64, 128, 128]
650
- K = len(ref_enc_filters)
651
- filters = [1] + ref_enc_filters
652
- convs = [
653
- weight_norm(
654
- nn.Conv2d(
655
- in_channels=filters[i],
656
- out_channels=filters[i + 1],
657
- kernel_size=(3, 3),
658
- stride=(2, 2),
659
- padding=(1, 1),
660
- )
661
- )
662
- for i in range(K)
663
- ]
664
- self.convs = nn.ModuleList(convs)
665
- # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
666
-
667
- out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
668
- self.gru = nn.GRU(
669
- input_size=ref_enc_filters[-1] * out_channels,
670
- hidden_size=256 // 2,
671
- batch_first=True,
672
- )
673
- self.proj = nn.Linear(128, gin_channels)
674
-
675
- def forward(self, inputs, mask=None):
676
- N = inputs.size(0)
677
- out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
678
- for conv in self.convs:
679
- out = conv(out)
680
- # out = wn(out)
681
- out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
682
-
683
- out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
684
- T = out.size(1)
685
- N = out.size(0)
686
- out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
687
-
688
- self.gru.flatten_parameters()
689
- memory, out = self.gru(out) # out --- [1, N, 128]
690
-
691
- return self.proj(out.squeeze(0))
692
-
693
- def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
694
- for i in range(n_convs):
695
- L = (L - kernel_size + 2 * pad) // stride + 1
696
- return L
697
-
698
-
699
- class SynthesizerTrn(nn.Module):
700
- """
701
- Synthesizer for Training
702
- """
703
-
704
- def __init__(
705
- self,
706
- n_vocab,
707
- spec_channels,
708
- segment_size,
709
- inter_channels,
710
- hidden_channels,
711
- filter_channels,
712
- n_heads,
713
- n_layers,
714
- kernel_size,
715
- p_dropout,
716
- resblock,
717
- resblock_kernel_sizes,
718
- resblock_dilation_sizes,
719
- upsample_rates,
720
- upsample_initial_channel,
721
- upsample_kernel_sizes,
722
- n_speakers=256,
723
- gin_channels=256,
724
- use_sdp=True,
725
- n_flow_layer=4,
726
- n_layers_trans_flow=4,
727
- flow_share_parameter=False,
728
- use_transformer_flow=True,
729
- **kwargs,
730
- ):
731
- super().__init__()
732
- self.n_vocab = n_vocab
733
- self.spec_channels = spec_channels
734
- self.inter_channels = inter_channels
735
- self.hidden_channels = hidden_channels
736
- self.filter_channels = filter_channels
737
- self.n_heads = n_heads
738
- self.n_layers = n_layers
739
- self.kernel_size = kernel_size
740
- self.p_dropout = p_dropout
741
- self.resblock = resblock
742
- self.resblock_kernel_sizes = resblock_kernel_sizes
743
- self.resblock_dilation_sizes = resblock_dilation_sizes
744
- self.upsample_rates = upsample_rates
745
- self.upsample_initial_channel = upsample_initial_channel
746
- self.upsample_kernel_sizes = upsample_kernel_sizes
747
- self.segment_size = segment_size
748
- self.n_speakers = n_speakers
749
- self.gin_channels = gin_channels
750
- self.n_layers_trans_flow = n_layers_trans_flow
751
- self.use_spk_conditioned_encoder = kwargs.get(
752
- "use_spk_conditioned_encoder", True
753
- )
754
- self.use_sdp = use_sdp
755
- self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
756
- self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
757
- self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
758
- self.current_mas_noise_scale = self.mas_noise_scale_initial
759
- if self.use_spk_conditioned_encoder and gin_channels > 0:
760
- self.enc_gin_channels = gin_channels
761
- self.enc_p = TextEncoder(
762
- n_vocab,
763
- inter_channels,
764
- hidden_channels,
765
- filter_channels,
766
- n_heads,
767
- n_layers,
768
- kernel_size,
769
- p_dropout,
770
- gin_channels=self.enc_gin_channels,
771
- )
772
- self.dec = Generator(
773
- inter_channels,
774
- resblock,
775
- resblock_kernel_sizes,
776
- resblock_dilation_sizes,
777
- upsample_rates,
778
- upsample_initial_channel,
779
- upsample_kernel_sizes,
780
- gin_channels=gin_channels,
781
- )
782
- self.enc_q = PosteriorEncoder(
783
- spec_channels,
784
- inter_channels,
785
- hidden_channels,
786
- 5,
787
- 1,
788
- 16,
789
- gin_channels=gin_channels,
790
- )
791
- if use_transformer_flow:
792
- self.flow = TransformerCouplingBlock(
793
- inter_channels,
794
- hidden_channels,
795
- filter_channels,
796
- n_heads,
797
- n_layers_trans_flow,
798
- 5,
799
- p_dropout,
800
- n_flow_layer,
801
- gin_channels=gin_channels,
802
- share_parameter=flow_share_parameter,
803
- )
804
- else:
805
- self.flow = ResidualCouplingBlock(
806
- inter_channels,
807
- hidden_channels,
808
- 5,
809
- 1,
810
- n_flow_layer,
811
- gin_channels=gin_channels,
812
- )
813
- self.sdp = StochasticDurationPredictor(
814
- hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
815
- )
816
- self.dp = DurationPredictor(
817
- hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
818
- )
819
-
820
- if n_speakers >= 1:
821
- self.emb_g = nn.Embedding(n_speakers, gin_channels)
822
- else:
823
- self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
824
-
825
- def export_onnx(
826
- self,
827
- path,
828
- max_len=None,
829
- sdp_ratio=0,
830
- y=None,
831
- ):
832
- noise_scale = 0.667
833
- length_scale = 1
834
- noise_scale_w = 0.8
835
- x = (
836
- torch.LongTensor(
837
- [
838
- 0,
839
- 97,
840
- 0,
841
- 8,
842
- 0,
843
- 78,
844
- 0,
845
- 8,
846
- 0,
847
- 76,
848
- 0,
849
- 37,
850
- 0,
851
- 40,
852
- 0,
853
- 97,
854
- 0,
855
- 8,
856
- 0,
857
- 23,
858
- 0,
859
- 8,
860
- 0,
861
- 74,
862
- 0,
863
- 26,
864
- 0,
865
- 104,
866
- 0,
867
- ]
868
- )
869
- .unsqueeze(0)
870
- .cpu()
871
- )
872
- tone = torch.zeros_like(x).cpu()
873
- language = torch.zeros_like(x).cpu()
874
- x_lengths = torch.LongTensor([x.shape[1]]).cpu()
875
- sid = torch.LongTensor([0]).cpu()
876
- bert = torch.randn(size=(x.shape[1], 1024)).cpu()
877
- ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
878
- en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
879
-
880
- if self.n_speakers > 0:
881
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
882
- torch.onnx.export(
883
- self.emb_g,
884
- (sid),
885
- f"onnx/{path}/{path}_emb.onnx",
886
- input_names=["sid"],
887
- output_names=["g"],
888
- verbose=True,
889
- )
890
- else:
891
- g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
892
-
893
- torch.onnx.export(
894
- self.enc_p,
895
- (x, x_lengths, tone, language, bert, ja_bert, en_bert, g),
896
- f"onnx/{path}/{path}_enc_p.onnx",
897
- input_names=[
898
- "x",
899
- "x_lengths",
900
- "t",
901
- "language",
902
- "bert_0",
903
- "bert_1",
904
- "bert_2",
905
- "g",
906
- ],
907
- output_names=["xout", "m_p", "logs_p", "x_mask"],
908
- dynamic_axes={
909
- "x": [0, 1],
910
- "t": [0, 1],
911
- "language": [0, 1],
912
- "bert_0": [0],
913
- "bert_1": [0],
914
- "bert_2": [0],
915
- "xout": [0, 2],
916
- "m_p": [0, 2],
917
- "logs_p": [0, 2],
918
- "x_mask": [0, 2],
919
- },
920
- verbose=True,
921
- opset_version=16,
922
- )
923
- x, m_p, logs_p, x_mask = self.enc_p(
924
- x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
925
- )
926
- zinput = (
927
- torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
928
- * noise_scale_w
929
- )
930
- torch.onnx.export(
931
- self.sdp,
932
- (x, x_mask, zinput, g),
933
- f"onnx/{path}/{path}_sdp.onnx",
934
- input_names=["x", "x_mask", "zin", "g"],
935
- output_names=["logw"],
936
- dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
937
- verbose=True,
938
- )
939
- torch.onnx.export(
940
- self.dp,
941
- (x, x_mask, g),
942
- f"onnx/{path}/{path}_dp.onnx",
943
- input_names=["x", "x_mask", "g"],
944
- output_names=["logw"],
945
- dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
946
- verbose=True,
947
- )
948
- logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
949
- x, x_mask, g=g
950
- ) * (1 - sdp_ratio)
951
- w = torch.exp(logw) * x_mask * length_scale
952
- w_ceil = torch.ceil(w)
953
- y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
954
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
955
- x_mask.dtype
956
- )
957
- attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
958
- attn = commons.generate_path(w_ceil, attn_mask)
959
-
960
- m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
961
- 1, 2
962
- ) # [b, t', t], [b, t, d] -> [b, d, t']
963
- logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
964
- 1, 2
965
- ) # [b, t', t], [b, t, d] -> [b, d, t']
966
-
967
- z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
968
- torch.onnx.export(
969
- self.flow,
970
- (z_p, y_mask, g),
971
- f"onnx/{path}/{path}_flow.onnx",
972
- input_names=["z_p", "y_mask", "g"],
973
- output_names=["z"],
974
- dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
975
- verbose=True,
976
- )
977
-
978
- z = self.flow(z_p, y_mask, g=g, reverse=True)
979
- z_in = (z * y_mask)[:, :, :max_len]
980
-
981
- torch.onnx.export(
982
- self.dec,
983
- (z_in, g),
984
- f"onnx/{path}/{path}_dec.onnx",
985
- input_names=["z_in", "g"],
986
- output_names=["o"],
987
- dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
988
- verbose=True,
989
- )
990
- o = self.dec((z * y_mask)[:, :, :max_len], g=g)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .symbols import *
 
 
onnx_modules/V200/text/bert_utils.py DELETED
@@ -1,23 +0,0 @@
1
- from pathlib import Path
2
-
3
- from huggingface_hub import hf_hub_download
4
-
5
- from config import config
6
-
7
-
8
- MIRROR: str = config.mirror
9
-
10
-
11
- def _check_bert(repo_id, files, local_path):
12
- for file in files:
13
- if not Path(local_path).joinpath(file).exists():
14
- if MIRROR.lower() == "openi":
15
- import openi
16
-
17
- openi.model.download_model(
18
- "Stardust_minus/Bert-VITS2", repo_id.split("/")[-1], "./bert"
19
- )
20
- else:
21
- hf_hub_download(
22
- repo_id, file, local_dir=local_path, local_dir_use_symlinks=False
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/chinese.py DELETED
@@ -1,198 +0,0 @@
1
- import os
2
- import re
3
-
4
- import cn2an
5
- from pypinyin import lazy_pinyin, Style
6
-
7
- from .symbols import punctuation
8
- from .tone_sandhi import ToneSandhi
9
-
10
- current_file_path = os.path.dirname(__file__)
11
- pinyin_to_symbol_map = {
12
- line.split("\t")[0]: line.strip().split("\t")[1]
13
- for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14
- }
15
-
16
- import jieba.posseg as psg
17
-
18
-
19
- rep_map = {
20
- ":": ",",
21
- ";": ",",
22
- ",": ",",
23
- "。": ".",
24
- "!": "!",
25
- "?": "?",
26
- "\n": ".",
27
- "·": ",",
28
- "、": ",",
29
- "...": "…",
30
- "$": ".",
31
- "“": "'",
32
- "”": "'",
33
- "‘": "'",
34
- "’": "'",
35
- "(": "'",
36
- ")": "'",
37
- "(": "'",
38
- ")": "'",
39
- "《": "'",
40
- "》": "'",
41
- "【": "'",
42
- "】": "'",
43
- "[": "'",
44
- "]": "'",
45
- "—": "-",
46
- "~": "-",
47
- "~": "-",
48
- "「": "'",
49
- "」": "'",
50
- }
51
-
52
- tone_modifier = ToneSandhi()
53
-
54
-
55
- def replace_punctuation(text):
56
- text = text.replace("嗯", "恩").replace("呣", "母")
57
- pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
58
-
59
- replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
60
-
61
- replaced_text = re.sub(
62
- r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
63
- )
64
-
65
- return replaced_text
66
-
67
-
68
- def g2p(text):
69
- pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
70
- sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
71
- phones, tones, word2ph = _g2p(sentences)
72
- assert sum(word2ph) == len(phones)
73
- assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
74
- phones = ["_"] + phones + ["_"]
75
- tones = [0] + tones + [0]
76
- word2ph = [1] + word2ph + [1]
77
- return phones, tones, word2ph
78
-
79
-
80
- def _get_initials_finals(word):
81
- initials = []
82
- finals = []
83
- orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
84
- orig_finals = lazy_pinyin(
85
- word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
86
- )
87
- for c, v in zip(orig_initials, orig_finals):
88
- initials.append(c)
89
- finals.append(v)
90
- return initials, finals
91
-
92
-
93
- def _g2p(segments):
94
- phones_list = []
95
- tones_list = []
96
- word2ph = []
97
- for seg in segments:
98
- # Replace all English words in the sentence
99
- seg = re.sub("[a-zA-Z]+", "", seg)
100
- seg_cut = psg.lcut(seg)
101
- initials = []
102
- finals = []
103
- seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
104
- for word, pos in seg_cut:
105
- if pos == "eng":
106
- continue
107
- sub_initials, sub_finals = _get_initials_finals(word)
108
- sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
109
- initials.append(sub_initials)
110
- finals.append(sub_finals)
111
-
112
- # assert len(sub_initials) == len(sub_finals) == len(word)
113
- initials = sum(initials, [])
114
- finals = sum(finals, [])
115
- #
116
- for c, v in zip(initials, finals):
117
- raw_pinyin = c + v
118
- # NOTE: post process for pypinyin outputs
119
- # we discriminate i, ii and iii
120
- if c == v:
121
- assert c in punctuation
122
- phone = [c]
123
- tone = "0"
124
- word2ph.append(1)
125
- else:
126
- v_without_tone = v[:-1]
127
- tone = v[-1]
128
-
129
- pinyin = c + v_without_tone
130
- assert tone in "12345"
131
-
132
- if c:
133
- # 多音节
134
- v_rep_map = {
135
- "uei": "ui",
136
- "iou": "iu",
137
- "uen": "un",
138
- }
139
- if v_without_tone in v_rep_map.keys():
140
- pinyin = c + v_rep_map[v_without_tone]
141
- else:
142
- # 单音节
143
- pinyin_rep_map = {
144
- "ing": "ying",
145
- "i": "yi",
146
- "in": "yin",
147
- "u": "wu",
148
- }
149
- if pinyin in pinyin_rep_map.keys():
150
- pinyin = pinyin_rep_map[pinyin]
151
- else:
152
- single_rep_map = {
153
- "v": "yu",
154
- "e": "e",
155
- "i": "y",
156
- "u": "w",
157
- }
158
- if pinyin[0] in single_rep_map.keys():
159
- pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
160
-
161
- assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
162
- phone = pinyin_to_symbol_map[pinyin].split(" ")
163
- word2ph.append(len(phone))
164
-
165
- phones_list += phone
166
- tones_list += [int(tone)] * len(phone)
167
- return phones_list, tones_list, word2ph
168
-
169
-
170
- def text_normalize(text):
171
- numbers = re.findall(r"\d+(?:\.?\d+)?", text)
172
- for number in numbers:
173
- text = text.replace(number, cn2an.an2cn(number), 1)
174
- text = replace_punctuation(text)
175
- return text
176
-
177
-
178
- def get_bert_feature(text, word2ph):
179
- from text import chinese_bert
180
-
181
- return chinese_bert.get_bert_feature(text, word2ph)
182
-
183
-
184
- if __name__ == "__main__":
185
- from text.chinese_bert import get_bert_feature
186
-
187
- text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
188
- text = text_normalize(text)
189
- print(text)
190
- phones, tones, word2ph = g2p(text)
191
- bert = get_bert_feature(text, word2ph)
192
-
193
- print(phones, tones, word2ph, bert.shape)
194
-
195
-
196
- # # 示例用法
197
- # text = "这是一个示例文本:,你好!这是一个测试...."
198
- # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/chinese_bert.py DELETED
@@ -1,101 +0,0 @@
1
- import sys
2
-
3
- import torch
4
- from transformers import AutoModelForMaskedLM, AutoTokenizer
5
-
6
- from config import config
7
-
8
- LOCAL_PATH = "./bert/chinese-roberta-wwm-ext-large"
9
-
10
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
11
-
12
- models = dict()
13
-
14
-
15
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
16
- if (
17
- sys.platform == "darwin"
18
- and torch.backends.mps.is_available()
19
- and device == "cpu"
20
- ):
21
- device = "mps"
22
- if not device:
23
- device = "cuda"
24
- if device not in models.keys():
25
- models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
26
- with torch.no_grad():
27
- inputs = tokenizer(text, return_tensors="pt")
28
- for i in inputs:
29
- inputs[i] = inputs[i].to(device)
30
- res = models[device](**inputs, output_hidden_states=True)
31
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
32
-
33
- assert len(word2ph) == len(text) + 2
34
- word2phone = word2ph
35
- phone_level_feature = []
36
- for i in range(len(word2phone)):
37
- repeat_feature = res[i].repeat(word2phone[i], 1)
38
- phone_level_feature.append(repeat_feature)
39
-
40
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
41
-
42
- return phone_level_feature.T
43
-
44
-
45
- if __name__ == "__main__":
46
- word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
47
- word2phone = [
48
- 1,
49
- 2,
50
- 1,
51
- 2,
52
- 2,
53
- 1,
54
- 2,
55
- 2,
56
- 1,
57
- 2,
58
- 2,
59
- 1,
60
- 2,
61
- 2,
62
- 2,
63
- 2,
64
- 2,
65
- 1,
66
- 1,
67
- 2,
68
- 2,
69
- 1,
70
- 2,
71
- 2,
72
- 2,
73
- 2,
74
- 1,
75
- 2,
76
- 2,
77
- 2,
78
- 2,
79
- 2,
80
- 1,
81
- 2,
82
- 2,
83
- 2,
84
- 2,
85
- 1,
86
- ]
87
-
88
- # 计算总帧数
89
- total_frames = sum(word2phone)
90
- print(word_level_feature.shape)
91
- print(word2phone)
92
- phone_level_feature = []
93
- for i in range(len(word2phone)):
94
- print(word_level_feature[i].shape)
95
-
96
- # 对每个词重复word2phone[i]次
97
- repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
98
- phone_level_feature.append(repeat_feature)
99
-
100
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
101
- print(phone_level_feature.shape) # torch.Size([36, 1024])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/cleaner.py DELETED
@@ -1,28 +0,0 @@
1
- from . import chinese, japanese, english, cleaned_text_to_sequence
2
-
3
-
4
- language_module_map = {"ZH": chinese, "JP": japanese, "EN": english}
5
-
6
-
7
- def clean_text(text, language):
8
- language_module = language_module_map[language]
9
- norm_text = language_module.text_normalize(text)
10
- phones, tones, word2ph = language_module.g2p(norm_text)
11
- return norm_text, phones, tones, word2ph
12
-
13
-
14
- def clean_text_bert(text, language):
15
- language_module = language_module_map[language]
16
- norm_text = language_module.text_normalize(text)
17
- phones, tones, word2ph = language_module.g2p(norm_text)
18
- bert = language_module.get_bert_feature(norm_text, word2ph)
19
- return phones, tones, bert
20
-
21
-
22
- def text_to_sequence(text, language):
23
- norm_text, phones, tones, word2ph = clean_text(text, language)
24
- return cleaned_text_to_sequence(phones, tones, language)
25
-
26
-
27
- if __name__ == "__main__":
28
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/english.py DELETED
@@ -1,362 +0,0 @@
1
- import pickle
2
- import os
3
- import re
4
- from g2p_en import G2p
5
-
6
- from . import symbols
7
-
8
- current_file_path = os.path.dirname(__file__)
9
- CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
10
- CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
11
- _g2p = G2p()
12
-
13
- arpa = {
14
- "AH0",
15
- "S",
16
- "AH1",
17
- "EY2",
18
- "AE2",
19
- "EH0",
20
- "OW2",
21
- "UH0",
22
- "NG",
23
- "B",
24
- "G",
25
- "AY0",
26
- "M",
27
- "AA0",
28
- "F",
29
- "AO0",
30
- "ER2",
31
- "UH1",
32
- "IY1",
33
- "AH2",
34
- "DH",
35
- "IY0",
36
- "EY1",
37
- "IH0",
38
- "K",
39
- "N",
40
- "W",
41
- "IY2",
42
- "T",
43
- "AA1",
44
- "ER1",
45
- "EH2",
46
- "OY0",
47
- "UH2",
48
- "UW1",
49
- "Z",
50
- "AW2",
51
- "AW1",
52
- "V",
53
- "UW2",
54
- "AA2",
55
- "ER",
56
- "AW0",
57
- "UW0",
58
- "R",
59
- "OW1",
60
- "EH1",
61
- "ZH",
62
- "AE0",
63
- "IH2",
64
- "IH",
65
- "Y",
66
- "JH",
67
- "P",
68
- "AY1",
69
- "EY0",
70
- "OY2",
71
- "TH",
72
- "HH",
73
- "D",
74
- "ER0",
75
- "CH",
76
- "AO1",
77
- "AE1",
78
- "AO2",
79
- "OY1",
80
- "AY2",
81
- "IH1",
82
- "OW0",
83
- "L",
84
- "SH",
85
- }
86
-
87
-
88
- def post_replace_ph(ph):
89
- rep_map = {
90
- ":": ",",
91
- ";": ",",
92
- ",": ",",
93
- "。": ".",
94
- "!": "!",
95
- "?": "?",
96
- "\n": ".",
97
- "·": ",",
98
- "、": ",",
99
- "...": "…",
100
- "v": "V",
101
- }
102
- if ph in rep_map.keys():
103
- ph = rep_map[ph]
104
- if ph in symbols:
105
- return ph
106
- if ph not in symbols:
107
- ph = "UNK"
108
- return ph
109
-
110
-
111
- def read_dict():
112
- g2p_dict = {}
113
- start_line = 49
114
- with open(CMU_DICT_PATH) as f:
115
- line = f.readline()
116
- line_index = 1
117
- while line:
118
- if line_index >= start_line:
119
- line = line.strip()
120
- word_split = line.split(" ")
121
- word = word_split[0]
122
-
123
- syllable_split = word_split[1].split(" - ")
124
- g2p_dict[word] = []
125
- for syllable in syllable_split:
126
- phone_split = syllable.split(" ")
127
- g2p_dict[word].append(phone_split)
128
-
129
- line_index = line_index + 1
130
- line = f.readline()
131
-
132
- return g2p_dict
133
-
134
-
135
- def cache_dict(g2p_dict, file_path):
136
- with open(file_path, "wb") as pickle_file:
137
- pickle.dump(g2p_dict, pickle_file)
138
-
139
-
140
- def get_dict():
141
- if os.path.exists(CACHE_PATH):
142
- with open(CACHE_PATH, "rb") as pickle_file:
143
- g2p_dict = pickle.load(pickle_file)
144
- else:
145
- g2p_dict = read_dict()
146
- cache_dict(g2p_dict, CACHE_PATH)
147
-
148
- return g2p_dict
149
-
150
-
151
- eng_dict = get_dict()
152
-
153
-
154
- def refine_ph(phn):
155
- tone = 0
156
- if re.search(r"\d$", phn):
157
- tone = int(phn[-1]) + 1
158
- phn = phn[:-1]
159
- return phn.lower(), tone
160
-
161
-
162
- def refine_syllables(syllables):
163
- tones = []
164
- phonemes = []
165
- for phn_list in syllables:
166
- for i in range(len(phn_list)):
167
- phn = phn_list[i]
168
- phn, tone = refine_ph(phn)
169
- phonemes.append(phn)
170
- tones.append(tone)
171
- return phonemes, tones
172
-
173
-
174
- import re
175
- import inflect
176
-
177
- _inflect = inflect.engine()
178
- _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
179
- _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
180
- _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
181
- _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
182
- _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
183
- _number_re = re.compile(r"[0-9]+")
184
-
185
- # List of (regular expression, replacement) pairs for abbreviations:
186
- _abbreviations = [
187
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
188
- for x in [
189
- ("mrs", "misess"),
190
- ("mr", "mister"),
191
- ("dr", "doctor"),
192
- ("st", "saint"),
193
- ("co", "company"),
194
- ("jr", "junior"),
195
- ("maj", "major"),
196
- ("gen", "general"),
197
- ("drs", "doctors"),
198
- ("rev", "reverend"),
199
- ("lt", "lieutenant"),
200
- ("hon", "honorable"),
201
- ("sgt", "sergeant"),
202
- ("capt", "captain"),
203
- ("esq", "esquire"),
204
- ("ltd", "limited"),
205
- ("col", "colonel"),
206
- ("ft", "fort"),
207
- ]
208
- ]
209
-
210
-
211
- # List of (ipa, lazy ipa) pairs:
212
- _lazy_ipa = [
213
- (re.compile("%s" % x[0]), x[1])
214
- for x in [
215
- ("r", "ɹ"),
216
- ("æ", "e"),
217
- ("ɑ", "a"),
218
- ("ɔ", "o"),
219
- ("ð", "z"),
220
- ("θ", "s"),
221
- ("ɛ", "e"),
222
- ("ɪ", "i"),
223
- ("ʊ", "u"),
224
- ("ʒ", "ʥ"),
225
- ("ʤ", "ʥ"),
226
- ("ˈ", "↓"),
227
- ]
228
- ]
229
-
230
- # List of (ipa, lazy ipa2) pairs:
231
- _lazy_ipa2 = [
232
- (re.compile("%s" % x[0]), x[1])
233
- for x in [
234
- ("r", "ɹ"),
235
- ("ð", "z"),
236
- ("θ", "s"),
237
- ("ʒ", "ʑ"),
238
- ("ʤ", "dʑ"),
239
- ("ˈ", "↓"),
240
- ]
241
- ]
242
-
243
- # List of (ipa, ipa2) pairs
244
- _ipa_to_ipa2 = [
245
- (re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")]
246
- ]
247
-
248
-
249
- def _expand_dollars(m):
250
- match = m.group(1)
251
- parts = match.split(".")
252
- if len(parts) > 2:
253
- return match + " dollars" # Unexpected format
254
- dollars = int(parts[0]) if parts[0] else 0
255
- cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
256
- if dollars and cents:
257
- dollar_unit = "dollar" if dollars == 1 else "dollars"
258
- cent_unit = "cent" if cents == 1 else "cents"
259
- return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
260
- elif dollars:
261
- dollar_unit = "dollar" if dollars == 1 else "dollars"
262
- return "%s %s" % (dollars, dollar_unit)
263
- elif cents:
264
- cent_unit = "cent" if cents == 1 else "cents"
265
- return "%s %s" % (cents, cent_unit)
266
- else:
267
- return "zero dollars"
268
-
269
-
270
- def _remove_commas(m):
271
- return m.group(1).replace(",", "")
272
-
273
-
274
- def _expand_ordinal(m):
275
- return _inflect.number_to_words(m.group(0))
276
-
277
-
278
- def _expand_number(m):
279
- num = int(m.group(0))
280
- if num > 1000 and num < 3000:
281
- if num == 2000:
282
- return "two thousand"
283
- elif num > 2000 and num < 2010:
284
- return "two thousand " + _inflect.number_to_words(num % 100)
285
- elif num % 100 == 0:
286
- return _inflect.number_to_words(num // 100) + " hundred"
287
- else:
288
- return _inflect.number_to_words(
289
- num, andword="", zero="oh", group=2
290
- ).replace(", ", " ")
291
- else:
292
- return _inflect.number_to_words(num, andword="")
293
-
294
-
295
- def _expand_decimal_point(m):
296
- return m.group(1).replace(".", " point ")
297
-
298
-
299
- def normalize_numbers(text):
300
- text = re.sub(_comma_number_re, _remove_commas, text)
301
- text = re.sub(_pounds_re, r"\1 pounds", text)
302
- text = re.sub(_dollars_re, _expand_dollars, text)
303
- text = re.sub(_decimal_number_re, _expand_decimal_point, text)
304
- text = re.sub(_ordinal_re, _expand_ordinal, text)
305
- text = re.sub(_number_re, _expand_number, text)
306
- return text
307
-
308
-
309
- def text_normalize(text):
310
- text = normalize_numbers(text)
311
- return text
312
-
313
-
314
- def g2p(text):
315
- phones = []
316
- tones = []
317
- word2ph = []
318
- words = re.split(r"([,;.\-\?\!\s+])", text)
319
- words = [word for word in words if word.strip() != ""]
320
- for word in words:
321
- if word.upper() in eng_dict:
322
- phns, tns = refine_syllables(eng_dict[word.upper()])
323
- phones += phns
324
- tones += tns
325
- word2ph.append(len(phns))
326
- else:
327
- phone_list = list(filter(lambda p: p != " ", _g2p(word)))
328
- for ph in phone_list:
329
- if ph in arpa:
330
- ph, tn = refine_ph(ph)
331
- phones.append(ph)
332
- tones.append(tn)
333
- else:
334
- phones.append(ph)
335
- tones.append(0)
336
- word2ph.append(len(phone_list))
337
-
338
- phones = [post_replace_ph(i) for i in phones]
339
-
340
- phones = ["_"] + phones + ["_"]
341
- tones = [0] + tones + [0]
342
- word2ph = [1] + word2ph + [1]
343
-
344
- return phones, tones, word2ph
345
-
346
-
347
- def get_bert_feature(text, word2ph):
348
- from text import english_bert_mock
349
-
350
- return english_bert_mock.get_bert_feature(text, word2ph)
351
-
352
-
353
- if __name__ == "__main__":
354
- # print(get_dict())
355
- # print(eng_word_to_phoneme("hello"))
356
- print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))
357
- # all_phones = set()
358
- # for k, syllables in eng_dict.items():
359
- # for group in syllables:
360
- # for ph in group:
361
- # all_phones.add(ph)
362
- # print(all_phones)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/english_bert_mock.py DELETED
@@ -1,42 +0,0 @@
1
- import sys
2
-
3
- import torch
4
- from transformers import DebertaV2Model, DebertaV2Tokenizer
5
-
6
- from config import config
7
-
8
-
9
- LOCAL_PATH = "./bert/deberta-v3-large"
10
-
11
- tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
12
-
13
- models = dict()
14
-
15
-
16
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
17
- if (
18
- sys.platform == "darwin"
19
- and torch.backends.mps.is_available()
20
- and device == "cpu"
21
- ):
22
- device = "mps"
23
- if not device:
24
- device = "cuda"
25
- if device not in models.keys():
26
- models[device] = DebertaV2Model.from_pretrained(LOCAL_PATH).to(device)
27
- with torch.no_grad():
28
- inputs = tokenizer(text, return_tensors="pt")
29
- for i in inputs:
30
- inputs[i] = inputs[i].to(device)
31
- res = models[device](**inputs, output_hidden_states=True)
32
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
33
- # assert len(word2ph) == len(text)+2
34
- word2phone = word2ph
35
- phone_level_feature = []
36
- for i in range(len(word2phone)):
37
- repeat_feature = res[i].repeat(word2phone[i], 1)
38
- phone_level_feature.append(repeat_feature)
39
-
40
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
41
-
42
- return phone_level_feature.T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/japanese.py DELETED
@@ -1,403 +0,0 @@
1
- # Convert Japanese text to phonemes which is
2
- # compatible with Julius https://github.com/julius-speech/segmentation-kit
3
- import re
4
- import unicodedata
5
-
6
- from transformers import AutoTokenizer
7
-
8
- from . import punctuation, symbols
9
-
10
- from num2words import num2words
11
-
12
- import pyopenjtalk
13
- import jaconv
14
-
15
-
16
- def kata2phoneme(text: str) -> str:
17
- """Convert katakana text to phonemes."""
18
- text = text.strip()
19
- if text == "ー":
20
- return ["ー"]
21
- elif text.startswith("ー"):
22
- return ["ー"] + kata2phoneme(text[1:])
23
- res = []
24
- prev = None
25
- while text:
26
- if re.match(_MARKS, text):
27
- res.append(text)
28
- text = text[1:]
29
- continue
30
- if text.startswith("ー"):
31
- if prev:
32
- res.append(prev[-1])
33
- text = text[1:]
34
- continue
35
- res += pyopenjtalk.g2p(text).lower().replace("cl", "q").split(" ")
36
- break
37
- # res = _COLON_RX.sub(":", res)
38
- return res
39
-
40
-
41
- def hira2kata(text: str) -> str:
42
- return jaconv.hira2kata(text)
43
-
44
-
45
- _SYMBOL_TOKENS = set(list("・、。?!"))
46
- _NO_YOMI_TOKENS = set(list("「」『』―()[][]"))
47
- _MARKS = re.compile(
48
- r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
49
- )
50
-
51
-
52
- def text2kata(text: str) -> str:
53
- parsed = pyopenjtalk.run_frontend(text)
54
-
55
- res = []
56
- for parts in parsed:
57
- word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace(
58
- "’", ""
59
- )
60
- if yomi:
61
- if re.match(_MARKS, yomi):
62
- if len(word) > 1:
63
- word = [replace_punctuation(i) for i in list(word)]
64
- yomi = word
65
- res += yomi
66
- sep += word
67
- continue
68
- elif word not in rep_map.keys() and word not in rep_map.values():
69
- word = ","
70
- yomi = word
71
- res.append(yomi)
72
- else:
73
- if word in _SYMBOL_TOKENS:
74
- res.append(word)
75
- elif word in ("っ", "ッ"):
76
- res.append("ッ")
77
- elif word in _NO_YOMI_TOKENS:
78
- pass
79
- else:
80
- res.append(word)
81
- return hira2kata("".join(res))
82
-
83
-
84
- def text2sep_kata(text: str) -> (list, list):
85
- parsed = pyopenjtalk.run_frontend(text)
86
-
87
- res = []
88
- sep = []
89
- for parts in parsed:
90
- word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace(
91
- "’", ""
92
- )
93
- if yomi:
94
- if re.match(_MARKS, yomi):
95
- if len(word) > 1:
96
- word = [replace_punctuation(i) for i in list(word)]
97
- yomi = word
98
- res += yomi
99
- sep += word
100
- continue
101
- elif word not in rep_map.keys() and word not in rep_map.values():
102
- word = ","
103
- yomi = word
104
- res.append(yomi)
105
- else:
106
- if word in _SYMBOL_TOKENS:
107
- res.append(word)
108
- elif word in ("っ", "ッ"):
109
- res.append("ッ")
110
- elif word in _NO_YOMI_TOKENS:
111
- pass
112
- else:
113
- res.append(word)
114
- sep.append(word)
115
- return sep, [hira2kata(i) for i in res], get_accent(parsed)
116
-
117
-
118
- def get_accent(parsed):
119
- labels = pyopenjtalk.make_label(parsed)
120
-
121
- phonemes = []
122
- accents = []
123
- for n, label in enumerate(labels):
124
- phoneme = re.search(r"\-([^\+]*)\+", label).group(1)
125
- if phoneme not in ["sil", "pau"]:
126
- phonemes.append(phoneme.replace("cl", "q").lower())
127
- else:
128
- continue
129
- a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
130
- a2 = int(re.search(r"\+(\d+)\+", label).group(1))
131
- if re.search(r"\-([^\+]*)\+", labels[n + 1]).group(1) in ["sil", "pau"]:
132
- a2_next = -1
133
- else:
134
- a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
135
- # Falling
136
- if a1 == 0 and a2_next == a2 + 1:
137
- accents.append(-1)
138
- # Rising
139
- elif a2 == 1 and a2_next == 2:
140
- accents.append(1)
141
- else:
142
- accents.append(0)
143
- return list(zip(phonemes, accents))
144
-
145
-
146
- _ALPHASYMBOL_YOMI = {
147
- "#": "シャープ",
148
- "%": "パーセント",
149
- "&": "アンド",
150
- "+": "プラス",
151
- "-": "マイナス",
152
- ":": "コロン",
153
- ";": "セミコロン",
154
- "<": "小なり",
155
- "=": "イコール",
156
- ">": "大なり",
157
- "@": "アット",
158
- "a": "エー",
159
- "b": "ビー",
160
- "c": "シー",
161
- "d": "ディー",
162
- "e": "イー",
163
- "f": "エフ",
164
- "g": "ジー",
165
- "h": "エイチ",
166
- "i": "アイ",
167
- "j": "ジェー",
168
- "k": "ケー",
169
- "l": "エル",
170
- "m": "エム",
171
- "n": "エヌ",
172
- "o": "オー",
173
- "p": "ピー",
174
- "q": "キュー",
175
- "r": "アール",
176
- "s": "エス",
177
- "t": "ティー",
178
- "u": "ユー",
179
- "v": "ブイ",
180
- "w": "ダブリュー",
181
- "x": "エックス",
182
- "y": "ワイ",
183
- "z": "ゼット",
184
- "α": "アルファ",
185
- "β": "ベータ",
186
- "γ": "ガンマ",
187
- "δ": "デルタ",
188
- "ε": "イプシロン",
189
- "ζ": "ゼータ",
190
- "η": "イータ",
191
- "θ": "シータ",
192
- "ι": "イオタ",
193
- "κ": "カッパ",
194
- "λ": "ラムダ",
195
- "μ": "ミュー",
196
- "ν": "ニュー",
197
- "ξ": "クサイ",
198
- "ο": "オミクロン",
199
- "π": "パイ",
200
- "ρ": "ロー",
201
- "σ": "シグマ",
202
- "τ": "タウ",
203
- "υ": "ウプシロン",
204
- "φ": "ファイ",
205
- "χ": "カイ",
206
- "ψ": "プサイ",
207
- "ω": "オメガ",
208
- }
209
-
210
-
211
- _NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+")
212
- _CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"}
213
- _CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])")
214
- _NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?")
215
-
216
-
217
- def japanese_convert_numbers_to_words(text: str) -> str:
218
- res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text)
219
- res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res)
220
- res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res)
221
- return res
222
-
223
-
224
- def japanese_convert_alpha_symbols_to_words(text: str) -> str:
225
- return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()])
226
-
227
-
228
- def japanese_text_to_phonemes(text: str) -> str:
229
- """Convert Japanese text to phonemes."""
230
- res = unicodedata.normalize("NFKC", text)
231
- res = japanese_convert_numbers_to_words(res)
232
- # res = japanese_convert_alpha_symbols_to_words(res)
233
- res = text2kata(res)
234
- res = kata2phoneme(res)
235
- return res
236
-
237
-
238
- def is_japanese_character(char):
239
- # 定义日语文字系统的 Unicode 范围
240
- japanese_ranges = [
241
- (0x3040, 0x309F), # 平假名
242
- (0x30A0, 0x30FF), # 片假名
243
- (0x4E00, 0x9FFF), # 汉字 (CJK Unified Ideographs)
244
- (0x3400, 0x4DBF), # 汉字扩展 A
245
- (0x20000, 0x2A6DF), # 汉字扩展 B
246
- # 可以根据需要添加其他汉字扩展范围
247
- ]
248
-
249
- # 将字符的 Unicode 编码转换为整数
250
- char_code = ord(char)
251
-
252
- # 检查字符是否在任何一个日语范围内
253
- for start, end in japanese_ranges:
254
- if start <= char_code <= end:
255
- return True
256
-
257
- return False
258
-
259
-
260
- rep_map = {
261
- ":": ",",
262
- ";": ",",
263
- ",": ",",
264
- "。": ".",
265
- "!": "!",
266
- "?": "?",
267
- "\n": ".",
268
- ".": ".",
269
- "...": "…",
270
- "···": "…",
271
- "・・・": "…",
272
- "·": ",",
273
- "・": ",",
274
- "、": ",",
275
- "$": ".",
276
- "“": "'",
277
- "”": "'",
278
- "‘": "'",
279
- "’": "'",
280
- "(": "'",
281
- ")": "'",
282
- "(": "'",
283
- ")": "'",
284
- "《": "'",
285
- "》": "'",
286
- "【": "'",
287
- "】": "'",
288
- "[": "'",
289
- "]": "'",
290
- "—": "-",
291
- "−": "-",
292
- "~": "-",
293
- "~": "-",
294
- "「": "'",
295
- "」": "'",
296
- }
297
-
298
-
299
- def replace_punctuation(text):
300
- pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
301
-
302
- replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
303
-
304
- replaced_text = re.sub(
305
- r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
306
- + "".join(punctuation)
307
- + r"]+",
308
- "",
309
- replaced_text,
310
- )
311
-
312
- return replaced_text
313
-
314
-
315
- def text_normalize(text):
316
- res = unicodedata.normalize("NFKC", text)
317
- res = japanese_convert_numbers_to_words(res)
318
- # res = "".join([i for i in res if is_japanese_character(i)])
319
- res = replace_punctuation(res)
320
- return res
321
-
322
-
323
- def distribute_phone(n_phone, n_word):
324
- phones_per_word = [0] * n_word
325
- for task in range(n_phone):
326
- min_tasks = min(phones_per_word)
327
- min_index = phones_per_word.index(min_tasks)
328
- phones_per_word[min_index] += 1
329
- return phones_per_word
330
-
331
-
332
- def handle_long(sep_phonemes):
333
- for i in range(len(sep_phonemes)):
334
- if sep_phonemes[i][0] == "ー":
335
- sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
336
- if "ー" in sep_phonemes[i]:
337
- for j in range(len(sep_phonemes[i])):
338
- if sep_phonemes[i][j] == "ー":
339
- sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
340
- return sep_phonemes
341
-
342
-
343
- tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese")
344
-
345
-
346
- def align_tones(phones, tones):
347
- res = []
348
- for pho in phones:
349
- temp = [0] * len(pho)
350
- for idx, p in enumerate(pho):
351
- if len(tones) == 0:
352
- break
353
- if p == tones[0][0]:
354
- temp[idx] = tones[0][1]
355
- if idx > 0:
356
- temp[idx] += temp[idx - 1]
357
- tones.pop(0)
358
- temp = [0] + temp
359
- temp = temp[:-1]
360
- if -1 in temp:
361
- temp = [i + 1 for i in temp]
362
- res.append(temp)
363
- res = [i for j in res for i in j]
364
- assert not any([i < 0 for i in res]) and not any([i > 1 for i in res])
365
- return res
366
-
367
-
368
- def g2p(norm_text):
369
- sep_text, sep_kata, acc = text2sep_kata(norm_text)
370
- sep_tokenized = [tokenizer.tokenize(i) for i in sep_text]
371
- sep_phonemes = handle_long([kata2phoneme(i) for i in sep_kata])
372
- # 异常处理,MeCab不认识的词的话会一路传到这里来,然后炸掉。目前来看只有那些超级稀有的生僻词会出现这种情况
373
- for i in sep_phonemes:
374
- for j in i:
375
- assert j in symbols, (sep_text, sep_kata, sep_phonemes)
376
- tones = align_tones(sep_phonemes, acc)
377
-
378
- word2ph = []
379
- for token, phoneme in zip(sep_tokenized, sep_phonemes):
380
- phone_len = len(phoneme)
381
- word_len = len(token)
382
-
383
- aaa = distribute_phone(phone_len, word_len)
384
- word2ph += aaa
385
- phones = ["_"] + [j for i in sep_phonemes for j in i] + ["_"]
386
- tones = [0] + tones + [0]
387
- word2ph = [1] + word2ph + [1]
388
- assert len(phones) == len(tones)
389
- return phones, tones, word2ph
390
-
391
-
392
- if __name__ == "__main__":
393
- tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese")
394
- text = "hello,こんにちは、世界ー!……"
395
- from text.japanese_bert import get_bert_feature
396
-
397
- text = text_normalize(text)
398
- print(text)
399
-
400
- phones, tones, word2ph = g2p(text)
401
- bert = get_bert_feature(text, word2ph)
402
-
403
- print(phones, tones, word2ph, bert.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/japanese_bert.py DELETED
@@ -1,58 +0,0 @@
1
- import sys
2
-
3
- import torch
4
- from transformers import AutoModelForMaskedLM, AutoTokenizer
5
-
6
- from config import config
7
- from .japanese import text2sep_kata
8
-
9
- LOCAL_PATH = "./bert/deberta-v2-large-japanese"
10
-
11
- tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
12
-
13
- models = dict()
14
-
15
-
16
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
17
- sep_text, _, _ = text2sep_kata(text)
18
- sep_tokens = [tokenizer.tokenize(t) for t in sep_text]
19
- sep_ids = [tokenizer.convert_tokens_to_ids(t) for t in sep_tokens]
20
- sep_ids = [2] + [item for sublist in sep_ids for item in sublist] + [3]
21
- return get_bert_feature_with_token(sep_ids, word2ph, device)
22
-
23
-
24
- def get_bert_feature_with_token(tokens, word2ph, device=config.bert_gen_config.device):
25
- if (
26
- sys.platform == "darwin"
27
- and torch.backends.mps.is_available()
28
- and device == "cpu"
29
- ):
30
- device = "mps"
31
- if not device:
32
- device = "cuda"
33
- if device not in models.keys():
34
- models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
35
- with torch.no_grad():
36
- inputs = torch.tensor(tokens).to(device).unsqueeze(0)
37
- token_type_ids = torch.zeros_like(inputs).to(device)
38
- attention_mask = torch.ones_like(inputs).to(device)
39
- inputs = {
40
- "input_ids": inputs,
41
- "token_type_ids": token_type_ids,
42
- "attention_mask": attention_mask,
43
- }
44
-
45
- # for i in inputs:
46
- # inputs[i] = inputs[i].to(device)
47
- res = models[device](**inputs, output_hidden_states=True)
48
- res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
49
- assert inputs["input_ids"].shape[-1] == len(word2ph)
50
- word2phone = word2ph
51
- phone_level_feature = []
52
- for i in range(len(word2phone)):
53
- repeat_feature = res[i].repeat(word2phone[i], 1)
54
- phone_level_feature.append(repeat_feature)
55
-
56
- phone_level_feature = torch.cat(phone_level_feature, dim=0)
57
-
58
- return phone_level_feature.T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/opencpop-strict.txt DELETED
@@ -1,429 +0,0 @@
1
- a AA a
2
- ai AA ai
3
- an AA an
4
- ang AA ang
5
- ao AA ao
6
- ba b a
7
- bai b ai
8
- ban b an
9
- bang b ang
10
- bao b ao
11
- bei b ei
12
- ben b en
13
- beng b eng
14
- bi b i
15
- bian b ian
16
- biao b iao
17
- bie b ie
18
- bin b in
19
- bing b ing
20
- bo b o
21
- bu b u
22
- ca c a
23
- cai c ai
24
- can c an
25
- cang c ang
26
- cao c ao
27
- ce c e
28
- cei c ei
29
- cen c en
30
- ceng c eng
31
- cha ch a
32
- chai ch ai
33
- chan ch an
34
- chang ch ang
35
- chao ch ao
36
- che ch e
37
- chen ch en
38
- cheng ch eng
39
- chi ch ir
40
- chong ch ong
41
- chou ch ou
42
- chu ch u
43
- chua ch ua
44
- chuai ch uai
45
- chuan ch uan
46
- chuang ch uang
47
- chui ch ui
48
- chun ch un
49
- chuo ch uo
50
- ci c i0
51
- cong c ong
52
- cou c ou
53
- cu c u
54
- cuan c uan
55
- cui c ui
56
- cun c un
57
- cuo c uo
58
- da d a
59
- dai d ai
60
- dan d an
61
- dang d ang
62
- dao d ao
63
- de d e
64
- dei d ei
65
- den d en
66
- deng d eng
67
- di d i
68
- dia d ia
69
- dian d ian
70
- diao d iao
71
- die d ie
72
- ding d ing
73
- diu d iu
74
- dong d ong
75
- dou d ou
76
- du d u
77
- duan d uan
78
- dui d ui
79
- dun d un
80
- duo d uo
81
- e EE e
82
- ei EE ei
83
- en EE en
84
- eng EE eng
85
- er EE er
86
- fa f a
87
- fan f an
88
- fang f ang
89
- fei f ei
90
- fen f en
91
- feng f eng
92
- fo f o
93
- fou f ou
94
- fu f u
95
- ga g a
96
- gai g ai
97
- gan g an
98
- gang g ang
99
- gao g ao
100
- ge g e
101
- gei g ei
102
- gen g en
103
- geng g eng
104
- gong g ong
105
- gou g ou
106
- gu g u
107
- gua g ua
108
- guai g uai
109
- guan g uan
110
- guang g uang
111
- gui g ui
112
- gun g un
113
- guo g uo
114
- ha h a
115
- hai h ai
116
- han h an
117
- hang h ang
118
- hao h ao
119
- he h e
120
- hei h ei
121
- hen h en
122
- heng h eng
123
- hong h ong
124
- hou h ou
125
- hu h u
126
- hua h ua
127
- huai h uai
128
- huan h uan
129
- huang h uang
130
- hui h ui
131
- hun h un
132
- huo h uo
133
- ji j i
134
- jia j ia
135
- jian j ian
136
- jiang j iang
137
- jiao j iao
138
- jie j ie
139
- jin j in
140
- jing j ing
141
- jiong j iong
142
- jiu j iu
143
- ju j v
144
- jv j v
145
- juan j van
146
- jvan j van
147
- jue j ve
148
- jve j ve
149
- jun j vn
150
- jvn j vn
151
- ka k a
152
- kai k ai
153
- kan k an
154
- kang k ang
155
- kao k ao
156
- ke k e
157
- kei k ei
158
- ken k en
159
- keng k eng
160
- kong k ong
161
- kou k ou
162
- ku k u
163
- kua k ua
164
- kuai k uai
165
- kuan k uan
166
- kuang k uang
167
- kui k ui
168
- kun k un
169
- kuo k uo
170
- la l a
171
- lai l ai
172
- lan l an
173
- lang l ang
174
- lao l ao
175
- le l e
176
- lei l ei
177
- leng l eng
178
- li l i
179
- lia l ia
180
- lian l ian
181
- liang l iang
182
- liao l iao
183
- lie l ie
184
- lin l in
185
- ling l ing
186
- liu l iu
187
- lo l o
188
- long l ong
189
- lou l ou
190
- lu l u
191
- luan l uan
192
- lun l un
193
- luo l uo
194
- lv l v
195
- lve l ve
196
- ma m a
197
- mai m ai
198
- man m an
199
- mang m ang
200
- mao m ao
201
- me m e
202
- mei m ei
203
- men m en
204
- meng m eng
205
- mi m i
206
- mian m ian
207
- miao m iao
208
- mie m ie
209
- min m in
210
- ming m ing
211
- miu m iu
212
- mo m o
213
- mou m ou
214
- mu m u
215
- na n a
216
- nai n ai
217
- nan n an
218
- nang n ang
219
- nao n ao
220
- ne n e
221
- nei n ei
222
- nen n en
223
- neng n eng
224
- ni n i
225
- nian n ian
226
- niang n iang
227
- niao n iao
228
- nie n ie
229
- nin n in
230
- ning n ing
231
- niu n iu
232
- nong n ong
233
- nou n ou
234
- nu n u
235
- nuan n uan
236
- nun n un
237
- nuo n uo
238
- nv n v
239
- nve n ve
240
- o OO o
241
- ou OO ou
242
- pa p a
243
- pai p ai
244
- pan p an
245
- pang p ang
246
- pao p ao
247
- pei p ei
248
- pen p en
249
- peng p eng
250
- pi p i
251
- pian p ian
252
- piao p iao
253
- pie p ie
254
- pin p in
255
- ping p ing
256
- po p o
257
- pou p ou
258
- pu p u
259
- qi q i
260
- qia q ia
261
- qian q ian
262
- qiang q iang
263
- qiao q iao
264
- qie q ie
265
- qin q in
266
- qing q ing
267
- qiong q iong
268
- qiu q iu
269
- qu q v
270
- qv q v
271
- quan q van
272
- qvan q van
273
- que q ve
274
- qve q ve
275
- qun q vn
276
- qvn q vn
277
- ran r an
278
- rang r ang
279
- rao r ao
280
- re r e
281
- ren r en
282
- reng r eng
283
- ri r ir
284
- rong r ong
285
- rou r ou
286
- ru r u
287
- rua r ua
288
- ruan r uan
289
- rui r ui
290
- run r un
291
- ruo r uo
292
- sa s a
293
- sai s ai
294
- san s an
295
- sang s ang
296
- sao s ao
297
- se s e
298
- sen s en
299
- seng s eng
300
- sha sh a
301
- shai sh ai
302
- shan sh an
303
- shang sh ang
304
- shao sh ao
305
- she sh e
306
- shei sh ei
307
- shen sh en
308
- sheng sh eng
309
- shi sh ir
310
- shou sh ou
311
- shu sh u
312
- shua sh ua
313
- shuai sh uai
314
- shuan sh uan
315
- shuang sh uang
316
- shui sh ui
317
- shun sh un
318
- shuo sh uo
319
- si s i0
320
- song s ong
321
- sou s ou
322
- su s u
323
- suan s uan
324
- sui s ui
325
- sun s un
326
- suo s uo
327
- ta t a
328
- tai t ai
329
- tan t an
330
- tang t ang
331
- tao t ao
332
- te t e
333
- tei t ei
334
- teng t eng
335
- ti t i
336
- tian t ian
337
- tiao t iao
338
- tie t ie
339
- ting t ing
340
- tong t ong
341
- tou t ou
342
- tu t u
343
- tuan t uan
344
- tui t ui
345
- tun t un
346
- tuo t uo
347
- wa w a
348
- wai w ai
349
- wan w an
350
- wang w ang
351
- wei w ei
352
- wen w en
353
- weng w eng
354
- wo w o
355
- wu w u
356
- xi x i
357
- xia x ia
358
- xian x ian
359
- xiang x iang
360
- xiao x iao
361
- xie x ie
362
- xin x in
363
- xing x ing
364
- xiong x iong
365
- xiu x iu
366
- xu x v
367
- xv x v
368
- xuan x van
369
- xvan x van
370
- xue x ve
371
- xve x ve
372
- xun x vn
373
- xvn x vn
374
- ya y a
375
- yan y En
376
- yang y ang
377
- yao y ao
378
- ye y E
379
- yi y i
380
- yin y in
381
- ying y ing
382
- yo y o
383
- yong y ong
384
- you y ou
385
- yu y v
386
- yv y v
387
- yuan y van
388
- yvan y van
389
- yue y ve
390
- yve y ve
391
- yun y vn
392
- yvn y vn
393
- za z a
394
- zai z ai
395
- zan z an
396
- zang z ang
397
- zao z ao
398
- ze z e
399
- zei z ei
400
- zen z en
401
- zeng z eng
402
- zha zh a
403
- zhai zh ai
404
- zhan zh an
405
- zhang zh ang
406
- zhao zh ao
407
- zhe zh e
408
- zhei zh ei
409
- zhen zh en
410
- zheng zh eng
411
- zhi zh ir
412
- zhong zh ong
413
- zhou zh ou
414
- zhu zh u
415
- zhua zh ua
416
- zhuai zh uai
417
- zhuan zh uan
418
- zhuang zh uang
419
- zhui zh ui
420
- zhun zh un
421
- zhuo zh uo
422
- zi z i0
423
- zong z ong
424
- zou z ou
425
- zu z u
426
- zuan z uan
427
- zui z ui
428
- zun z un
429
- zuo z uo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/symbols.py DELETED
@@ -1,187 +0,0 @@
1
- punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
- pu_symbols = punctuation + ["SP", "UNK"]
3
- pad = "_"
4
-
5
- # chinese
6
- zh_symbols = [
7
- "E",
8
- "En",
9
- "a",
10
- "ai",
11
- "an",
12
- "ang",
13
- "ao",
14
- "b",
15
- "c",
16
- "ch",
17
- "d",
18
- "e",
19
- "ei",
20
- "en",
21
- "eng",
22
- "er",
23
- "f",
24
- "g",
25
- "h",
26
- "i",
27
- "i0",
28
- "ia",
29
- "ian",
30
- "iang",
31
- "iao",
32
- "ie",
33
- "in",
34
- "ing",
35
- "iong",
36
- "ir",
37
- "iu",
38
- "j",
39
- "k",
40
- "l",
41
- "m",
42
- "n",
43
- "o",
44
- "ong",
45
- "ou",
46
- "p",
47
- "q",
48
- "r",
49
- "s",
50
- "sh",
51
- "t",
52
- "u",
53
- "ua",
54
- "uai",
55
- "uan",
56
- "uang",
57
- "ui",
58
- "un",
59
- "uo",
60
- "v",
61
- "van",
62
- "ve",
63
- "vn",
64
- "w",
65
- "x",
66
- "y",
67
- "z",
68
- "zh",
69
- "AA",
70
- "EE",
71
- "OO",
72
- ]
73
- num_zh_tones = 6
74
-
75
- # japanese
76
- ja_symbols = [
77
- "N",
78
- "a",
79
- "a:",
80
- "b",
81
- "by",
82
- "ch",
83
- "d",
84
- "dy",
85
- "e",
86
- "e:",
87
- "f",
88
- "g",
89
- "gy",
90
- "h",
91
- "hy",
92
- "i",
93
- "i:",
94
- "j",
95
- "k",
96
- "ky",
97
- "m",
98
- "my",
99
- "n",
100
- "ny",
101
- "o",
102
- "o:",
103
- "p",
104
- "py",
105
- "q",
106
- "r",
107
- "ry",
108
- "s",
109
- "sh",
110
- "t",
111
- "ts",
112
- "ty",
113
- "u",
114
- "u:",
115
- "w",
116
- "y",
117
- "z",
118
- "zy",
119
- ]
120
- num_ja_tones = 2
121
-
122
- # English
123
- en_symbols = [
124
- "aa",
125
- "ae",
126
- "ah",
127
- "ao",
128
- "aw",
129
- "ay",
130
- "b",
131
- "ch",
132
- "d",
133
- "dh",
134
- "eh",
135
- "er",
136
- "ey",
137
- "f",
138
- "g",
139
- "hh",
140
- "ih",
141
- "iy",
142
- "jh",
143
- "k",
144
- "l",
145
- "m",
146
- "n",
147
- "ng",
148
- "ow",
149
- "oy",
150
- "p",
151
- "r",
152
- "s",
153
- "sh",
154
- "t",
155
- "th",
156
- "uh",
157
- "uw",
158
- "V",
159
- "w",
160
- "y",
161
- "z",
162
- "zh",
163
- ]
164
- num_en_tones = 4
165
-
166
- # combine all symbols
167
- normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
- symbols = [pad] + normal_symbols + pu_symbols
169
- sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
-
171
- # combine all tones
172
- num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
-
174
- # language maps
175
- language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
- num_languages = len(language_id_map.keys())
177
-
178
- language_tone_start_map = {
179
- "ZH": 0,
180
- "JP": num_zh_tones,
181
- "EN": num_zh_tones + num_ja_tones,
182
- }
183
-
184
- if __name__ == "__main__":
185
- a = set(zh_symbols)
186
- b = set(en_symbols)
187
- print(sorted(a & b))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V200/text/tone_sandhi.py DELETED
@@ -1,769 +0,0 @@
1
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import List
15
- from typing import Tuple
16
-
17
- import jieba
18
- from pypinyin import lazy_pinyin
19
- from pypinyin import Style
20
-
21
-
22
- class ToneSandhi:
23
- def __init__(self):
24
- self.must_neural_tone_words = {
25
- "麻烦",
26
- "麻利",
27
- "鸳鸯",
28
- "高粱",
29
- "骨头",
30
- "骆驼",
31
- "马虎",
32
- "首饰",
33
- "馒头",
34
- "馄饨",
35
- "风筝",
36
- "难为",
37
- "队伍",
38
- "阔气",
39
- "闺女",
40
- "门道",
41
- "锄头",
42
- "铺盖",
43
- "铃铛",
44
- "铁匠",
45
- "钥匙",
46
- "里脊",
47
- "里头",
48
- "部分",
49
- "那么",
50
- "道士",
51
- "造化",
52
- "迷糊",
53
- "连累",
54
- "这么",
55
- "这个",
56
- "运气",
57
- "过去",
58
- "软和",
59
- "转悠",
60
- "踏实",
61
- "跳蚤",
62
- "跟头",
63
- "趔趄",
64
- "财主",
65
- "豆腐",
66
- "讲究",
67
- "记性",
68
- "记号",
69
- "认识",
70
- "规矩",
71
- "见识",
72
- "裁缝",
73
- "补丁",
74
- "衣裳",
75
- "衣服",
76
- "衙门",
77
- "街坊",
78
- "行李",
79
- "行当",
80
- "蛤蟆",
81
- "蘑菇",
82
- "薄荷",
83
- "葫芦",
84
- "葡萄",
85
- "萝卜",
86
- "荸荠",
87
- "苗条",
88
- "苗头",
89
- "苍蝇",
90
- "芝麻",
91
- "舒服",
92
- "舒坦",
93
- "舌头",
94
- "自在",
95
- "膏药",
96
- "脾气",
97
- "脑袋",
98
- "脊梁",
99
- "能耐",
100
- "胳膊",
101
- "胭脂",
102
- "胡萝",
103
- "胡琴",
104
- "胡同",
105
- "聪明",
106
- "耽误",
107
- "耽搁",
108
- "耷拉",
109
- "耳朵",
110
- "老爷",
111
- "老实",
112
- "老婆",
113
- "老头",
114
- "老太",
115
- "翻腾",
116
- "罗嗦",
117
- "罐头",
118
- "编辑",
119
- "结实",
120
- "红火",
121
- "累赘",
122
- "糨糊",
123
- "糊涂",
124
- "精神",
125
- "粮食",
126
- "簸箕",
127
- "篱笆",
128
- "算计",
129
- "算盘",
130
- "答应",
131
- "笤帚",
132
- "笑语",
133
- "笑话",
134
- "窟窿",
135
- "窝囊",
136
- "窗户",
137
- "稳当",
138
- "稀罕",
139
- "称呼",
140
- "秧歌",
141
- "秀气",
142
- "秀才",
143
- "福气",
144
- "祖宗",
145
- "砚台",
146
- "码头",
147
- "石榴",
148
- "石头",
149
- "石匠",
150
- "知识",
151
- "眼睛",
152
- "眯缝",
153
- "眨巴",
154
- "眉毛",
155
- "相声",
156
- "盘算",
157
- "白净",
158
- "痢疾",
159
- "痛快",
160
- "疟疾",
161
- "疙瘩",
162
- "疏忽",
163
- "畜生",
164
- "生意",
165
- "甘蔗",
166
- "琵琶",
167
- "琢磨",
168
- "琉璃",
169
- "玻璃",
170
- "玫瑰",
171
- "玄乎",
172
- "狐狸",
173
- "状元",
174
- "特务",
175
- "牲口",
176
- "牙碜",
177
- "牌楼",
178
- "爽快",
179
- "爱人",
180
- "热闹",
181
- "烧饼",
182
- "烟筒",
183
- "烂糊",
184
- "点心",
185
- "炊帚",
186
- "灯笼",
187
- "火候",
188
- "漂亮",
189
- "滑溜",
190
- "溜达",
191
- "温和",
192
- "清楚",
193
- "消息",
194
- "浪头",
195
- "活泼",
196
- "比方",
197
- "正经",
198
- "欺负",
199
- "模糊",
200
- "槟榔",
201
- "棺材",
202
- "棒槌",
203
- "棉花",
204
- "核桃",
205
- "栅栏",
206
- "柴火",
207
- "架势",
208
- "枕头",
209
- "枇杷",
210
- "机灵",
211
- "本事",
212
- "木头",
213
- "木匠",
214
- "朋友",
215
- "月饼",
216
- "月亮",
217
- "暖和",
218
- "明白",
219
- "时候",
220
- "新鲜",
221
- "故事",
222
- "收拾",
223
- "收成",
224
- "提防",
225
- "挖苦",
226
- "挑剔",
227
- "指甲",
228
- "指头",
229
- "拾掇",
230
- "拳头",
231
- "拨弄",
232
- "招牌",
233
- "招呼",
234
- "抬举",
235
- "护士",
236
- "折腾",
237
- "扫帚",
238
- "打量",
239
- "打算",
240
- "打点",
241
- "打扮",
242
- "打听",
243
- "打发",
244
- "扎实",
245
- "扁担",
246
- "戒指",
247
- "懒得",
248
- "意识",
249
- "意思",
250
- "情形",
251
- "悟性",
252
- "怪物",
253
- "思量",
254
- "怎么",
255
- "念头",
256
- "念叨",
257
- "快活",
258
- "忙活",
259
- "志气",
260
- "心思",
261
- "得罪",
262
- "张罗",
263
- "弟兄",
264
- "开通",
265
- "应酬",
266
- "庄稼",
267
- "干事",
268
- "帮手",
269
- "帐篷",
270
- "希罕",
271
- "师父",
272
- "师傅",
273
- "巴结",
274
- "巴掌",
275
- "差事",
276
- "工夫",
277
- "岁数",
278
- "屁股",
279
- "尾巴",
280
- "少爷",
281
- "小气",
282
- "小伙",
283
- "将就",
284
- "对头",
285
- "对付",
286
- "寡妇",
287
- "家伙",
288
- "客气",
289
- "实在",
290
- "官司",
291
- "学问",
292
- "学生",
293
- "字号",
294
- "嫁妆",
295
- "媳妇",
296
- "媒人",
297
- "婆家",
298
- "娘家",
299
- "委屈",
300
- "姑娘",
301
- "姐夫",
302
- "妯娌",
303
- "妥当",
304
- "妖精",
305
- "奴才",
306
- "女婿",
307
- "头发",
308
- "太阳",
309
- "大爷",
310
- "大方",
311
- "大意",
312
- "大夫",
313
- "多少",
314
- "多么",
315
- "外甥",
316
- "壮实",
317
- "地道",
318
- "地方",
319
- "在乎",
320
- "困难",
321
- "嘴巴",
322
- "嘱咐",
323
- "嘟囔",
324
- "嘀咕",
325
- "喜欢",
326
- "喇嘛",
327
- "喇叭",
328
- "商量",
329
- "唾沫",
330
- "哑巴",
331
- "哈欠",
332
- "哆嗦",
333
- "咳嗽",
334
- "和尚",
335
- "告诉",
336
- "告示",
337
- "含糊",
338
- "吓唬",
339
- "后头",
340
- "名字",
341
- "名堂",
342
- "合同",
343
- "吆喝",
344
- "叫唤",
345
- "口袋",
346
- "厚道",
347
- "厉害",
348
- "千斤",
349
- "包袱",
350
- "包涵",
351
- "匀称",
352
- "勤快",
353
- "动静",
354
- "动弹",
355
- "功夫",
356
- "力气",
357
- "前头",
358
- "刺猬",
359
- "刺激",
360
- "别扭",
361
- "利落",
362
- "利索",
363
- "利害",
364
- "分析",
365
- "出息",
366
- "凑合",
367
- "凉快",
368
- "冷战",
369
- "冤枉",
370
- "冒失",
371
- "养活",
372
- "关系",
373
- "先生",
374
- "兄弟",
375
- "便宜",
376
- "使唤",
377
- "佩服",
378
- "作坊",
379
- "体面",
380
- "位置",
381
- "似的",
382
- "伙计",
383
- "休息",
384
- "什么",
385
- "人家",
386
- "亲戚",
387
- "亲家",
388
- "交情",
389
- "云彩",
390
- "事情",
391
- "买卖",
392
- "主意",
393
- "丫头",
394
- "丧气",
395
- "两口",
396
- "东西",
397
- "东家",
398
- "世故",
399
- "不由",
400
- "不在",
401
- "下水",
402
- "下巴",
403
- "上头",
404
- "上司",
405
- "丈夫",
406
- "丈人",
407
- "一辈",
408
- "那个",
409
- "菩萨",
410
- "父亲",
411
- "母亲",
412
- "咕噜",
413
- "邋遢",
414
- "费用",
415
- "冤家",
416
- "甜头",
417
- "介绍",
418
- "荒唐",
419
- "大人",
420
- "泥鳅",
421
- "幸福",
422
- "熟悉",
423
- "计划",
424
- "扑腾",
425
- "蜡烛",
426
- "姥爷",
427
- "照顾",
428
- "喉咙",
429
- "吉他",
430
- "弄堂",
431
- "蚂蚱",
432
- "凤凰",
433
- "拖沓",
434
- "寒碜",
435
- "糟蹋",
436
- "倒腾",
437
- "报复",
438
- "逻辑",
439
- "盘缠",
440
- "喽啰",
441
- "牢骚",
442
- "咖喱",
443
- "扫把",
444
- "惦记",
445
- }
446
- self.must_not_neural_tone_words = {
447
- "男子",
448
- "女子",
449
- "分子",
450
- "原子",
451
- "量子",
452
- "莲子",
453
- "石子",
454
- "瓜子",
455
- "电子",
456
- "人人",
457
- "虎虎",
458
- }
459
- self.punc = ":,;。?!“”‘’':,;.?!"
460
-
461
- # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
462
- # e.g.
463
- # word: "家里"
464
- # pos: "s"
465
- # finals: ['ia1', 'i3']
466
- def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
467
- # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
468
- for j, item in enumerate(word):
469
- if (
470
- j - 1 >= 0
471
- and item == word[j - 1]
472
- and pos[0] in {"n", "v", "a"}
473
- and word not in self.must_not_neural_tone_words
474
- ):
475
- finals[j] = finals[j][:-1] + "5"
476
- ge_idx = word.find("个")
477
- if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
478
- finals[-1] = finals[-1][:-1] + "5"
479
- elif len(word) >= 1 and word[-1] in "的地得":
480
- finals[-1] = finals[-1][:-1] + "5"
481
- # e.g. 走了, 看着, 去过
482
- # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
483
- # finals[-1] = finals[-1][:-1] + "5"
484
- elif (
485
- len(word) > 1
486
- and word[-1] in "们子"
487
- and pos in {"r", "n"}
488
- and word not in self.must_not_neural_tone_words
489
- ):
490
- finals[-1] = finals[-1][:-1] + "5"
491
- # e.g. 桌上, 地下, 家里
492
- elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
493
- finals[-1] = finals[-1][:-1] + "5"
494
- # e.g. 上来, 下去
495
- elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
496
- finals[-1] = finals[-1][:-1] + "5"
497
- # 个做量词
498
- elif (
499
- ge_idx >= 1
500
- and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
501
- ) or word == "个":
502
- finals[ge_idx] = finals[ge_idx][:-1] + "5"
503
- else:
504
- if (
505
- word in self.must_neural_tone_words
506
- or word[-2:] in self.must_neural_tone_words
507
- ):
508
- finals[-1] = finals[-1][:-1] + "5"
509
-
510
- word_list = self._split_word(word)
511
- finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
512
- for i, word in enumerate(word_list):
513
- # conventional neural in Chinese
514
- if (
515
- word in self.must_neural_tone_words
516
- or word[-2:] in self.must_neural_tone_words
517
- ):
518
- finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
519
- finals = sum(finals_list, [])
520
- return finals
521
-
522
- def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
523
- # e.g. 看不懂
524
- if len(word) == 3 and word[1] == "不":
525
- finals[1] = finals[1][:-1] + "5"
526
- else:
527
- for i, char in enumerate(word):
528
- # "不" before tone4 should be bu2, e.g. 不怕
529
- if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
530
- finals[i] = finals[i][:-1] + "2"
531
- return finals
532
-
533
- def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
534
- # "一" in number sequences, e.g. 一零零, 二一零
535
- if word.find("一") != -1 and all(
536
- [item.isnumeric() for item in word if item != "一"]
537
- ):
538
- return finals
539
- # "一" between reduplication words should be yi5, e.g. 看一看
540
- elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
541
- finals[1] = finals[1][:-1] + "5"
542
- # when "一" is ordinal word, it should be yi1
543
- elif word.startswith("第一"):
544
- finals[1] = finals[1][:-1] + "1"
545
- else:
546
- for i, char in enumerate(word):
547
- if char == "一" and i + 1 < len(word):
548
- # "一" before tone4 should be yi2, e.g. 一段
549
- if finals[i + 1][-1] == "4":
550
- finals[i] = finals[i][:-1] + "2"
551
- # "一" before non-tone4 should be yi4, e.g. 一天
552
- else:
553
- # "一" 后面如果是标点,还读一声
554
- if word[i + 1] not in self.punc:
555
- finals[i] = finals[i][:-1] + "4"
556
- return finals
557
-
558
- def _split_word(self, word: str) -> List[str]:
559
- word_list = jieba.cut_for_search(word)
560
- word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
561
- first_subword = word_list[0]
562
- first_begin_idx = word.find(first_subword)
563
- if first_begin_idx == 0:
564
- second_subword = word[len(first_subword) :]
565
- new_word_list = [first_subword, second_subword]
566
- else:
567
- second_subword = word[: -len(first_subword)]
568
- new_word_list = [second_subword, first_subword]
569
- return new_word_list
570
-
571
- def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
572
- if len(word) == 2 and self._all_tone_three(finals):
573
- finals[0] = finals[0][:-1] + "2"
574
- elif len(word) == 3:
575
- word_list = self._split_word(word)
576
- if self._all_tone_three(finals):
577
- # disyllabic + monosyllabic, e.g. 蒙古/包
578
- if len(word_list[0]) == 2:
579
- finals[0] = finals[0][:-1] + "2"
580
- finals[1] = finals[1][:-1] + "2"
581
- # monosyllabic + disyllabic, e.g. 纸/老虎
582
- elif len(word_list[0]) == 1:
583
- finals[1] = finals[1][:-1] + "2"
584
- else:
585
- finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
586
- if len(finals_list) == 2:
587
- for i, sub in enumerate(finals_list):
588
- # e.g. 所有/人
589
- if self._all_tone_three(sub) and len(sub) == 2:
590
- finals_list[i][0] = finals_list[i][0][:-1] + "2"
591
- # e.g. 好/喜欢
592
- elif (
593
- i == 1
594
- and not self._all_tone_three(sub)
595
- and finals_list[i][0][-1] == "3"
596
- and finals_list[0][-1][-1] == "3"
597
- ):
598
- finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
599
- finals = sum(finals_list, [])
600
- # split idiom into two words who's length is 2
601
- elif len(word) == 4:
602
- finals_list = [finals[:2], finals[2:]]
603
- finals = []
604
- for sub in finals_list:
605
- if self._all_tone_three(sub):
606
- sub[0] = sub[0][:-1] + "2"
607
- finals += sub
608
-
609
- return finals
610
-
611
- def _all_tone_three(self, finals: List[str]) -> bool:
612
- return all(x[-1] == "3" for x in finals)
613
-
614
- # merge "不" and the word behind it
615
- # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
616
- def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
617
- new_seg = []
618
- last_word = ""
619
- for word, pos in seg:
620
- if last_word == "不":
621
- word = last_word + word
622
- if word != "不":
623
- new_seg.append((word, pos))
624
- last_word = word[:]
625
- if last_word == "不":
626
- new_seg.append((last_word, "d"))
627
- last_word = ""
628
- return new_seg
629
-
630
- # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
631
- # function 2: merge single "一" and the word behind it
632
- # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
633
- # e.g.
634
- # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
635
- # output seg: [['听一听', 'v']]
636
- def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
637
- new_seg = []
638
- # function 1
639
- for i, (word, pos) in enumerate(seg):
640
- if (
641
- i - 1 >= 0
642
- and word == "一"
643
- and i + 1 < len(seg)
644
- and seg[i - 1][0] == seg[i + 1][0]
645
- and seg[i - 1][1] == "v"
646
- ):
647
- new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
648
- else:
649
- if (
650
- i - 2 >= 0
651
- and seg[i - 1][0] == "一"
652
- and seg[i - 2][0] == word
653
- and pos == "v"
654
- ):
655
- continue
656
- else:
657
- new_seg.append([word, pos])
658
- seg = new_seg
659
- new_seg = []
660
- # function 2
661
- for i, (word, pos) in enumerate(seg):
662
- if new_seg and new_seg[-1][0] == "一":
663
- new_seg[-1][0] = new_seg[-1][0] + word
664
- else:
665
- new_seg.append([word, pos])
666
- return new_seg
667
-
668
- # the first and the second words are all_tone_three
669
- def _merge_continuous_three_tones(
670
- self, seg: List[Tuple[str, str]]
671
- ) -> List[Tuple[str, str]]:
672
- new_seg = []
673
- sub_finals_list = [
674
- lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
675
- for (word, pos) in seg
676
- ]
677
- assert len(sub_finals_list) == len(seg)
678
- merge_last = [False] * len(seg)
679
- for i, (word, pos) in enumerate(seg):
680
- if (
681
- i - 1 >= 0
682
- and self._all_tone_three(sub_finals_list[i - 1])
683
- and self._all_tone_three(sub_finals_list[i])
684
- and not merge_last[i - 1]
685
- ):
686
- # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
687
- if (
688
- not self._is_reduplication(seg[i - 1][0])
689
- and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
690
- ):
691
- new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
692
- merge_last[i] = True
693
- else:
694
- new_seg.append([word, pos])
695
- else:
696
- new_seg.append([word, pos])
697
-
698
- return new_seg
699
-
700
- def _is_reduplication(self, word: str) -> bool:
701
- return len(word) == 2 and word[0] == word[1]
702
-
703
- # the last char of first word and the first char of second word is tone_three
704
- def _merge_continuous_three_tones_2(
705
- self, seg: List[Tuple[str, str]]
706
- ) -> List[Tuple[str, str]]:
707
- new_seg = []
708
- sub_finals_list = [
709
- lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
710
- for (word, pos) in seg
711
- ]
712
- assert len(sub_finals_list) == len(seg)
713
- merge_last = [False] * len(seg)
714
- for i, (word, pos) in enumerate(seg):
715
- if (
716
- i - 1 >= 0
717
- and sub_finals_list[i - 1][-1][-1] == "3"
718
- and sub_finals_list[i][0][-1] == "3"
719
- and not merge_last[i - 1]
720
- ):
721
- # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
722
- if (
723
- not self._is_reduplication(seg[i - 1][0])
724
- and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
725
- ):
726
- new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
727
- merge_last[i] = True
728
- else:
729
- new_seg.append([word, pos])
730
- else:
731
- new_seg.append([word, pos])
732
- return new_seg
733
-
734
- def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
735
- new_seg = []
736
- for i, (word, pos) in enumerate(seg):
737
- if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#":
738
- new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
739
- else:
740
- new_seg.append([word, pos])
741
- return new_seg
742
-
743
- def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
744
- new_seg = []
745
- for i, (word, pos) in enumerate(seg):
746
- if new_seg and word == new_seg[-1][0]:
747
- new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
748
- else:
749
- new_seg.append([word, pos])
750
- return new_seg
751
-
752
- def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
753
- seg = self._merge_bu(seg)
754
- try:
755
- seg = self._merge_yi(seg)
756
- except:
757
- print("_merge_yi failed")
758
- seg = self._merge_reduplication(seg)
759
- seg = self._merge_continuous_three_tones(seg)
760
- seg = self._merge_continuous_three_tones_2(seg)
761
- seg = self._merge_er(seg)
762
- return seg
763
-
764
- def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
765
- finals = self._bu_sandhi(word, finals)
766
- finals = self._yi_sandhi(word, finals)
767
- finals = self._neural_sandhi(word, pos, finals)
768
- finals = self._three_sandhi(word, finals)
769
- return finals
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V210/__init__.py DELETED
File without changes
onnx_modules/V210/attentions_onnx.py DELETED
@@ -1,378 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- import commons
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class LayerNorm(nn.Module):
13
- def __init__(self, channels, eps=1e-5):
14
- super().__init__()
15
- self.channels = channels
16
- self.eps = eps
17
-
18
- self.gamma = nn.Parameter(torch.ones(channels))
19
- self.beta = nn.Parameter(torch.zeros(channels))
20
-
21
- def forward(self, x):
22
- x = x.transpose(1, -1)
23
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
- return x.transpose(1, -1)
25
-
26
-
27
- @torch.jit.script
28
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
- n_channels_int = n_channels[0]
30
- in_act = input_a + input_b
31
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
- acts = t_act * s_act
34
- return acts
35
-
36
-
37
- class Encoder(nn.Module):
38
- def __init__(
39
- self,
40
- hidden_channels,
41
- filter_channels,
42
- n_heads,
43
- n_layers,
44
- kernel_size=1,
45
- p_dropout=0.0,
46
- window_size=4,
47
- isflow=True,
48
- **kwargs
49
- ):
50
- super().__init__()
51
- self.hidden_channels = hidden_channels
52
- self.filter_channels = filter_channels
53
- self.n_heads = n_heads
54
- self.n_layers = n_layers
55
- self.kernel_size = kernel_size
56
- self.p_dropout = p_dropout
57
- self.window_size = window_size
58
- # if isflow:
59
- # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
- # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
- # self.cond_layer = weight_norm(cond_layer, name='weight')
62
- # self.gin_channels = 256
63
- self.cond_layer_idx = self.n_layers
64
- if "gin_channels" in kwargs:
65
- self.gin_channels = kwargs["gin_channels"]
66
- if self.gin_channels != 0:
67
- self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
- # vits2 says 3rd block, so idx is 2 by default
69
- self.cond_layer_idx = (
70
- kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
- )
72
- logging.debug(self.gin_channels, self.cond_layer_idx)
73
- assert (
74
- self.cond_layer_idx < self.n_layers
75
- ), "cond_layer_idx should be less than n_layers"
76
- self.drop = nn.Dropout(p_dropout)
77
- self.attn_layers = nn.ModuleList()
78
- self.norm_layers_1 = nn.ModuleList()
79
- self.ffn_layers = nn.ModuleList()
80
- self.norm_layers_2 = nn.ModuleList()
81
- for i in range(self.n_layers):
82
- self.attn_layers.append(
83
- MultiHeadAttention(
84
- hidden_channels,
85
- hidden_channels,
86
- n_heads,
87
- p_dropout=p_dropout,
88
- window_size=window_size,
89
- )
90
- )
91
- self.norm_layers_1.append(LayerNorm(hidden_channels))
92
- self.ffn_layers.append(
93
- FFN(
94
- hidden_channels,
95
- hidden_channels,
96
- filter_channels,
97
- kernel_size,
98
- p_dropout=p_dropout,
99
- )
100
- )
101
- self.norm_layers_2.append(LayerNorm(hidden_channels))
102
-
103
- def forward(self, x, x_mask, g=None):
104
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105
- x = x * x_mask
106
- for i in range(self.n_layers):
107
- if i == self.cond_layer_idx and g is not None:
108
- g = self.spk_emb_linear(g.transpose(1, 2))
109
- g = g.transpose(1, 2)
110
- x = x + g
111
- x = x * x_mask
112
- y = self.attn_layers[i](x, x, attn_mask)
113
- y = self.drop(y)
114
- x = self.norm_layers_1[i](x + y)
115
-
116
- y = self.ffn_layers[i](x, x_mask)
117
- y = self.drop(y)
118
- x = self.norm_layers_2[i](x + y)
119
- x = x * x_mask
120
- return x
121
-
122
-
123
- class MultiHeadAttention(nn.Module):
124
- def __init__(
125
- self,
126
- channels,
127
- out_channels,
128
- n_heads,
129
- p_dropout=0.0,
130
- window_size=None,
131
- heads_share=True,
132
- block_length=None,
133
- proximal_bias=False,
134
- proximal_init=False,
135
- ):
136
- super().__init__()
137
- assert channels % n_heads == 0
138
-
139
- self.channels = channels
140
- self.out_channels = out_channels
141
- self.n_heads = n_heads
142
- self.p_dropout = p_dropout
143
- self.window_size = window_size
144
- self.heads_share = heads_share
145
- self.block_length = block_length
146
- self.proximal_bias = proximal_bias
147
- self.proximal_init = proximal_init
148
- self.attn = None
149
-
150
- self.k_channels = channels // n_heads
151
- self.conv_q = nn.Conv1d(channels, channels, 1)
152
- self.conv_k = nn.Conv1d(channels, channels, 1)
153
- self.conv_v = nn.Conv1d(channels, channels, 1)
154
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
155
- self.drop = nn.Dropout(p_dropout)
156
-
157
- if window_size is not None:
158
- n_heads_rel = 1 if heads_share else n_heads
159
- rel_stddev = self.k_channels**-0.5
160
- self.emb_rel_k = nn.Parameter(
161
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
162
- * rel_stddev
163
- )
164
- self.emb_rel_v = nn.Parameter(
165
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
166
- * rel_stddev
167
- )
168
-
169
- nn.init.xavier_uniform_(self.conv_q.weight)
170
- nn.init.xavier_uniform_(self.conv_k.weight)
171
- nn.init.xavier_uniform_(self.conv_v.weight)
172
- if proximal_init:
173
- with torch.no_grad():
174
- self.conv_k.weight.copy_(self.conv_q.weight)
175
- self.conv_k.bias.copy_(self.conv_q.bias)
176
-
177
- def forward(self, x, c, attn_mask=None):
178
- q = self.conv_q(x)
179
- k = self.conv_k(c)
180
- v = self.conv_v(c)
181
-
182
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
183
-
184
- x = self.conv_o(x)
185
- return x
186
-
187
- def attention(self, query, key, value, mask=None):
188
- # reshape [b, d, t] -> [b, n_h, t, d_k]
189
- b, d, t_s, t_t = (*key.size(), query.size(2))
190
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
191
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
193
-
194
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
195
- if self.window_size is not None:
196
- assert (
197
- t_s == t_t
198
- ), "Relative attention is only available for self-attention."
199
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
200
- rel_logits = self._matmul_with_relative_keys(
201
- query / math.sqrt(self.k_channels), key_relative_embeddings
202
- )
203
- scores_local = self._relative_position_to_absolute_position(rel_logits)
204
- scores = scores + scores_local
205
- if self.proximal_bias:
206
- assert t_s == t_t, "Proximal bias is only available for self-attention."
207
- scores = scores + self._attention_bias_proximal(t_s).to(
208
- device=scores.device, dtype=scores.dtype
209
- )
210
- if mask is not None:
211
- scores = scores.masked_fill(mask == 0, -1e4)
212
- if self.block_length is not None:
213
- assert (
214
- t_s == t_t
215
- ), "Local attention is only available for self-attention."
216
- block_mask = (
217
- torch.ones_like(scores)
218
- .triu(-self.block_length)
219
- .tril(self.block_length)
220
- )
221
- scores = scores.masked_fill(block_mask == 0, -1e4)
222
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
223
- p_attn = self.drop(p_attn)
224
- output = torch.matmul(p_attn, value)
225
- if self.window_size is not None:
226
- relative_weights = self._absolute_position_to_relative_position(p_attn)
227
- value_relative_embeddings = self._get_relative_embeddings(
228
- self.emb_rel_v, t_s
229
- )
230
- output = output + self._matmul_with_relative_values(
231
- relative_weights, value_relative_embeddings
232
- )
233
- output = (
234
- output.transpose(2, 3).contiguous().view(b, d, t_t)
235
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
236
- return output, p_attn
237
-
238
- def _matmul_with_relative_values(self, x, y):
239
- """
240
- x: [b, h, l, m]
241
- y: [h or 1, m, d]
242
- ret: [b, h, l, d]
243
- """
244
- ret = torch.matmul(x, y.unsqueeze(0))
245
- return ret
246
-
247
- def _matmul_with_relative_keys(self, x, y):
248
- """
249
- x: [b, h, l, d]
250
- y: [h or 1, m, d]
251
- ret: [b, h, l, m]
252
- """
253
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
254
- return ret
255
-
256
- def _get_relative_embeddings(self, relative_embeddings, length):
257
- max_relative_position = 2 * self.window_size + 1
258
- # Pad first before slice to avoid using cond ops.
259
- pad_length = max(length - (self.window_size + 1), 0)
260
- slice_start_position = max((self.window_size + 1) - length, 0)
261
- slice_end_position = slice_start_position + 2 * length - 1
262
- if pad_length > 0:
263
- padded_relative_embeddings = F.pad(
264
- relative_embeddings,
265
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
266
- )
267
- else:
268
- padded_relative_embeddings = relative_embeddings
269
- used_relative_embeddings = padded_relative_embeddings[
270
- :, slice_start_position:slice_end_position
271
- ]
272
- return used_relative_embeddings
273
-
274
- def _relative_position_to_absolute_position(self, x):
275
- """
276
- x: [b, h, l, 2*l-1]
277
- ret: [b, h, l, l]
278
- """
279
- batch, heads, length, _ = x.size()
280
- # Concat columns of pad to shift from relative to absolute indexing.
281
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
282
-
283
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
284
- x_flat = x.view([batch, heads, length * 2 * length])
285
- x_flat = F.pad(
286
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
287
- )
288
-
289
- # Reshape and slice out the padded elements.
290
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
291
- :, :, :length, length - 1 :
292
- ]
293
- return x_final
294
-
295
- def _absolute_position_to_relative_position(self, x):
296
- """
297
- x: [b, h, l, l]
298
- ret: [b, h, l, 2*l-1]
299
- """
300
- batch, heads, length, _ = x.size()
301
- # padd along column
302
- x = F.pad(
303
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
304
- )
305
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
306
- # add 0's in the beginning that will skew the elements after reshape
307
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
308
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
309
- return x_final
310
-
311
- def _attention_bias_proximal(self, length):
312
- """Bias for self-attention to encourage attention to close positions.
313
- Args:
314
- length: an integer scalar.
315
- Returns:
316
- a Tensor with shape [1, 1, length, length]
317
- """
318
- r = torch.arange(length, dtype=torch.float32)
319
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
320
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
321
-
322
-
323
- class FFN(nn.Module):
324
- def __init__(
325
- self,
326
- in_channels,
327
- out_channels,
328
- filter_channels,
329
- kernel_size,
330
- p_dropout=0.0,
331
- activation=None,
332
- causal=False,
333
- ):
334
- super().__init__()
335
- self.in_channels = in_channels
336
- self.out_channels = out_channels
337
- self.filter_channels = filter_channels
338
- self.kernel_size = kernel_size
339
- self.p_dropout = p_dropout
340
- self.activation = activation
341
- self.causal = causal
342
-
343
- if causal:
344
- self.padding = self._causal_padding
345
- else:
346
- self.padding = self._same_padding
347
-
348
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
349
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
350
- self.drop = nn.Dropout(p_dropout)
351
-
352
- def forward(self, x, x_mask):
353
- x = self.conv_1(self.padding(x * x_mask))
354
- if self.activation == "gelu":
355
- x = x * torch.sigmoid(1.702 * x)
356
- else:
357
- x = torch.relu(x)
358
- x = self.drop(x)
359
- x = self.conv_2(self.padding(x * x_mask))
360
- return x * x_mask
361
-
362
- def _causal_padding(self, x):
363
- if self.kernel_size == 1:
364
- return x
365
- pad_l = self.kernel_size - 1
366
- pad_r = 0
367
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
368
- x = F.pad(x, commons.convert_pad_shape(padding))
369
- return x
370
-
371
- def _same_padding(self, x):
372
- if self.kernel_size == 1:
373
- return x
374
- pad_l = (self.kernel_size - 1) // 2
375
- pad_r = self.kernel_size // 2
376
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
377
- x = F.pad(x, commons.convert_pad_shape(padding))
378
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V210/models_onnx.py DELETED
@@ -1,1044 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- import commons
7
- import modules
8
- from . import attentions_onnx
9
- from vector_quantize_pytorch import VectorQuantize
10
-
11
- from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
- from commons import init_weights, get_padding
14
- from .text import symbols, num_tones, num_languages
15
-
16
-
17
- class DurationDiscriminator(nn.Module): # vits2
18
- def __init__(
19
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
20
- ):
21
- super().__init__()
22
-
23
- self.in_channels = in_channels
24
- self.filter_channels = filter_channels
25
- self.kernel_size = kernel_size
26
- self.p_dropout = p_dropout
27
- self.gin_channels = gin_channels
28
-
29
- self.drop = nn.Dropout(p_dropout)
30
- self.conv_1 = nn.Conv1d(
31
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
32
- )
33
- self.norm_1 = modules.LayerNorm(filter_channels)
34
- self.conv_2 = nn.Conv1d(
35
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
36
- )
37
- self.norm_2 = modules.LayerNorm(filter_channels)
38
- self.dur_proj = nn.Conv1d(1, filter_channels, 1)
39
-
40
- self.pre_out_conv_1 = nn.Conv1d(
41
- 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
42
- )
43
- self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
44
- self.pre_out_conv_2 = nn.Conv1d(
45
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
46
- )
47
- self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
48
-
49
- if gin_channels != 0:
50
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
51
-
52
- self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
53
-
54
- def forward_probability(self, x, x_mask, dur, g=None):
55
- dur = self.dur_proj(dur)
56
- x = torch.cat([x, dur], dim=1)
57
- x = self.pre_out_conv_1(x * x_mask)
58
- x = torch.relu(x)
59
- x = self.pre_out_norm_1(x)
60
- x = self.drop(x)
61
- x = self.pre_out_conv_2(x * x_mask)
62
- x = torch.relu(x)
63
- x = self.pre_out_norm_2(x)
64
- x = self.drop(x)
65
- x = x * x_mask
66
- x = x.transpose(1, 2)
67
- output_prob = self.output_layer(x)
68
- return output_prob
69
-
70
- def forward(self, x, x_mask, dur_r, dur_hat, g=None):
71
- x = torch.detach(x)
72
- if g is not None:
73
- g = torch.detach(g)
74
- x = x + self.cond(g)
75
- x = self.conv_1(x * x_mask)
76
- x = torch.relu(x)
77
- x = self.norm_1(x)
78
- x = self.drop(x)
79
- x = self.conv_2(x * x_mask)
80
- x = torch.relu(x)
81
- x = self.norm_2(x)
82
- x = self.drop(x)
83
-
84
- output_probs = []
85
- for dur in [dur_r, dur_hat]:
86
- output_prob = self.forward_probability(x, x_mask, dur, g)
87
- output_probs.append(output_prob)
88
-
89
- return output_probs
90
-
91
-
92
- class TransformerCouplingBlock(nn.Module):
93
- def __init__(
94
- self,
95
- channels,
96
- hidden_channels,
97
- filter_channels,
98
- n_heads,
99
- n_layers,
100
- kernel_size,
101
- p_dropout,
102
- n_flows=4,
103
- gin_channels=0,
104
- share_parameter=False,
105
- ):
106
- super().__init__()
107
- self.channels = channels
108
- self.hidden_channels = hidden_channels
109
- self.kernel_size = kernel_size
110
- self.n_layers = n_layers
111
- self.n_flows = n_flows
112
- self.gin_channels = gin_channels
113
-
114
- self.flows = nn.ModuleList()
115
-
116
- self.wn = (
117
- attentions_onnx.FFT(
118
- hidden_channels,
119
- filter_channels,
120
- n_heads,
121
- n_layers,
122
- kernel_size,
123
- p_dropout,
124
- isflow=True,
125
- gin_channels=self.gin_channels,
126
- )
127
- if share_parameter
128
- else None
129
- )
130
-
131
- for i in range(n_flows):
132
- self.flows.append(
133
- modules.TransformerCouplingLayer(
134
- channels,
135
- hidden_channels,
136
- kernel_size,
137
- n_layers,
138
- n_heads,
139
- p_dropout,
140
- filter_channels,
141
- mean_only=True,
142
- wn_sharing_parameter=self.wn,
143
- gin_channels=self.gin_channels,
144
- )
145
- )
146
- self.flows.append(modules.Flip())
147
-
148
- def forward(self, x, x_mask, g=None, reverse=True):
149
- if not reverse:
150
- for flow in self.flows:
151
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
152
- else:
153
- for flow in reversed(self.flows):
154
- x = flow(x, x_mask, g=g, reverse=reverse)
155
- return x
156
-
157
-
158
- class StochasticDurationPredictor(nn.Module):
159
- def __init__(
160
- self,
161
- in_channels,
162
- filter_channels,
163
- kernel_size,
164
- p_dropout,
165
- n_flows=4,
166
- gin_channels=0,
167
- ):
168
- super().__init__()
169
- filter_channels = in_channels # it needs to be removed from future version.
170
- self.in_channels = in_channels
171
- self.filter_channels = filter_channels
172
- self.kernel_size = kernel_size
173
- self.p_dropout = p_dropout
174
- self.n_flows = n_flows
175
- self.gin_channels = gin_channels
176
-
177
- self.log_flow = modules.Log()
178
- self.flows = nn.ModuleList()
179
- self.flows.append(modules.ElementwiseAffine(2))
180
- for i in range(n_flows):
181
- self.flows.append(
182
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
183
- )
184
- self.flows.append(modules.Flip())
185
-
186
- self.post_pre = nn.Conv1d(1, filter_channels, 1)
187
- self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
188
- self.post_convs = modules.DDSConv(
189
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
190
- )
191
- self.post_flows = nn.ModuleList()
192
- self.post_flows.append(modules.ElementwiseAffine(2))
193
- for i in range(4):
194
- self.post_flows.append(
195
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
196
- )
197
- self.post_flows.append(modules.Flip())
198
-
199
- self.pre = nn.Conv1d(in_channels, filter_channels, 1)
200
- self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
201
- self.convs = modules.DDSConv(
202
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
203
- )
204
- if gin_channels != 0:
205
- self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
206
-
207
- def forward(self, x, x_mask, z, g=None):
208
- x = torch.detach(x)
209
- x = self.pre(x)
210
- if g is not None:
211
- g = torch.detach(g)
212
- x = x + self.cond(g)
213
- x = self.convs(x, x_mask)
214
- x = self.proj(x) * x_mask
215
-
216
- flows = list(reversed(self.flows))
217
- flows = flows[:-2] + [flows[-1]] # remove a useless vflow
218
- for flow in flows:
219
- z = flow(z, x_mask, g=x, reverse=True)
220
- z0, z1 = torch.split(z, [1, 1], 1)
221
- logw = z0
222
- return logw
223
-
224
-
225
- class DurationPredictor(nn.Module):
226
- def __init__(
227
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
228
- ):
229
- super().__init__()
230
-
231
- self.in_channels = in_channels
232
- self.filter_channels = filter_channels
233
- self.kernel_size = kernel_size
234
- self.p_dropout = p_dropout
235
- self.gin_channels = gin_channels
236
-
237
- self.drop = nn.Dropout(p_dropout)
238
- self.conv_1 = nn.Conv1d(
239
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
240
- )
241
- self.norm_1 = modules.LayerNorm(filter_channels)
242
- self.conv_2 = nn.Conv1d(
243
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
244
- )
245
- self.norm_2 = modules.LayerNorm(filter_channels)
246
- self.proj = nn.Conv1d(filter_channels, 1, 1)
247
-
248
- if gin_channels != 0:
249
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
250
-
251
- def forward(self, x, x_mask, g=None):
252
- x = torch.detach(x)
253
- if g is not None:
254
- g = torch.detach(g)
255
- x = x + self.cond(g)
256
- x = self.conv_1(x * x_mask)
257
- x = torch.relu(x)
258
- x = self.norm_1(x)
259
- x = self.drop(x)
260
- x = self.conv_2(x * x_mask)
261
- x = torch.relu(x)
262
- x = self.norm_2(x)
263
- x = self.drop(x)
264
- x = self.proj(x * x_mask)
265
- return x * x_mask
266
-
267
-
268
- class TextEncoder(nn.Module):
269
- def __init__(
270
- self,
271
- n_vocab,
272
- out_channels,
273
- hidden_channels,
274
- filter_channels,
275
- n_heads,
276
- n_layers,
277
- kernel_size,
278
- p_dropout,
279
- n_speakers,
280
- gin_channels=0,
281
- ):
282
- super().__init__()
283
- self.n_vocab = n_vocab
284
- self.out_channels = out_channels
285
- self.hidden_channels = hidden_channels
286
- self.filter_channels = filter_channels
287
- self.n_heads = n_heads
288
- self.n_layers = n_layers
289
- self.kernel_size = kernel_size
290
- self.p_dropout = p_dropout
291
- self.gin_channels = gin_channels
292
- self.emb = nn.Embedding(len(symbols), hidden_channels)
293
- nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
294
- self.tone_emb = nn.Embedding(num_tones, hidden_channels)
295
- nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
296
- self.language_emb = nn.Embedding(num_languages, hidden_channels)
297
- nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
298
- self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
299
- self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
300
- self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
301
- self.emo_proj = nn.Linear(1024, 1024)
302
- self.emo_quantizer = nn.ModuleList()
303
- for i in range(0, n_speakers):
304
- self.emo_quantizer.append(
305
- VectorQuantize(
306
- dim=1024,
307
- codebook_size=10,
308
- decay=0.8,
309
- commitment_weight=1.0,
310
- learnable_codebook=True,
311
- ema_update=False,
312
- )
313
- )
314
- self.emo_q_proj = nn.Linear(1024, hidden_channels)
315
- self.n_speakers = n_speakers
316
-
317
- self.encoder = attentions_onnx.Encoder(
318
- hidden_channels,
319
- filter_channels,
320
- n_heads,
321
- n_layers,
322
- kernel_size,
323
- p_dropout,
324
- gin_channels=self.gin_channels,
325
- )
326
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
327
-
328
- def init_vq(self):
329
- self.emb_vq = nn.Embedding(10 * self.n_speakers, 1024)
330
- self.emb_vq_weight = torch.zeros(10 * self.n_speakers, 1024).float()
331
- for i in range(self.n_speakers):
332
- for j in range(10):
333
- self.emb_vq_weight[i * 10 + j] = self.emo_quantizer[
334
- i
335
- ].get_output_from_indices(torch.LongTensor([j]))
336
- self.emb_vq.weight = nn.Parameter(self.emb_vq_weight.clone())
337
-
338
- def forward(
339
- self,
340
- x,
341
- x_lengths,
342
- tone,
343
- language,
344
- bert,
345
- ja_bert,
346
- en_bert,
347
- g=None,
348
- vqidx=None,
349
- sid=None,
350
- ):
351
- x_mask = torch.ones_like(x).unsqueeze(0)
352
- bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
353
- ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
354
- 1, 2
355
- )
356
- en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
357
- 1, 2
358
- )
359
-
360
- emb_vq_idx = torch.clamp(
361
- (sid * 10) + vqidx, min=0, max=(self.n_speakers * 10) - 1
362
- )
363
-
364
- vqval = self.emb_vq(emb_vq_idx)
365
-
366
- x = (
367
- self.emb(x)
368
- + self.tone_emb(tone)
369
- + self.language_emb(language)
370
- + bert_emb
371
- + ja_bert_emb
372
- + en_bert_emb
373
- + self.emo_q_proj(vqval)
374
- ) * math.sqrt(
375
- self.hidden_channels
376
- ) # [b, t, h]
377
- x = torch.transpose(x, 1, -1) # [b, h, t]
378
- x_mask = x_mask.to(x.dtype)
379
-
380
- x = self.encoder(x * x_mask, x_mask, g=g)
381
- stats = self.proj(x) * x_mask
382
-
383
- m, logs = torch.split(stats, self.out_channels, dim=1)
384
- return x, m, logs, x_mask
385
-
386
-
387
- class ResidualCouplingBlock(nn.Module):
388
- def __init__(
389
- self,
390
- channels,
391
- hidden_channels,
392
- kernel_size,
393
- dilation_rate,
394
- n_layers,
395
- n_flows=4,
396
- gin_channels=0,
397
- ):
398
- super().__init__()
399
- self.channels = channels
400
- self.hidden_channels = hidden_channels
401
- self.kernel_size = kernel_size
402
- self.dilation_rate = dilation_rate
403
- self.n_layers = n_layers
404
- self.n_flows = n_flows
405
- self.gin_channels = gin_channels
406
-
407
- self.flows = nn.ModuleList()
408
- for i in range(n_flows):
409
- self.flows.append(
410
- modules.ResidualCouplingLayer(
411
- channels,
412
- hidden_channels,
413
- kernel_size,
414
- dilation_rate,
415
- n_layers,
416
- gin_channels=gin_channels,
417
- mean_only=True,
418
- )
419
- )
420
- self.flows.append(modules.Flip())
421
-
422
- def forward(self, x, x_mask, g=None, reverse=True):
423
- if not reverse:
424
- for flow in self.flows:
425
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
426
- else:
427
- for flow in reversed(self.flows):
428
- x = flow(x, x_mask, g=g, reverse=reverse)
429
- return x
430
-
431
-
432
- class PosteriorEncoder(nn.Module):
433
- def __init__(
434
- self,
435
- in_channels,
436
- out_channels,
437
- hidden_channels,
438
- kernel_size,
439
- dilation_rate,
440
- n_layers,
441
- gin_channels=0,
442
- ):
443
- super().__init__()
444
- self.in_channels = in_channels
445
- self.out_channels = out_channels
446
- self.hidden_channels = hidden_channels
447
- self.kernel_size = kernel_size
448
- self.dilation_rate = dilation_rate
449
- self.n_layers = n_layers
450
- self.gin_channels = gin_channels
451
-
452
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
453
- self.enc = modules.WN(
454
- hidden_channels,
455
- kernel_size,
456
- dilation_rate,
457
- n_layers,
458
- gin_channels=gin_channels,
459
- )
460
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
461
-
462
- def forward(self, x, x_lengths, g=None):
463
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
464
- x.dtype
465
- )
466
- x = self.pre(x) * x_mask
467
- x = self.enc(x, x_mask, g=g)
468
- stats = self.proj(x) * x_mask
469
- m, logs = torch.split(stats, self.out_channels, dim=1)
470
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
471
- return z, m, logs, x_mask
472
-
473
-
474
- class Generator(torch.nn.Module):
475
- def __init__(
476
- self,
477
- initial_channel,
478
- resblock,
479
- resblock_kernel_sizes,
480
- resblock_dilation_sizes,
481
- upsample_rates,
482
- upsample_initial_channel,
483
- upsample_kernel_sizes,
484
- gin_channels=0,
485
- ):
486
- super(Generator, self).__init__()
487
- self.num_kernels = len(resblock_kernel_sizes)
488
- self.num_upsamples = len(upsample_rates)
489
- self.conv_pre = Conv1d(
490
- initial_channel, upsample_initial_channel, 7, 1, padding=3
491
- )
492
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
493
-
494
- self.ups = nn.ModuleList()
495
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
496
- self.ups.append(
497
- weight_norm(
498
- ConvTranspose1d(
499
- upsample_initial_channel // (2**i),
500
- upsample_initial_channel // (2 ** (i + 1)),
501
- k,
502
- u,
503
- padding=(k - u) // 2,
504
- )
505
- )
506
- )
507
-
508
- self.resblocks = nn.ModuleList()
509
- for i in range(len(self.ups)):
510
- ch = upsample_initial_channel // (2 ** (i + 1))
511
- for j, (k, d) in enumerate(
512
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
513
- ):
514
- self.resblocks.append(resblock(ch, k, d))
515
-
516
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
517
- self.ups.apply(init_weights)
518
-
519
- if gin_channels != 0:
520
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
521
-
522
- def forward(self, x, g=None):
523
- x = self.conv_pre(x)
524
- if g is not None:
525
- x = x + self.cond(g)
526
-
527
- for i in range(self.num_upsamples):
528
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
529
- x = self.ups[i](x)
530
- xs = None
531
- for j in range(self.num_kernels):
532
- if xs is None:
533
- xs = self.resblocks[i * self.num_kernels + j](x)
534
- else:
535
- xs += self.resblocks[i * self.num_kernels + j](x)
536
- x = xs / self.num_kernels
537
- x = F.leaky_relu(x)
538
- x = self.conv_post(x)
539
- x = torch.tanh(x)
540
-
541
- return x
542
-
543
- def remove_weight_norm(self):
544
- print("Removing weight norm...")
545
- for layer in self.ups:
546
- remove_weight_norm(layer)
547
- for layer in self.resblocks:
548
- layer.remove_weight_norm()
549
-
550
-
551
- class DiscriminatorP(torch.nn.Module):
552
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
553
- super(DiscriminatorP, self).__init__()
554
- self.period = period
555
- self.use_spectral_norm = use_spectral_norm
556
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
557
- self.convs = nn.ModuleList(
558
- [
559
- norm_f(
560
- Conv2d(
561
- 1,
562
- 32,
563
- (kernel_size, 1),
564
- (stride, 1),
565
- padding=(get_padding(kernel_size, 1), 0),
566
- )
567
- ),
568
- norm_f(
569
- Conv2d(
570
- 32,
571
- 128,
572
- (kernel_size, 1),
573
- (stride, 1),
574
- padding=(get_padding(kernel_size, 1), 0),
575
- )
576
- ),
577
- norm_f(
578
- Conv2d(
579
- 128,
580
- 512,
581
- (kernel_size, 1),
582
- (stride, 1),
583
- padding=(get_padding(kernel_size, 1), 0),
584
- )
585
- ),
586
- norm_f(
587
- Conv2d(
588
- 512,
589
- 1024,
590
- (kernel_size, 1),
591
- (stride, 1),
592
- padding=(get_padding(kernel_size, 1), 0),
593
- )
594
- ),
595
- norm_f(
596
- Conv2d(
597
- 1024,
598
- 1024,
599
- (kernel_size, 1),
600
- 1,
601
- padding=(get_padding(kernel_size, 1), 0),
602
- )
603
- ),
604
- ]
605
- )
606
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
607
-
608
- def forward(self, x):
609
- fmap = []
610
-
611
- # 1d to 2d
612
- b, c, t = x.shape
613
- if t % self.period != 0: # pad first
614
- n_pad = self.period - (t % self.period)
615
- x = F.pad(x, (0, n_pad), "reflect")
616
- t = t + n_pad
617
- x = x.view(b, c, t // self.period, self.period)
618
-
619
- for layer in self.convs:
620
- x = layer(x)
621
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
622
- fmap.append(x)
623
- x = self.conv_post(x)
624
- fmap.append(x)
625
- x = torch.flatten(x, 1, -1)
626
-
627
- return x, fmap
628
-
629
-
630
- class DiscriminatorS(torch.nn.Module):
631
- def __init__(self, use_spectral_norm=False):
632
- super(DiscriminatorS, self).__init__()
633
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
634
- self.convs = nn.ModuleList(
635
- [
636
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
637
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
638
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
639
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
640
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
641
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
642
- ]
643
- )
644
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
645
-
646
- def forward(self, x):
647
- fmap = []
648
-
649
- for layer in self.convs:
650
- x = layer(x)
651
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
652
- fmap.append(x)
653
- x = self.conv_post(x)
654
- fmap.append(x)
655
- x = torch.flatten(x, 1, -1)
656
-
657
- return x, fmap
658
-
659
-
660
- class MultiPeriodDiscriminator(torch.nn.Module):
661
- def __init__(self, use_spectral_norm=False):
662
- super(MultiPeriodDiscriminator, self).__init__()
663
- periods = [2, 3, 5, 7, 11]
664
-
665
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
666
- discs = discs + [
667
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
668
- ]
669
- self.discriminators = nn.ModuleList(discs)
670
-
671
- def forward(self, y, y_hat):
672
- y_d_rs = []
673
- y_d_gs = []
674
- fmap_rs = []
675
- fmap_gs = []
676
- for i, d in enumerate(self.discriminators):
677
- y_d_r, fmap_r = d(y)
678
- y_d_g, fmap_g = d(y_hat)
679
- y_d_rs.append(y_d_r)
680
- y_d_gs.append(y_d_g)
681
- fmap_rs.append(fmap_r)
682
- fmap_gs.append(fmap_g)
683
-
684
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
685
-
686
-
687
- class ReferenceEncoder(nn.Module):
688
- """
689
- inputs --- [N, Ty/r, n_mels*r] mels
690
- outputs --- [N, ref_enc_gru_size]
691
- """
692
-
693
- def __init__(self, spec_channels, gin_channels=0):
694
- super().__init__()
695
- self.spec_channels = spec_channels
696
- ref_enc_filters = [32, 32, 64, 64, 128, 128]
697
- K = len(ref_enc_filters)
698
- filters = [1] + ref_enc_filters
699
- convs = [
700
- weight_norm(
701
- nn.Conv2d(
702
- in_channels=filters[i],
703
- out_channels=filters[i + 1],
704
- kernel_size=(3, 3),
705
- stride=(2, 2),
706
- padding=(1, 1),
707
- )
708
- )
709
- for i in range(K)
710
- ]
711
- self.convs = nn.ModuleList(convs)
712
- # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
713
-
714
- out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
715
- self.gru = nn.GRU(
716
- input_size=ref_enc_filters[-1] * out_channels,
717
- hidden_size=256 // 2,
718
- batch_first=True,
719
- )
720
- self.proj = nn.Linear(128, gin_channels)
721
-
722
- def forward(self, inputs, mask=None):
723
- N = inputs.size(0)
724
- out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
725
- for conv in self.convs:
726
- out = conv(out)
727
- # out = wn(out)
728
- out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
729
-
730
- out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
731
- T = out.size(1)
732
- N = out.size(0)
733
- out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
734
-
735
- self.gru.flatten_parameters()
736
- memory, out = self.gru(out) # out --- [1, N, 128]
737
-
738
- return self.proj(out.squeeze(0))
739
-
740
- def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
741
- for i in range(n_convs):
742
- L = (L - kernel_size + 2 * pad) // stride + 1
743
- return L
744
-
745
-
746
- class SynthesizerTrn(nn.Module):
747
- """
748
- Synthesizer for Training
749
- """
750
-
751
- def __init__(
752
- self,
753
- n_vocab,
754
- spec_channels,
755
- segment_size,
756
- inter_channels,
757
- hidden_channels,
758
- filter_channels,
759
- n_heads,
760
- n_layers,
761
- kernel_size,
762
- p_dropout,
763
- resblock,
764
- resblock_kernel_sizes,
765
- resblock_dilation_sizes,
766
- upsample_rates,
767
- upsample_initial_channel,
768
- upsample_kernel_sizes,
769
- n_speakers=256,
770
- gin_channels=256,
771
- use_sdp=True,
772
- n_flow_layer=4,
773
- n_layers_trans_flow=4,
774
- flow_share_parameter=False,
775
- use_transformer_flow=True,
776
- **kwargs,
777
- ):
778
- super().__init__()
779
- self.n_vocab = n_vocab
780
- self.spec_channels = spec_channels
781
- self.inter_channels = inter_channels
782
- self.hidden_channels = hidden_channels
783
- self.filter_channels = filter_channels
784
- self.n_heads = n_heads
785
- self.n_layers = n_layers
786
- self.kernel_size = kernel_size
787
- self.p_dropout = p_dropout
788
- self.resblock = resblock
789
- self.resblock_kernel_sizes = resblock_kernel_sizes
790
- self.resblock_dilation_sizes = resblock_dilation_sizes
791
- self.upsample_rates = upsample_rates
792
- self.upsample_initial_channel = upsample_initial_channel
793
- self.upsample_kernel_sizes = upsample_kernel_sizes
794
- self.segment_size = segment_size
795
- self.n_speakers = n_speakers
796
- self.gin_channels = gin_channels
797
- self.n_layers_trans_flow = n_layers_trans_flow
798
- self.use_spk_conditioned_encoder = kwargs.get(
799
- "use_spk_conditioned_encoder", True
800
- )
801
- self.use_sdp = use_sdp
802
- self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
803
- self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
804
- self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
805
- self.current_mas_noise_scale = self.mas_noise_scale_initial
806
- if self.use_spk_conditioned_encoder and gin_channels > 0:
807
- self.enc_gin_channels = gin_channels
808
- self.enc_p = TextEncoder(
809
- n_vocab,
810
- inter_channels,
811
- hidden_channels,
812
- filter_channels,
813
- n_heads,
814
- n_layers,
815
- kernel_size,
816
- p_dropout,
817
- n_speakers,
818
- gin_channels=self.enc_gin_channels,
819
- )
820
- self.dec = Generator(
821
- inter_channels,
822
- resblock,
823
- resblock_kernel_sizes,
824
- resblock_dilation_sizes,
825
- upsample_rates,
826
- upsample_initial_channel,
827
- upsample_kernel_sizes,
828
- gin_channels=gin_channels,
829
- )
830
- self.enc_q = PosteriorEncoder(
831
- spec_channels,
832
- inter_channels,
833
- hidden_channels,
834
- 5,
835
- 1,
836
- 16,
837
- gin_channels=gin_channels,
838
- )
839
- if use_transformer_flow:
840
- self.flow = TransformerCouplingBlock(
841
- inter_channels,
842
- hidden_channels,
843
- filter_channels,
844
- n_heads,
845
- n_layers_trans_flow,
846
- 5,
847
- p_dropout,
848
- n_flow_layer,
849
- gin_channels=gin_channels,
850
- share_parameter=flow_share_parameter,
851
- )
852
- else:
853
- self.flow = ResidualCouplingBlock(
854
- inter_channels,
855
- hidden_channels,
856
- 5,
857
- 1,
858
- n_flow_layer,
859
- gin_channels=gin_channels,
860
- )
861
- self.sdp = StochasticDurationPredictor(
862
- hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
863
- )
864
- self.dp = DurationPredictor(
865
- hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
866
- )
867
-
868
- if n_speakers >= 1:
869
- self.emb_g = nn.Embedding(n_speakers, gin_channels)
870
- else:
871
- self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
872
-
873
- def export_onnx(
874
- self,
875
- path,
876
- max_len=None,
877
- sdp_ratio=0,
878
- y=None,
879
- ):
880
- noise_scale = 0.667
881
- length_scale = 1
882
- noise_scale_w = 0.8
883
- x = (
884
- torch.LongTensor(
885
- [
886
- 0,
887
- 97,
888
- 0,
889
- 8,
890
- 0,
891
- 78,
892
- 0,
893
- 8,
894
- 0,
895
- 76,
896
- 0,
897
- 37,
898
- 0,
899
- 40,
900
- 0,
901
- 97,
902
- 0,
903
- 8,
904
- 0,
905
- 23,
906
- 0,
907
- 8,
908
- 0,
909
- 74,
910
- 0,
911
- 26,
912
- 0,
913
- 104,
914
- 0,
915
- ]
916
- )
917
- .unsqueeze(0)
918
- .cpu()
919
- )
920
- tone = torch.zeros_like(x).cpu()
921
- language = torch.zeros_like(x).cpu()
922
- x_lengths = torch.LongTensor([x.shape[1]]).cpu()
923
- sid = torch.LongTensor([0]).cpu()
924
- bert = torch.randn(size=(x.shape[1], 1024)).cpu()
925
- ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
926
- en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
927
-
928
- if self.n_speakers > 0:
929
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
930
- torch.onnx.export(
931
- self.emb_g,
932
- (sid),
933
- f"onnx/{path}/{path}_emb.onnx",
934
- input_names=["sid"],
935
- output_names=["g"],
936
- verbose=True,
937
- )
938
- else:
939
- g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
940
-
941
- self.enc_p.init_vq()
942
-
943
- torch.onnx.export(
944
- self.enc_p,
945
- (x, x_lengths, tone, language, bert, ja_bert, en_bert, g, sid, sid),
946
- f"onnx/{path}/{path}_enc_p.onnx",
947
- input_names=[
948
- "x",
949
- "x_lengths",
950
- "t",
951
- "language",
952
- "bert_0",
953
- "bert_1",
954
- "bert_2",
955
- "g",
956
- "vqidx",
957
- "sid",
958
- ],
959
- output_names=["xout", "m_p", "logs_p", "x_mask"],
960
- dynamic_axes={
961
- "x": [0, 1],
962
- "t": [0, 1],
963
- "language": [0, 1],
964
- "bert_0": [0],
965
- "bert_1": [0],
966
- "bert_2": [0],
967
- "xout": [0, 2],
968
- "m_p": [0, 2],
969
- "logs_p": [0, 2],
970
- "x_mask": [0, 2],
971
- },
972
- verbose=True,
973
- opset_version=16,
974
- )
975
-
976
- x, m_p, logs_p, x_mask = self.enc_p(
977
- x, x_lengths, tone, language, bert, ja_bert, en_bert, g, sid, sid
978
- )
979
-
980
- zinput = (
981
- torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
982
- * noise_scale_w
983
- )
984
- torch.onnx.export(
985
- self.sdp,
986
- (x, x_mask, zinput, g),
987
- f"onnx/{path}/{path}_sdp.onnx",
988
- input_names=["x", "x_mask", "zin", "g"],
989
- output_names=["logw"],
990
- dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
991
- verbose=True,
992
- )
993
- torch.onnx.export(
994
- self.dp,
995
- (x, x_mask, g),
996
- f"onnx/{path}/{path}_dp.onnx",
997
- input_names=["x", "x_mask", "g"],
998
- output_names=["logw"],
999
- dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
1000
- verbose=True,
1001
- )
1002
- logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
1003
- x, x_mask, g=g
1004
- ) * (1 - sdp_ratio)
1005
- w = torch.exp(logw) * x_mask * length_scale
1006
- w_ceil = torch.ceil(w)
1007
- y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1008
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1009
- x_mask.dtype
1010
- )
1011
- attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1012
- attn = commons.generate_path(w_ceil, attn_mask)
1013
-
1014
- m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1015
- 1, 2
1016
- ) # [b, t', t], [b, t, d] -> [b, d, t']
1017
- logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1018
- 1, 2
1019
- ) # [b, t', t], [b, t, d] -> [b, d, t']
1020
-
1021
- z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1022
- torch.onnx.export(
1023
- self.flow,
1024
- (z_p, y_mask, g),
1025
- f"onnx/{path}/{path}_flow.onnx",
1026
- input_names=["z_p", "y_mask", "g"],
1027
- output_names=["z"],
1028
- dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
1029
- verbose=True,
1030
- )
1031
-
1032
- z = self.flow(z_p, y_mask, g=g, reverse=True)
1033
- z_in = (z * y_mask)[:, :, :max_len]
1034
-
1035
- torch.onnx.export(
1036
- self.dec,
1037
- (z_in, g),
1038
- f"onnx/{path}/{path}_dec.onnx",
1039
- input_names=["z_in", "g"],
1040
- output_names=["o"],
1041
- dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
1042
- verbose=True,
1043
- )
1044
- o = self.dec((z * y_mask)[:, :, :max_len], g=g)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V210/text/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .symbols import *
 
 
onnx_modules/V210/text/symbols.py DELETED
@@ -1,187 +0,0 @@
1
- punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
- pu_symbols = punctuation + ["SP", "UNK"]
3
- pad = "_"
4
-
5
- # chinese
6
- zh_symbols = [
7
- "E",
8
- "En",
9
- "a",
10
- "ai",
11
- "an",
12
- "ang",
13
- "ao",
14
- "b",
15
- "c",
16
- "ch",
17
- "d",
18
- "e",
19
- "ei",
20
- "en",
21
- "eng",
22
- "er",
23
- "f",
24
- "g",
25
- "h",
26
- "i",
27
- "i0",
28
- "ia",
29
- "ian",
30
- "iang",
31
- "iao",
32
- "ie",
33
- "in",
34
- "ing",
35
- "iong",
36
- "ir",
37
- "iu",
38
- "j",
39
- "k",
40
- "l",
41
- "m",
42
- "n",
43
- "o",
44
- "ong",
45
- "ou",
46
- "p",
47
- "q",
48
- "r",
49
- "s",
50
- "sh",
51
- "t",
52
- "u",
53
- "ua",
54
- "uai",
55
- "uan",
56
- "uang",
57
- "ui",
58
- "un",
59
- "uo",
60
- "v",
61
- "van",
62
- "ve",
63
- "vn",
64
- "w",
65
- "x",
66
- "y",
67
- "z",
68
- "zh",
69
- "AA",
70
- "EE",
71
- "OO",
72
- ]
73
- num_zh_tones = 6
74
-
75
- # japanese
76
- ja_symbols = [
77
- "N",
78
- "a",
79
- "a:",
80
- "b",
81
- "by",
82
- "ch",
83
- "d",
84
- "dy",
85
- "e",
86
- "e:",
87
- "f",
88
- "g",
89
- "gy",
90
- "h",
91
- "hy",
92
- "i",
93
- "i:",
94
- "j",
95
- "k",
96
- "ky",
97
- "m",
98
- "my",
99
- "n",
100
- "ny",
101
- "o",
102
- "o:",
103
- "p",
104
- "py",
105
- "q",
106
- "r",
107
- "ry",
108
- "s",
109
- "sh",
110
- "t",
111
- "ts",
112
- "ty",
113
- "u",
114
- "u:",
115
- "w",
116
- "y",
117
- "z",
118
- "zy",
119
- ]
120
- num_ja_tones = 2
121
-
122
- # English
123
- en_symbols = [
124
- "aa",
125
- "ae",
126
- "ah",
127
- "ao",
128
- "aw",
129
- "ay",
130
- "b",
131
- "ch",
132
- "d",
133
- "dh",
134
- "eh",
135
- "er",
136
- "ey",
137
- "f",
138
- "g",
139
- "hh",
140
- "ih",
141
- "iy",
142
- "jh",
143
- "k",
144
- "l",
145
- "m",
146
- "n",
147
- "ng",
148
- "ow",
149
- "oy",
150
- "p",
151
- "r",
152
- "s",
153
- "sh",
154
- "t",
155
- "th",
156
- "uh",
157
- "uw",
158
- "V",
159
- "w",
160
- "y",
161
- "z",
162
- "zh",
163
- ]
164
- num_en_tones = 4
165
-
166
- # combine all symbols
167
- normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
- symbols = [pad] + normal_symbols + pu_symbols
169
- sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
-
171
- # combine all tones
172
- num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
-
174
- # language maps
175
- language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
- num_languages = len(language_id_map.keys())
177
-
178
- language_tone_start_map = {
179
- "ZH": 0,
180
- "JP": num_zh_tones,
181
- "EN": num_zh_tones + num_ja_tones,
182
- }
183
-
184
- if __name__ == "__main__":
185
- a = set(zh_symbols)
186
- b = set(en_symbols)
187
- print(sorted(a & b))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V220/__init__.py DELETED
File without changes
onnx_modules/V220/attentions_onnx.py DELETED
@@ -1,378 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- import commons
7
- import logging
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class LayerNorm(nn.Module):
13
- def __init__(self, channels, eps=1e-5):
14
- super().__init__()
15
- self.channels = channels
16
- self.eps = eps
17
-
18
- self.gamma = nn.Parameter(torch.ones(channels))
19
- self.beta = nn.Parameter(torch.zeros(channels))
20
-
21
- def forward(self, x):
22
- x = x.transpose(1, -1)
23
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
- return x.transpose(1, -1)
25
-
26
-
27
- @torch.jit.script
28
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
- n_channels_int = n_channels[0]
30
- in_act = input_a + input_b
31
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
- acts = t_act * s_act
34
- return acts
35
-
36
-
37
- class Encoder(nn.Module):
38
- def __init__(
39
- self,
40
- hidden_channels,
41
- filter_channels,
42
- n_heads,
43
- n_layers,
44
- kernel_size=1,
45
- p_dropout=0.0,
46
- window_size=4,
47
- isflow=True,
48
- **kwargs
49
- ):
50
- super().__init__()
51
- self.hidden_channels = hidden_channels
52
- self.filter_channels = filter_channels
53
- self.n_heads = n_heads
54
- self.n_layers = n_layers
55
- self.kernel_size = kernel_size
56
- self.p_dropout = p_dropout
57
- self.window_size = window_size
58
- # if isflow:
59
- # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
- # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
- # self.cond_layer = weight_norm(cond_layer, name='weight')
62
- # self.gin_channels = 256
63
- self.cond_layer_idx = self.n_layers
64
- if "gin_channels" in kwargs:
65
- self.gin_channels = kwargs["gin_channels"]
66
- if self.gin_channels != 0:
67
- self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
- # vits2 says 3rd block, so idx is 2 by default
69
- self.cond_layer_idx = (
70
- kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
- )
72
- logging.debug(self.gin_channels, self.cond_layer_idx)
73
- assert (
74
- self.cond_layer_idx < self.n_layers
75
- ), "cond_layer_idx should be less than n_layers"
76
- self.drop = nn.Dropout(p_dropout)
77
- self.attn_layers = nn.ModuleList()
78
- self.norm_layers_1 = nn.ModuleList()
79
- self.ffn_layers = nn.ModuleList()
80
- self.norm_layers_2 = nn.ModuleList()
81
- for i in range(self.n_layers):
82
- self.attn_layers.append(
83
- MultiHeadAttention(
84
- hidden_channels,
85
- hidden_channels,
86
- n_heads,
87
- p_dropout=p_dropout,
88
- window_size=window_size,
89
- )
90
- )
91
- self.norm_layers_1.append(LayerNorm(hidden_channels))
92
- self.ffn_layers.append(
93
- FFN(
94
- hidden_channels,
95
- hidden_channels,
96
- filter_channels,
97
- kernel_size,
98
- p_dropout=p_dropout,
99
- )
100
- )
101
- self.norm_layers_2.append(LayerNorm(hidden_channels))
102
-
103
- def forward(self, x, x_mask, g=None):
104
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105
- x = x * x_mask
106
- for i in range(self.n_layers):
107
- if i == self.cond_layer_idx and g is not None:
108
- g = self.spk_emb_linear(g.transpose(1, 2))
109
- g = g.transpose(1, 2)
110
- x = x + g
111
- x = x * x_mask
112
- y = self.attn_layers[i](x, x, attn_mask)
113
- y = self.drop(y)
114
- x = self.norm_layers_1[i](x + y)
115
-
116
- y = self.ffn_layers[i](x, x_mask)
117
- y = self.drop(y)
118
- x = self.norm_layers_2[i](x + y)
119
- x = x * x_mask
120
- return x
121
-
122
-
123
- class MultiHeadAttention(nn.Module):
124
- def __init__(
125
- self,
126
- channels,
127
- out_channels,
128
- n_heads,
129
- p_dropout=0.0,
130
- window_size=None,
131
- heads_share=True,
132
- block_length=None,
133
- proximal_bias=False,
134
- proximal_init=False,
135
- ):
136
- super().__init__()
137
- assert channels % n_heads == 0
138
-
139
- self.channels = channels
140
- self.out_channels = out_channels
141
- self.n_heads = n_heads
142
- self.p_dropout = p_dropout
143
- self.window_size = window_size
144
- self.heads_share = heads_share
145
- self.block_length = block_length
146
- self.proximal_bias = proximal_bias
147
- self.proximal_init = proximal_init
148
- self.attn = None
149
-
150
- self.k_channels = channels // n_heads
151
- self.conv_q = nn.Conv1d(channels, channels, 1)
152
- self.conv_k = nn.Conv1d(channels, channels, 1)
153
- self.conv_v = nn.Conv1d(channels, channels, 1)
154
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
155
- self.drop = nn.Dropout(p_dropout)
156
-
157
- if window_size is not None:
158
- n_heads_rel = 1 if heads_share else n_heads
159
- rel_stddev = self.k_channels**-0.5
160
- self.emb_rel_k = nn.Parameter(
161
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
162
- * rel_stddev
163
- )
164
- self.emb_rel_v = nn.Parameter(
165
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
166
- * rel_stddev
167
- )
168
-
169
- nn.init.xavier_uniform_(self.conv_q.weight)
170
- nn.init.xavier_uniform_(self.conv_k.weight)
171
- nn.init.xavier_uniform_(self.conv_v.weight)
172
- if proximal_init:
173
- with torch.no_grad():
174
- self.conv_k.weight.copy_(self.conv_q.weight)
175
- self.conv_k.bias.copy_(self.conv_q.bias)
176
-
177
- def forward(self, x, c, attn_mask=None):
178
- q = self.conv_q(x)
179
- k = self.conv_k(c)
180
- v = self.conv_v(c)
181
-
182
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
183
-
184
- x = self.conv_o(x)
185
- return x
186
-
187
- def attention(self, query, key, value, mask=None):
188
- # reshape [b, d, t] -> [b, n_h, t, d_k]
189
- b, d, t_s, t_t = (*key.size(), query.size(2))
190
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
191
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
193
-
194
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
195
- if self.window_size is not None:
196
- assert (
197
- t_s == t_t
198
- ), "Relative attention is only available for self-attention."
199
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
200
- rel_logits = self._matmul_with_relative_keys(
201
- query / math.sqrt(self.k_channels), key_relative_embeddings
202
- )
203
- scores_local = self._relative_position_to_absolute_position(rel_logits)
204
- scores = scores + scores_local
205
- if self.proximal_bias:
206
- assert t_s == t_t, "Proximal bias is only available for self-attention."
207
- scores = scores + self._attention_bias_proximal(t_s).to(
208
- device=scores.device, dtype=scores.dtype
209
- )
210
- if mask is not None:
211
- scores = scores.masked_fill(mask == 0, -1e4)
212
- if self.block_length is not None:
213
- assert (
214
- t_s == t_t
215
- ), "Local attention is only available for self-attention."
216
- block_mask = (
217
- torch.ones_like(scores)
218
- .triu(-self.block_length)
219
- .tril(self.block_length)
220
- )
221
- scores = scores.masked_fill(block_mask == 0, -1e4)
222
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
223
- p_attn = self.drop(p_attn)
224
- output = torch.matmul(p_attn, value)
225
- if self.window_size is not None:
226
- relative_weights = self._absolute_position_to_relative_position(p_attn)
227
- value_relative_embeddings = self._get_relative_embeddings(
228
- self.emb_rel_v, t_s
229
- )
230
- output = output + self._matmul_with_relative_values(
231
- relative_weights, value_relative_embeddings
232
- )
233
- output = (
234
- output.transpose(2, 3).contiguous().view(b, d, t_t)
235
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
236
- return output, p_attn
237
-
238
- def _matmul_with_relative_values(self, x, y):
239
- """
240
- x: [b, h, l, m]
241
- y: [h or 1, m, d]
242
- ret: [b, h, l, d]
243
- """
244
- ret = torch.matmul(x, y.unsqueeze(0))
245
- return ret
246
-
247
- def _matmul_with_relative_keys(self, x, y):
248
- """
249
- x: [b, h, l, d]
250
- y: [h or 1, m, d]
251
- ret: [b, h, l, m]
252
- """
253
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
254
- return ret
255
-
256
- def _get_relative_embeddings(self, relative_embeddings, length):
257
- max_relative_position = 2 * self.window_size + 1
258
- # Pad first before slice to avoid using cond ops.
259
- pad_length = max(length - (self.window_size + 1), 0)
260
- slice_start_position = max((self.window_size + 1) - length, 0)
261
- slice_end_position = slice_start_position + 2 * length - 1
262
- if pad_length > 0:
263
- padded_relative_embeddings = F.pad(
264
- relative_embeddings,
265
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
266
- )
267
- else:
268
- padded_relative_embeddings = relative_embeddings
269
- used_relative_embeddings = padded_relative_embeddings[
270
- :, slice_start_position:slice_end_position
271
- ]
272
- return used_relative_embeddings
273
-
274
- def _relative_position_to_absolute_position(self, x):
275
- """
276
- x: [b, h, l, 2*l-1]
277
- ret: [b, h, l, l]
278
- """
279
- batch, heads, length, _ = x.size()
280
- # Concat columns of pad to shift from relative to absolute indexing.
281
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
282
-
283
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
284
- x_flat = x.view([batch, heads, length * 2 * length])
285
- x_flat = F.pad(
286
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
287
- )
288
-
289
- # Reshape and slice out the padded elements.
290
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
291
- :, :, :length, length - 1 :
292
- ]
293
- return x_final
294
-
295
- def _absolute_position_to_relative_position(self, x):
296
- """
297
- x: [b, h, l, l]
298
- ret: [b, h, l, 2*l-1]
299
- """
300
- batch, heads, length, _ = x.size()
301
- # padd along column
302
- x = F.pad(
303
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
304
- )
305
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
306
- # add 0's in the beginning that will skew the elements after reshape
307
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
308
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
309
- return x_final
310
-
311
- def _attention_bias_proximal(self, length):
312
- """Bias for self-attention to encourage attention to close positions.
313
- Args:
314
- length: an integer scalar.
315
- Returns:
316
- a Tensor with shape [1, 1, length, length]
317
- """
318
- r = torch.arange(length, dtype=torch.float32)
319
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
320
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
321
-
322
-
323
- class FFN(nn.Module):
324
- def __init__(
325
- self,
326
- in_channels,
327
- out_channels,
328
- filter_channels,
329
- kernel_size,
330
- p_dropout=0.0,
331
- activation=None,
332
- causal=False,
333
- ):
334
- super().__init__()
335
- self.in_channels = in_channels
336
- self.out_channels = out_channels
337
- self.filter_channels = filter_channels
338
- self.kernel_size = kernel_size
339
- self.p_dropout = p_dropout
340
- self.activation = activation
341
- self.causal = causal
342
-
343
- if causal:
344
- self.padding = self._causal_padding
345
- else:
346
- self.padding = self._same_padding
347
-
348
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
349
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
350
- self.drop = nn.Dropout(p_dropout)
351
-
352
- def forward(self, x, x_mask):
353
- x = self.conv_1(self.padding(x * x_mask))
354
- if self.activation == "gelu":
355
- x = x * torch.sigmoid(1.702 * x)
356
- else:
357
- x = torch.relu(x)
358
- x = self.drop(x)
359
- x = self.conv_2(self.padding(x * x_mask))
360
- return x * x_mask
361
-
362
- def _causal_padding(self, x):
363
- if self.kernel_size == 1:
364
- return x
365
- pad_l = self.kernel_size - 1
366
- pad_r = 0
367
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
368
- x = F.pad(x, commons.convert_pad_shape(padding))
369
- return x
370
-
371
- def _same_padding(self, x):
372
- if self.kernel_size == 1:
373
- return x
374
- pad_l = (self.kernel_size - 1) // 2
375
- pad_r = self.kernel_size // 2
376
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
377
- x = F.pad(x, commons.convert_pad_shape(padding))
378
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V220/models_onnx.py DELETED
@@ -1,1076 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- import commons
7
- import modules
8
- from . import attentions_onnx
9
- from vector_quantize_pytorch import VectorQuantize
10
-
11
- from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
- from commons import init_weights, get_padding
14
- from .text import symbols, num_tones, num_languages
15
-
16
-
17
- class DurationDiscriminator(nn.Module): # vits2
18
- def __init__(
19
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
20
- ):
21
- super().__init__()
22
-
23
- self.in_channels = in_channels
24
- self.filter_channels = filter_channels
25
- self.kernel_size = kernel_size
26
- self.p_dropout = p_dropout
27
- self.gin_channels = gin_channels
28
-
29
- self.drop = nn.Dropout(p_dropout)
30
- self.conv_1 = nn.Conv1d(
31
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
32
- )
33
- self.norm_1 = modules.LayerNorm(filter_channels)
34
- self.conv_2 = nn.Conv1d(
35
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
36
- )
37
- self.norm_2 = modules.LayerNorm(filter_channels)
38
- self.dur_proj = nn.Conv1d(1, filter_channels, 1)
39
-
40
- self.pre_out_conv_1 = nn.Conv1d(
41
- 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
42
- )
43
- self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
44
- self.pre_out_conv_2 = nn.Conv1d(
45
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
46
- )
47
- self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
48
-
49
- if gin_channels != 0:
50
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
51
-
52
- self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
53
-
54
- def forward_probability(self, x, x_mask, dur, g=None):
55
- dur = self.dur_proj(dur)
56
- x = torch.cat([x, dur], dim=1)
57
- x = self.pre_out_conv_1(x * x_mask)
58
- x = torch.relu(x)
59
- x = self.pre_out_norm_1(x)
60
- x = self.drop(x)
61
- x = self.pre_out_conv_2(x * x_mask)
62
- x = torch.relu(x)
63
- x = self.pre_out_norm_2(x)
64
- x = self.drop(x)
65
- x = x * x_mask
66
- x = x.transpose(1, 2)
67
- output_prob = self.output_layer(x)
68
- return output_prob
69
-
70
- def forward(self, x, x_mask, dur_r, dur_hat, g=None):
71
- x = torch.detach(x)
72
- if g is not None:
73
- g = torch.detach(g)
74
- x = x + self.cond(g)
75
- x = self.conv_1(x * x_mask)
76
- x = torch.relu(x)
77
- x = self.norm_1(x)
78
- x = self.drop(x)
79
- x = self.conv_2(x * x_mask)
80
- x = torch.relu(x)
81
- x = self.norm_2(x)
82
- x = self.drop(x)
83
-
84
- output_probs = []
85
- for dur in [dur_r, dur_hat]:
86
- output_prob = self.forward_probability(x, x_mask, dur, g)
87
- output_probs.append(output_prob)
88
-
89
- return output_probs
90
-
91
-
92
- class TransformerCouplingBlock(nn.Module):
93
- def __init__(
94
- self,
95
- channels,
96
- hidden_channels,
97
- filter_channels,
98
- n_heads,
99
- n_layers,
100
- kernel_size,
101
- p_dropout,
102
- n_flows=4,
103
- gin_channels=0,
104
- share_parameter=False,
105
- ):
106
- super().__init__()
107
- self.channels = channels
108
- self.hidden_channels = hidden_channels
109
- self.kernel_size = kernel_size
110
- self.n_layers = n_layers
111
- self.n_flows = n_flows
112
- self.gin_channels = gin_channels
113
-
114
- self.flows = nn.ModuleList()
115
-
116
- self.wn = (
117
- attentions_onnx.FFT(
118
- hidden_channels,
119
- filter_channels,
120
- n_heads,
121
- n_layers,
122
- kernel_size,
123
- p_dropout,
124
- isflow=True,
125
- gin_channels=self.gin_channels,
126
- )
127
- if share_parameter
128
- else None
129
- )
130
-
131
- for i in range(n_flows):
132
- self.flows.append(
133
- modules.TransformerCouplingLayer(
134
- channels,
135
- hidden_channels,
136
- kernel_size,
137
- n_layers,
138
- n_heads,
139
- p_dropout,
140
- filter_channels,
141
- mean_only=True,
142
- wn_sharing_parameter=self.wn,
143
- gin_channels=self.gin_channels,
144
- )
145
- )
146
- self.flows.append(modules.Flip())
147
-
148
- def forward(self, x, x_mask, g=None, reverse=True):
149
- if not reverse:
150
- for flow in self.flows:
151
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
152
- else:
153
- for flow in reversed(self.flows):
154
- x = flow(x, x_mask, g=g, reverse=reverse)
155
- return x
156
-
157
-
158
- class StochasticDurationPredictor(nn.Module):
159
- def __init__(
160
- self,
161
- in_channels,
162
- filter_channels,
163
- kernel_size,
164
- p_dropout,
165
- n_flows=4,
166
- gin_channels=0,
167
- ):
168
- super().__init__()
169
- filter_channels = in_channels # it needs to be removed from future version.
170
- self.in_channels = in_channels
171
- self.filter_channels = filter_channels
172
- self.kernel_size = kernel_size
173
- self.p_dropout = p_dropout
174
- self.n_flows = n_flows
175
- self.gin_channels = gin_channels
176
-
177
- self.log_flow = modules.Log()
178
- self.flows = nn.ModuleList()
179
- self.flows.append(modules.ElementwiseAffine(2))
180
- for i in range(n_flows):
181
- self.flows.append(
182
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
183
- )
184
- self.flows.append(modules.Flip())
185
-
186
- self.post_pre = nn.Conv1d(1, filter_channels, 1)
187
- self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
188
- self.post_convs = modules.DDSConv(
189
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
190
- )
191
- self.post_flows = nn.ModuleList()
192
- self.post_flows.append(modules.ElementwiseAffine(2))
193
- for i in range(4):
194
- self.post_flows.append(
195
- modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
196
- )
197
- self.post_flows.append(modules.Flip())
198
-
199
- self.pre = nn.Conv1d(in_channels, filter_channels, 1)
200
- self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
201
- self.convs = modules.DDSConv(
202
- filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
203
- )
204
- if gin_channels != 0:
205
- self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
206
-
207
- def forward(self, x, x_mask, z, g=None):
208
- x = torch.detach(x)
209
- x = self.pre(x)
210
- if g is not None:
211
- g = torch.detach(g)
212
- x = x + self.cond(g)
213
- x = self.convs(x, x_mask)
214
- x = self.proj(x) * x_mask
215
-
216
- flows = list(reversed(self.flows))
217
- flows = flows[:-2] + [flows[-1]] # remove a useless vflow
218
- for flow in flows:
219
- z = flow(z, x_mask, g=x, reverse=True)
220
- z0, z1 = torch.split(z, [1, 1], 1)
221
- logw = z0
222
- return logw
223
-
224
-
225
- class DurationPredictor(nn.Module):
226
- def __init__(
227
- self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
228
- ):
229
- super().__init__()
230
-
231
- self.in_channels = in_channels
232
- self.filter_channels = filter_channels
233
- self.kernel_size = kernel_size
234
- self.p_dropout = p_dropout
235
- self.gin_channels = gin_channels
236
-
237
- self.drop = nn.Dropout(p_dropout)
238
- self.conv_1 = nn.Conv1d(
239
- in_channels, filter_channels, kernel_size, padding=kernel_size // 2
240
- )
241
- self.norm_1 = modules.LayerNorm(filter_channels)
242
- self.conv_2 = nn.Conv1d(
243
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
244
- )
245
- self.norm_2 = modules.LayerNorm(filter_channels)
246
- self.proj = nn.Conv1d(filter_channels, 1, 1)
247
-
248
- if gin_channels != 0:
249
- self.cond = nn.Conv1d(gin_channels, in_channels, 1)
250
-
251
- def forward(self, x, x_mask, g=None):
252
- x = torch.detach(x)
253
- if g is not None:
254
- g = torch.detach(g)
255
- x = x + self.cond(g)
256
- x = self.conv_1(x * x_mask)
257
- x = torch.relu(x)
258
- x = self.norm_1(x)
259
- x = self.drop(x)
260
- x = self.conv_2(x * x_mask)
261
- x = torch.relu(x)
262
- x = self.norm_2(x)
263
- x = self.drop(x)
264
- x = self.proj(x * x_mask)
265
- return x * x_mask
266
-
267
-
268
- class Bottleneck(nn.Sequential):
269
- def __init__(self, in_dim, hidden_dim):
270
- c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
271
- c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
272
- super().__init__(*[c_fc1, c_fc2])
273
-
274
-
275
- class Block(nn.Module):
276
- def __init__(self, in_dim, hidden_dim) -> None:
277
- super().__init__()
278
- self.norm = nn.LayerNorm(in_dim)
279
- self.mlp = MLP(in_dim, hidden_dim)
280
-
281
- def forward(self, x: torch.Tensor) -> torch.Tensor:
282
- x = x + self.mlp(self.norm(x))
283
- return x
284
-
285
-
286
- class MLP(nn.Module):
287
- def __init__(self, in_dim, hidden_dim):
288
- super().__init__()
289
- self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
290
- self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
291
- self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
292
-
293
- def forward(self, x: torch.Tensor):
294
- x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
295
- x = self.c_proj(x)
296
- return x
297
-
298
-
299
- class TextEncoder(nn.Module):
300
- def __init__(
301
- self,
302
- n_vocab,
303
- out_channels,
304
- hidden_channels,
305
- filter_channels,
306
- n_heads,
307
- n_layers,
308
- kernel_size,
309
- p_dropout,
310
- n_speakers,
311
- gin_channels=0,
312
- ):
313
- super().__init__()
314
- self.n_vocab = n_vocab
315
- self.out_channels = out_channels
316
- self.hidden_channels = hidden_channels
317
- self.filter_channels = filter_channels
318
- self.n_heads = n_heads
319
- self.n_layers = n_layers
320
- self.kernel_size = kernel_size
321
- self.p_dropout = p_dropout
322
- self.gin_channels = gin_channels
323
- self.emb = nn.Embedding(len(symbols), hidden_channels)
324
- nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
325
- self.tone_emb = nn.Embedding(num_tones, hidden_channels)
326
- nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
327
- self.language_emb = nn.Embedding(num_languages, hidden_channels)
328
- nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
329
- self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
330
- self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
331
- self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
332
- # self.emo_proj = nn.Linear(1024, 1024)
333
- # self.emo_quantizer = nn.ModuleList()
334
- # for i in range(0, n_speakers):
335
- # self.emo_quantizer.append(
336
- # VectorQuantize(
337
- # dim=1024,
338
- # codebook_size=10,
339
- # decay=0.8,
340
- # commitment_weight=1.0,
341
- # learnable_codebook=True,
342
- # ema_update=False,
343
- # )
344
- # )
345
- # self.emo_q_proj = nn.Linear(1024, hidden_channels)
346
- self.n_speakers = n_speakers
347
- self.in_feature_net = nn.Sequential(
348
- # input is assumed to an already normalized embedding
349
- nn.Linear(512, 1028, bias=False),
350
- nn.GELU(),
351
- nn.LayerNorm(1028),
352
- *[Block(1028, 512) for _ in range(1)],
353
- nn.Linear(1028, 512, bias=False),
354
- # normalize before passing to VQ?
355
- # nn.GELU(),
356
- # nn.LayerNorm(512),
357
- )
358
- self.emo_vq = VectorQuantize(
359
- dim=512,
360
- codebook_size=64,
361
- codebook_dim=32,
362
- commitment_weight=0.1,
363
- decay=0.85,
364
- heads=32,
365
- kmeans_iters=20,
366
- separate_codebook_per_head=True,
367
- stochastic_sample_codes=True,
368
- threshold_ema_dead_code=2,
369
- )
370
- self.out_feature_net = nn.Linear(512, hidden_channels)
371
-
372
- self.encoder = attentions_onnx.Encoder(
373
- hidden_channels,
374
- filter_channels,
375
- n_heads,
376
- n_layers,
377
- kernel_size,
378
- p_dropout,
379
- gin_channels=self.gin_channels,
380
- )
381
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
382
-
383
- def forward(
384
- self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g=None
385
- ):
386
- x_mask = torch.ones_like(x).unsqueeze(0)
387
- bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
388
- ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
389
- 1, 2
390
- )
391
- en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
392
- 1, 2
393
- )
394
- emo_emb = self.in_feature_net(emo.transpose(0, 1))
395
- emo_emb, _, _ = self.emo_vq(emo_emb.unsqueeze(1))
396
-
397
- emo_emb = self.out_feature_net(emo_emb)
398
-
399
- x = (
400
- self.emb(x)
401
- + self.tone_emb(tone)
402
- + self.language_emb(language)
403
- + bert_emb
404
- + ja_bert_emb
405
- + en_bert_emb
406
- + emo_emb
407
- ) * math.sqrt(
408
- self.hidden_channels
409
- ) # [b, t, h]
410
- x = torch.transpose(x, 1, -1) # [b, h, t]
411
- x_mask = x_mask.to(x.dtype)
412
-
413
- x = self.encoder(x * x_mask, x_mask, g=g)
414
- stats = self.proj(x) * x_mask
415
-
416
- m, logs = torch.split(stats, self.out_channels, dim=1)
417
- return x, m, logs, x_mask
418
-
419
-
420
- class ResidualCouplingBlock(nn.Module):
421
- def __init__(
422
- self,
423
- channels,
424
- hidden_channels,
425
- kernel_size,
426
- dilation_rate,
427
- n_layers,
428
- n_flows=4,
429
- gin_channels=0,
430
- ):
431
- super().__init__()
432
- self.channels = channels
433
- self.hidden_channels = hidden_channels
434
- self.kernel_size = kernel_size
435
- self.dilation_rate = dilation_rate
436
- self.n_layers = n_layers
437
- self.n_flows = n_flows
438
- self.gin_channels = gin_channels
439
-
440
- self.flows = nn.ModuleList()
441
- for i in range(n_flows):
442
- self.flows.append(
443
- modules.ResidualCouplingLayer(
444
- channels,
445
- hidden_channels,
446
- kernel_size,
447
- dilation_rate,
448
- n_layers,
449
- gin_channels=gin_channels,
450
- mean_only=True,
451
- )
452
- )
453
- self.flows.append(modules.Flip())
454
-
455
- def forward(self, x, x_mask, g=None, reverse=True):
456
- if not reverse:
457
- for flow in self.flows:
458
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
459
- else:
460
- for flow in reversed(self.flows):
461
- x = flow(x, x_mask, g=g, reverse=reverse)
462
- return x
463
-
464
-
465
- class PosteriorEncoder(nn.Module):
466
- def __init__(
467
- self,
468
- in_channels,
469
- out_channels,
470
- hidden_channels,
471
- kernel_size,
472
- dilation_rate,
473
- n_layers,
474
- gin_channels=0,
475
- ):
476
- super().__init__()
477
- self.in_channels = in_channels
478
- self.out_channels = out_channels
479
- self.hidden_channels = hidden_channels
480
- self.kernel_size = kernel_size
481
- self.dilation_rate = dilation_rate
482
- self.n_layers = n_layers
483
- self.gin_channels = gin_channels
484
-
485
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
486
- self.enc = modules.WN(
487
- hidden_channels,
488
- kernel_size,
489
- dilation_rate,
490
- n_layers,
491
- gin_channels=gin_channels,
492
- )
493
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
494
-
495
- def forward(self, x, x_lengths, g=None):
496
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
497
- x.dtype
498
- )
499
- x = self.pre(x) * x_mask
500
- x = self.enc(x, x_mask, g=g)
501
- stats = self.proj(x) * x_mask
502
- m, logs = torch.split(stats, self.out_channels, dim=1)
503
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
504
- return z, m, logs, x_mask
505
-
506
-
507
- class Generator(torch.nn.Module):
508
- def __init__(
509
- self,
510
- initial_channel,
511
- resblock,
512
- resblock_kernel_sizes,
513
- resblock_dilation_sizes,
514
- upsample_rates,
515
- upsample_initial_channel,
516
- upsample_kernel_sizes,
517
- gin_channels=0,
518
- ):
519
- super(Generator, self).__init__()
520
- self.num_kernels = len(resblock_kernel_sizes)
521
- self.num_upsamples = len(upsample_rates)
522
- self.conv_pre = Conv1d(
523
- initial_channel, upsample_initial_channel, 7, 1, padding=3
524
- )
525
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
526
-
527
- self.ups = nn.ModuleList()
528
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
529
- self.ups.append(
530
- weight_norm(
531
- ConvTranspose1d(
532
- upsample_initial_channel // (2**i),
533
- upsample_initial_channel // (2 ** (i + 1)),
534
- k,
535
- u,
536
- padding=(k - u) // 2,
537
- )
538
- )
539
- )
540
-
541
- self.resblocks = nn.ModuleList()
542
- for i in range(len(self.ups)):
543
- ch = upsample_initial_channel // (2 ** (i + 1))
544
- for j, (k, d) in enumerate(
545
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
546
- ):
547
- self.resblocks.append(resblock(ch, k, d))
548
-
549
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
550
- self.ups.apply(init_weights)
551
-
552
- if gin_channels != 0:
553
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
554
-
555
- def forward(self, x, g=None):
556
- x = self.conv_pre(x)
557
- if g is not None:
558
- x = x + self.cond(g)
559
-
560
- for i in range(self.num_upsamples):
561
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
562
- x = self.ups[i](x)
563
- xs = None
564
- for j in range(self.num_kernels):
565
- if xs is None:
566
- xs = self.resblocks[i * self.num_kernels + j](x)
567
- else:
568
- xs += self.resblocks[i * self.num_kernels + j](x)
569
- x = xs / self.num_kernels
570
- x = F.leaky_relu(x)
571
- x = self.conv_post(x)
572
- x = torch.tanh(x)
573
-
574
- return x
575
-
576
- def remove_weight_norm(self):
577
- print("Removing weight norm...")
578
- for layer in self.ups:
579
- remove_weight_norm(layer)
580
- for layer in self.resblocks:
581
- layer.remove_weight_norm()
582
-
583
-
584
- class DiscriminatorP(torch.nn.Module):
585
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
586
- super(DiscriminatorP, self).__init__()
587
- self.period = period
588
- self.use_spectral_norm = use_spectral_norm
589
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
590
- self.convs = nn.ModuleList(
591
- [
592
- norm_f(
593
- Conv2d(
594
- 1,
595
- 32,
596
- (kernel_size, 1),
597
- (stride, 1),
598
- padding=(get_padding(kernel_size, 1), 0),
599
- )
600
- ),
601
- norm_f(
602
- Conv2d(
603
- 32,
604
- 128,
605
- (kernel_size, 1),
606
- (stride, 1),
607
- padding=(get_padding(kernel_size, 1), 0),
608
- )
609
- ),
610
- norm_f(
611
- Conv2d(
612
- 128,
613
- 512,
614
- (kernel_size, 1),
615
- (stride, 1),
616
- padding=(get_padding(kernel_size, 1), 0),
617
- )
618
- ),
619
- norm_f(
620
- Conv2d(
621
- 512,
622
- 1024,
623
- (kernel_size, 1),
624
- (stride, 1),
625
- padding=(get_padding(kernel_size, 1), 0),
626
- )
627
- ),
628
- norm_f(
629
- Conv2d(
630
- 1024,
631
- 1024,
632
- (kernel_size, 1),
633
- 1,
634
- padding=(get_padding(kernel_size, 1), 0),
635
- )
636
- ),
637
- ]
638
- )
639
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
640
-
641
- def forward(self, x):
642
- fmap = []
643
-
644
- # 1d to 2d
645
- b, c, t = x.shape
646
- if t % self.period != 0: # pad first
647
- n_pad = self.period - (t % self.period)
648
- x = F.pad(x, (0, n_pad), "reflect")
649
- t = t + n_pad
650
- x = x.view(b, c, t // self.period, self.period)
651
-
652
- for layer in self.convs:
653
- x = layer(x)
654
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
655
- fmap.append(x)
656
- x = self.conv_post(x)
657
- fmap.append(x)
658
- x = torch.flatten(x, 1, -1)
659
-
660
- return x, fmap
661
-
662
-
663
- class DiscriminatorS(torch.nn.Module):
664
- def __init__(self, use_spectral_norm=False):
665
- super(DiscriminatorS, self).__init__()
666
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
667
- self.convs = nn.ModuleList(
668
- [
669
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
670
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
671
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
672
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
673
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
674
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
675
- ]
676
- )
677
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
678
-
679
- def forward(self, x):
680
- fmap = []
681
-
682
- for layer in self.convs:
683
- x = layer(x)
684
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
685
- fmap.append(x)
686
- x = self.conv_post(x)
687
- fmap.append(x)
688
- x = torch.flatten(x, 1, -1)
689
-
690
- return x, fmap
691
-
692
-
693
- class MultiPeriodDiscriminator(torch.nn.Module):
694
- def __init__(self, use_spectral_norm=False):
695
- super(MultiPeriodDiscriminator, self).__init__()
696
- periods = [2, 3, 5, 7, 11]
697
-
698
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
699
- discs = discs + [
700
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
701
- ]
702
- self.discriminators = nn.ModuleList(discs)
703
-
704
- def forward(self, y, y_hat):
705
- y_d_rs = []
706
- y_d_gs = []
707
- fmap_rs = []
708
- fmap_gs = []
709
- for i, d in enumerate(self.discriminators):
710
- y_d_r, fmap_r = d(y)
711
- y_d_g, fmap_g = d(y_hat)
712
- y_d_rs.append(y_d_r)
713
- y_d_gs.append(y_d_g)
714
- fmap_rs.append(fmap_r)
715
- fmap_gs.append(fmap_g)
716
-
717
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
718
-
719
-
720
- class ReferenceEncoder(nn.Module):
721
- """
722
- inputs --- [N, Ty/r, n_mels*r] mels
723
- outputs --- [N, ref_enc_gru_size]
724
- """
725
-
726
- def __init__(self, spec_channels, gin_channels=0):
727
- super().__init__()
728
- self.spec_channels = spec_channels
729
- ref_enc_filters = [32, 32, 64, 64, 128, 128]
730
- K = len(ref_enc_filters)
731
- filters = [1] + ref_enc_filters
732
- convs = [
733
- weight_norm(
734
- nn.Conv2d(
735
- in_channels=filters[i],
736
- out_channels=filters[i + 1],
737
- kernel_size=(3, 3),
738
- stride=(2, 2),
739
- padding=(1, 1),
740
- )
741
- )
742
- for i in range(K)
743
- ]
744
- self.convs = nn.ModuleList(convs)
745
- # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
746
-
747
- out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
748
- self.gru = nn.GRU(
749
- input_size=ref_enc_filters[-1] * out_channels,
750
- hidden_size=256 // 2,
751
- batch_first=True,
752
- )
753
- self.proj = nn.Linear(128, gin_channels)
754
-
755
- def forward(self, inputs, mask=None):
756
- N = inputs.size(0)
757
- out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
758
- for conv in self.convs:
759
- out = conv(out)
760
- # out = wn(out)
761
- out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
762
-
763
- out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
764
- T = out.size(1)
765
- N = out.size(0)
766
- out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
767
-
768
- self.gru.flatten_parameters()
769
- memory, out = self.gru(out) # out --- [1, N, 128]
770
-
771
- return self.proj(out.squeeze(0))
772
-
773
- def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
774
- for i in range(n_convs):
775
- L = (L - kernel_size + 2 * pad) // stride + 1
776
- return L
777
-
778
-
779
- class SynthesizerTrn(nn.Module):
780
- """
781
- Synthesizer for Training
782
- """
783
-
784
- def __init__(
785
- self,
786
- n_vocab,
787
- spec_channels,
788
- segment_size,
789
- inter_channels,
790
- hidden_channels,
791
- filter_channels,
792
- n_heads,
793
- n_layers,
794
- kernel_size,
795
- p_dropout,
796
- resblock,
797
- resblock_kernel_sizes,
798
- resblock_dilation_sizes,
799
- upsample_rates,
800
- upsample_initial_channel,
801
- upsample_kernel_sizes,
802
- n_speakers=256,
803
- gin_channels=256,
804
- use_sdp=True,
805
- n_flow_layer=4,
806
- n_layers_trans_flow=4,
807
- flow_share_parameter=False,
808
- use_transformer_flow=True,
809
- **kwargs,
810
- ):
811
- super().__init__()
812
- self.n_vocab = n_vocab
813
- self.spec_channels = spec_channels
814
- self.inter_channels = inter_channels
815
- self.hidden_channels = hidden_channels
816
- self.filter_channels = filter_channels
817
- self.n_heads = n_heads
818
- self.n_layers = n_layers
819
- self.kernel_size = kernel_size
820
- self.p_dropout = p_dropout
821
- self.resblock = resblock
822
- self.resblock_kernel_sizes = resblock_kernel_sizes
823
- self.resblock_dilation_sizes = resblock_dilation_sizes
824
- self.upsample_rates = upsample_rates
825
- self.upsample_initial_channel = upsample_initial_channel
826
- self.upsample_kernel_sizes = upsample_kernel_sizes
827
- self.segment_size = segment_size
828
- self.n_speakers = n_speakers
829
- self.gin_channels = gin_channels
830
- self.n_layers_trans_flow = n_layers_trans_flow
831
- self.use_spk_conditioned_encoder = kwargs.get(
832
- "use_spk_conditioned_encoder", True
833
- )
834
- self.use_sdp = use_sdp
835
- self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
836
- self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
837
- self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
838
- self.current_mas_noise_scale = self.mas_noise_scale_initial
839
- if self.use_spk_conditioned_encoder and gin_channels > 0:
840
- self.enc_gin_channels = gin_channels
841
- self.enc_p = TextEncoder(
842
- n_vocab,
843
- inter_channels,
844
- hidden_channels,
845
- filter_channels,
846
- n_heads,
847
- n_layers,
848
- kernel_size,
849
- p_dropout,
850
- self.n_speakers,
851
- gin_channels=self.enc_gin_channels,
852
- )
853
- self.dec = Generator(
854
- inter_channels,
855
- resblock,
856
- resblock_kernel_sizes,
857
- resblock_dilation_sizes,
858
- upsample_rates,
859
- upsample_initial_channel,
860
- upsample_kernel_sizes,
861
- gin_channels=gin_channels,
862
- )
863
- self.enc_q = PosteriorEncoder(
864
- spec_channels,
865
- inter_channels,
866
- hidden_channels,
867
- 5,
868
- 1,
869
- 16,
870
- gin_channels=gin_channels,
871
- )
872
- if use_transformer_flow:
873
- self.flow = TransformerCouplingBlock(
874
- inter_channels,
875
- hidden_channels,
876
- filter_channels,
877
- n_heads,
878
- n_layers_trans_flow,
879
- 5,
880
- p_dropout,
881
- n_flow_layer,
882
- gin_channels=gin_channels,
883
- share_parameter=flow_share_parameter,
884
- )
885
- else:
886
- self.flow = ResidualCouplingBlock(
887
- inter_channels,
888
- hidden_channels,
889
- 5,
890
- 1,
891
- n_flow_layer,
892
- gin_channels=gin_channels,
893
- )
894
- self.sdp = StochasticDurationPredictor(
895
- hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
896
- )
897
- self.dp = DurationPredictor(
898
- hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
899
- )
900
-
901
- if n_speakers >= 1:
902
- self.emb_g = nn.Embedding(n_speakers, gin_channels)
903
- else:
904
- self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
905
-
906
- def export_onnx(
907
- self,
908
- path,
909
- max_len=None,
910
- sdp_ratio=0,
911
- y=None,
912
- ):
913
- noise_scale = 0.667
914
- length_scale = 1
915
- noise_scale_w = 0.8
916
- x = (
917
- torch.LongTensor(
918
- [
919
- 0,
920
- 97,
921
- 0,
922
- 8,
923
- 0,
924
- 78,
925
- 0,
926
- 8,
927
- 0,
928
- 76,
929
- 0,
930
- 37,
931
- 0,
932
- 40,
933
- 0,
934
- 97,
935
- 0,
936
- 8,
937
- 0,
938
- 23,
939
- 0,
940
- 8,
941
- 0,
942
- 74,
943
- 0,
944
- 26,
945
- 0,
946
- 104,
947
- 0,
948
- ]
949
- )
950
- .unsqueeze(0)
951
- .cpu()
952
- )
953
- tone = torch.zeros_like(x).cpu()
954
- language = torch.zeros_like(x).cpu()
955
- x_lengths = torch.LongTensor([x.shape[1]]).cpu()
956
- sid = torch.LongTensor([0]).cpu()
957
- bert = torch.randn(size=(x.shape[1], 1024)).cpu()
958
- ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
959
- en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
960
-
961
- if self.n_speakers > 0:
962
- g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
963
- torch.onnx.export(
964
- self.emb_g,
965
- (sid),
966
- f"onnx/{path}/{path}_emb.onnx",
967
- input_names=["sid"],
968
- output_names=["g"],
969
- verbose=True,
970
- )
971
- else:
972
- g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
973
-
974
- emo = torch.randn(512, 1)
975
-
976
- torch.onnx.export(
977
- self.enc_p,
978
- (x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g),
979
- f"onnx/{path}/{path}_enc_p.onnx",
980
- input_names=[
981
- "x",
982
- "x_lengths",
983
- "t",
984
- "language",
985
- "bert_0",
986
- "bert_1",
987
- "bert_2",
988
- "emo",
989
- "g",
990
- ],
991
- output_names=["xout", "m_p", "logs_p", "x_mask"],
992
- dynamic_axes={
993
- "x": [0, 1],
994
- "t": [0, 1],
995
- "language": [0, 1],
996
- "bert_0": [0],
997
- "bert_1": [0],
998
- "bert_2": [0],
999
- "xout": [0, 2],
1000
- "m_p": [0, 2],
1001
- "logs_p": [0, 2],
1002
- "x_mask": [0, 2],
1003
- },
1004
- verbose=True,
1005
- opset_version=16,
1006
- )
1007
-
1008
- x, m_p, logs_p, x_mask = self.enc_p(
1009
- x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g
1010
- )
1011
-
1012
- zinput = (
1013
- torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
1014
- * noise_scale_w
1015
- )
1016
- torch.onnx.export(
1017
- self.sdp,
1018
- (x, x_mask, zinput, g),
1019
- f"onnx/{path}/{path}_sdp.onnx",
1020
- input_names=["x", "x_mask", "zin", "g"],
1021
- output_names=["logw"],
1022
- dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
1023
- verbose=True,
1024
- )
1025
- torch.onnx.export(
1026
- self.dp,
1027
- (x, x_mask, g),
1028
- f"onnx/{path}/{path}_dp.onnx",
1029
- input_names=["x", "x_mask", "g"],
1030
- output_names=["logw"],
1031
- dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
1032
- verbose=True,
1033
- )
1034
- logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
1035
- x, x_mask, g=g
1036
- ) * (1 - sdp_ratio)
1037
- w = torch.exp(logw) * x_mask * length_scale
1038
- w_ceil = torch.ceil(w)
1039
- y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1040
- y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1041
- x_mask.dtype
1042
- )
1043
- attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1044
- attn = commons.generate_path(w_ceil, attn_mask)
1045
-
1046
- m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1047
- 1, 2
1048
- ) # [b, t', t], [b, t, d] -> [b, d, t']
1049
- logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1050
- 1, 2
1051
- ) # [b, t', t], [b, t, d] -> [b, d, t']
1052
-
1053
- z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1054
- torch.onnx.export(
1055
- self.flow,
1056
- (z_p, y_mask, g),
1057
- f"onnx/{path}/{path}_flow.onnx",
1058
- input_names=["z_p", "y_mask", "g"],
1059
- output_names=["z"],
1060
- dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
1061
- verbose=True,
1062
- )
1063
-
1064
- z = self.flow(z_p, y_mask, g=g, reverse=True)
1065
- z_in = (z * y_mask)[:, :, :max_len]
1066
-
1067
- torch.onnx.export(
1068
- self.dec,
1069
- (z_in, g),
1070
- f"onnx/{path}/{path}_dec.onnx",
1071
- input_names=["z_in", "g"],
1072
- output_names=["o"],
1073
- dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
1074
- verbose=True,
1075
- )
1076
- o = self.dec((z * y_mask)[:, :, :max_len], g=g)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/V220/text/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .symbols import *
 
 
onnx_modules/V220/text/symbols.py DELETED
@@ -1,187 +0,0 @@
1
- punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
- pu_symbols = punctuation + ["SP", "UNK"]
3
- pad = "_"
4
-
5
- # chinese
6
- zh_symbols = [
7
- "E",
8
- "En",
9
- "a",
10
- "ai",
11
- "an",
12
- "ang",
13
- "ao",
14
- "b",
15
- "c",
16
- "ch",
17
- "d",
18
- "e",
19
- "ei",
20
- "en",
21
- "eng",
22
- "er",
23
- "f",
24
- "g",
25
- "h",
26
- "i",
27
- "i0",
28
- "ia",
29
- "ian",
30
- "iang",
31
- "iao",
32
- "ie",
33
- "in",
34
- "ing",
35
- "iong",
36
- "ir",
37
- "iu",
38
- "j",
39
- "k",
40
- "l",
41
- "m",
42
- "n",
43
- "o",
44
- "ong",
45
- "ou",
46
- "p",
47
- "q",
48
- "r",
49
- "s",
50
- "sh",
51
- "t",
52
- "u",
53
- "ua",
54
- "uai",
55
- "uan",
56
- "uang",
57
- "ui",
58
- "un",
59
- "uo",
60
- "v",
61
- "van",
62
- "ve",
63
- "vn",
64
- "w",
65
- "x",
66
- "y",
67
- "z",
68
- "zh",
69
- "AA",
70
- "EE",
71
- "OO",
72
- ]
73
- num_zh_tones = 6
74
-
75
- # japanese
76
- ja_symbols = [
77
- "N",
78
- "a",
79
- "a:",
80
- "b",
81
- "by",
82
- "ch",
83
- "d",
84
- "dy",
85
- "e",
86
- "e:",
87
- "f",
88
- "g",
89
- "gy",
90
- "h",
91
- "hy",
92
- "i",
93
- "i:",
94
- "j",
95
- "k",
96
- "ky",
97
- "m",
98
- "my",
99
- "n",
100
- "ny",
101
- "o",
102
- "o:",
103
- "p",
104
- "py",
105
- "q",
106
- "r",
107
- "ry",
108
- "s",
109
- "sh",
110
- "t",
111
- "ts",
112
- "ty",
113
- "u",
114
- "u:",
115
- "w",
116
- "y",
117
- "z",
118
- "zy",
119
- ]
120
- num_ja_tones = 2
121
-
122
- # English
123
- en_symbols = [
124
- "aa",
125
- "ae",
126
- "ah",
127
- "ao",
128
- "aw",
129
- "ay",
130
- "b",
131
- "ch",
132
- "d",
133
- "dh",
134
- "eh",
135
- "er",
136
- "ey",
137
- "f",
138
- "g",
139
- "hh",
140
- "ih",
141
- "iy",
142
- "jh",
143
- "k",
144
- "l",
145
- "m",
146
- "n",
147
- "ng",
148
- "ow",
149
- "oy",
150
- "p",
151
- "r",
152
- "s",
153
- "sh",
154
- "t",
155
- "th",
156
- "uh",
157
- "uw",
158
- "V",
159
- "w",
160
- "y",
161
- "z",
162
- "zh",
163
- ]
164
- num_en_tones = 4
165
-
166
- # combine all symbols
167
- normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
- symbols = [pad] + normal_symbols + pu_symbols
169
- sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
-
171
- # combine all tones
172
- num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
-
174
- # language maps
175
- language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
- num_languages = len(language_id_map.keys())
177
-
178
- language_tone_start_map = {
179
- "ZH": 0,
180
- "JP": num_zh_tones,
181
- "EN": num_zh_tones + num_ja_tones,
182
- }
183
-
184
- if __name__ == "__main__":
185
- a = set(zh_symbols)
186
- b = set(en_symbols)
187
- print(sorted(a & b))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
onnx_modules/__init__.py DELETED
@@ -1,50 +0,0 @@
1
- from utils import get_hparams_from_file, load_checkpoint
2
- import json
3
-
4
-
5
- def export_onnx(export_path, model_path, config_path):
6
- hps = get_hparams_from_file(config_path)
7
- version = hps.version[0:3]
8
- if version == "2.0":
9
- from .V200 import SynthesizerTrn, symbols
10
- elif version == "2.1":
11
- from .V210 import SynthesizerTrn, symbols
12
- elif version == "2.2":
13
- from .V220 import SynthesizerTrn, symbols
14
- net_g = SynthesizerTrn(
15
- len(symbols),
16
- hps.data.filter_length // 2 + 1,
17
- hps.train.segment_size // hps.data.hop_length,
18
- n_speakers=hps.data.n_speakers,
19
- **hps.model,
20
- )
21
- _ = net_g.eval()
22
- _ = load_checkpoint(model_path, net_g, None, skip_optimizer=True)
23
- net_g.cpu()
24
- net_g.export_onnx(export_path)
25
-
26
- spklist = []
27
- for key in hps.data.spk2id.keys():
28
- spklist.append(key)
29
-
30
- MoeVSConf = {
31
- "Folder": f"{export_path}",
32
- "Name": f"{export_path}",
33
- "Type": "BertVits",
34
- "Symbol": symbols,
35
- "Cleaner": "",
36
- "Rate": hps.data.sampling_rate,
37
- "CharaMix": True,
38
- "Characters": spklist,
39
- "LanguageMap": {"ZH": [0, 0], "JP": [1, 6], "EN": [2, 8]},
40
- "Dict": "BasicDict",
41
- "BertPath": [
42
- "chinese-roberta-wwm-ext-large",
43
- "deberta-v2-large-japanese",
44
- "bert-base-japanese-v3",
45
- ],
46
- "Clap": "clap-htsat-fused",
47
- }
48
-
49
- with open(f"onnx/{export_path}.json", "w") as MoeVsConfFile:
50
- json.dump(MoeVSConf, MoeVsConfFile, indent=4)