RUI-LONG commited on
Commit
5c904c4
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pt filter=lfs diff=lfs merge=lfs -text
37
+ *.json filter=lfs diff=lfs merge=lfs -text
38
+ *.index filter=lfs diff=lfs merge=lfs -text
39
+ *.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ *.mp3
4
+ hubert_base.pt
5
+ rmvpe.pt
6
+ memo.txt
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 litagin02
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Test Rvc
3
+ emoji: 📉
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ pinned: false
8
+ app_file: app.py
9
+ startup_duration_timeout: 45m
10
+ ---
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import datetime
3
+ import logging
4
+ import os
5
+ import time
6
+ import traceback
7
+
8
+ import edge_tts
9
+ import gradio as gr
10
+ import librosa
11
+
12
+ from src.rmvpe import RMVPE
13
+ from model_loader import ModelLoader
14
+
15
+ logging.getLogger("fairseq").setLevel(logging.WARNING)
16
+ logging.getLogger("numba").setLevel(logging.WARNING)
17
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
18
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
19
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
20
+
21
+ limitation = os.getenv("SYSTEM") == "spaces"
22
+
23
+ edge_output_filename = "edge_output.mp3"
24
+ tts_voice_list = asyncio.get_event_loop().run_until_complete(edge_tts.list_voices())
25
+ tts_voices = [f"{v['ShortName']}-{v['Gender']}" for v in tts_voice_list]
26
+
27
+ model_root = "weights"
28
+
29
+ print("Loading...")
30
+ model_loader = ModelLoader()
31
+ gpu_config = model_loader.config
32
+ hubert_model = model_loader.load_hubert()
33
+
34
+ rmvpe_model = RMVPE(
35
+ os.path.join(os.getcwd(), "weights", "rmvpe.pt"),
36
+ gpu_config.is_half,
37
+ gpu_config.device,
38
+ )
39
+
40
+ model_loader.load("char1")
41
+
42
+
43
+ def tts(
44
+ speed,
45
+ tts_text,
46
+ tts_voice,
47
+ f0_up_key,
48
+ f0_method="rmvpe",
49
+ index_rate=1,
50
+ protect=0.5,
51
+ filter_radius=3,
52
+ resample_sr=0,
53
+ rms_mix_rate=0.25,
54
+ ):
55
+ print("------------------")
56
+ print(datetime.datetime.now())
57
+ print("tts_text:")
58
+ print(tts_text)
59
+ print(f"tts_voice: {tts_voice}")
60
+ print(f"F0: {f0_method}, Key: {f0_up_key}, Index: {index_rate}, Protect: {protect}")
61
+ try:
62
+ if limitation and len(tts_text) > 280:
63
+ print("Error: Text too long")
64
+ return (
65
+ f"Text characters should be at most 280 in this huggingface space, but got {len(tts_text)} characters.",
66
+ None,
67
+ None,
68
+ )
69
+ tgt_sr, net_g, vc, version, index_file, if_f0 = (
70
+ model_loader.tgt_sr,
71
+ model_loader.net_g,
72
+ model_loader.vc,
73
+ model_loader.version,
74
+ model_loader.index_file,
75
+ model_loader.if_f0,
76
+ )
77
+ t0 = time.time()
78
+ if speed >= 0:
79
+ speed_str = f"+{speed}%"
80
+ else:
81
+ speed_str = f"{speed}%"
82
+ asyncio.run(
83
+ edge_tts.Communicate(
84
+ tts_text, "-".join(tts_voice.split("-")[:-1]), rate=speed_str
85
+ ).save(edge_output_filename)
86
+ )
87
+ t1 = time.time()
88
+ edge_time = t1 - t0
89
+ audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
90
+ duration = len(audio) / sr
91
+ print(f"Audio duration: {duration}s")
92
+ if limitation and duration >= 20:
93
+ print("Error: Audio too long")
94
+ return (
95
+ f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
96
+ edge_output_filename,
97
+ None,
98
+ )
99
+
100
+ f0_up_key = int(f0_up_key)
101
+
102
+ if f0_method == "rmvpe":
103
+ vc.model_rmvpe = rmvpe_model
104
+ times = [0, 0, 0]
105
+ audio_opt = vc.pipeline(
106
+ hubert_model,
107
+ net_g,
108
+ 0,
109
+ audio,
110
+ edge_output_filename,
111
+ times,
112
+ f0_up_key,
113
+ f0_method,
114
+ index_file,
115
+ # file_big_npy,
116
+ index_rate,
117
+ if_f0,
118
+ filter_radius,
119
+ tgt_sr,
120
+ resample_sr,
121
+ rms_mix_rate,
122
+ version,
123
+ protect,
124
+ None,
125
+ )
126
+ if tgt_sr != resample_sr >= 16000:
127
+ tgt_sr = resample_sr
128
+ info = f"Success. Time: edge-tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
129
+ print(info)
130
+ return (
131
+ info,
132
+ (tgt_sr, audio_opt),
133
+ )
134
+ except EOFError:
135
+ info = (
136
+ "It seems that the edge-tts output is not valid. "
137
+ "This may occur when the input text and the speaker do not match. "
138
+ "For example, maybe you entered Japanese (without alphabets) text but chose non-Japanese speaker?"
139
+ )
140
+ print(info)
141
+ return info, None
142
+ except:
143
+ info = traceback.format_exc()
144
+ print(info)
145
+ return info, None
146
+
147
+
148
+ initial_md = """
149
+ # Text-to-speech webui
150
+
151
+ This is a text-to-speech webui of RVC models.
152
+ """
153
+
154
+ app = gr.Blocks()
155
+ with app:
156
+ gr.Markdown(initial_md)
157
+ with gr.Row():
158
+ with gr.Column():
159
+ f0_key_up = gr.Number(
160
+ label="Transpose (the best value depends on the models and speakers)",
161
+ value=0,
162
+ )
163
+ with gr.Row():
164
+ with gr.Column():
165
+ tts_voice = gr.Dropdown(
166
+ label="speaker (format: language-Country-Name-Gender)",
167
+ choices=tts_voices,
168
+ allow_custom_value=False,
169
+ value="en-US-MichelleNeural-Female",
170
+ )
171
+ speed = gr.Slider(
172
+ minimum=-100,
173
+ maximum=100,
174
+ label="Speech speed (%)",
175
+ value=0,
176
+ step=10,
177
+ interactive=True,
178
+ )
179
+ tts_text = gr.Textbox(
180
+ label="Input Text",
181
+ value="Nova says: Happy New Year, yaaaaaaaaa",
182
+ )
183
+ with gr.Column():
184
+ but0 = gr.Button("Convert", variant="primary")
185
+ info_text = gr.Textbox(label="Output info")
186
+ with gr.Column():
187
+ tts_output = gr.Audio(label="Result")
188
+ but0.click(
189
+ tts,
190
+ [
191
+ speed,
192
+ tts_text,
193
+ tts_voice,
194
+ f0_key_up,
195
+ ],
196
+ [info_text, tts_output],
197
+ )
198
+ with gr.Row():
199
+ examples = gr.Examples(
200
+ examples_per_page=10,
201
+ examples=[
202
+ [
203
+ "これは日本語テキストから音声への変換デモです。",
204
+ "ja-JP-NanamiNeural-Female",
205
+ ],
206
+ [
207
+ "This is an English text to speech conversation demo.",
208
+ "en-US-AriaNeural-Female",
209
+ ],
210
+ ["這是用來測試的demo啦", "zh-TW-HsiaoChenNeural-Female"],
211
+ ["这是一个中文文本到语音的转换演示。", "zh-CN-XiaoxiaoNeural-Female"],
212
+ [
213
+ "한국어 텍스트에서 음성으로 변환하는 데모입니다.",
214
+ "ko-KR-SunHiNeural-Female",
215
+ ],
216
+ [
217
+ "Il s'agit d'une démo de conversion du texte français à la parole.",
218
+ "fr-FR-DeniseNeural-Female",
219
+ ],
220
+ [
221
+ "Dies ist eine Demo zur Umwandlung von Deutsch in Sprache.",
222
+ "de-DE-AmalaNeural-Female",
223
+ ],
224
+ [
225
+ "Tämä on suomenkielinen tekstistä puheeksi -esittely.",
226
+ "fi-FI-NooraNeural-Female",
227
+ ],
228
+ [
229
+ "Это демонстрационный пример преобразования русского текста в речь.",
230
+ "ru-RU-SvetlanaNeural-Female",
231
+ ],
232
+ [
233
+ "Αυτή είναι μια επίδειξη μετατροπής ελληνικού κειμένου σε ομιλία.",
234
+ "el-GR-AthinaNeural-Female",
235
+ ],
236
+ [
237
+ "Esta es una demostración de conversión de texto a voz en español.",
238
+ "es-ES-ElviraNeural-Female",
239
+ ],
240
+ [
241
+ "Questa è una dimostrazione di sintesi vocale in italiano.",
242
+ "it-IT-ElsaNeural-Female",
243
+ ],
244
+ [
245
+ "Esta é uma demonstração de conversão de texto em fala em português.",
246
+ "pt-PT-RaquelNeural-Female",
247
+ ],
248
+ [
249
+ "Це демонстрація тексту до мовлення українською мовою.",
250
+ "uk-UA-PolinaNeural-Female",
251
+ ],
252
+ [
253
+ "هذا عرض توضيحي عربي لتحويل النص إلى كلام.",
254
+ "ar-EG-SalmaNeural-Female",
255
+ ],
256
+ [
257
+ "இது தமிழ் உரையிலிருந்து பேச்சு மாற்ற டெமோ.",
258
+ "ta-IN-PallaviNeural-Female",
259
+ ],
260
+ ],
261
+ inputs=[tts_text, tts_voice],
262
+ )
263
+
264
+ app.launch(inbrowser=True)
lib/attentions.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from lib import commons
9
+ from lib.modules import LayerNorm
10
+
11
+
12
+ class Encoder(nn.Module):
13
+ def __init__(
14
+ self,
15
+ hidden_channels,
16
+ filter_channels,
17
+ n_heads,
18
+ n_layers,
19
+ kernel_size=1,
20
+ p_dropout=0.0,
21
+ window_size=10,
22
+ **kwargs
23
+ ):
24
+ super().__init__()
25
+ self.hidden_channels = hidden_channels
26
+ self.filter_channels = filter_channels
27
+ self.n_heads = n_heads
28
+ self.n_layers = n_layers
29
+ self.kernel_size = kernel_size
30
+ self.p_dropout = p_dropout
31
+ self.window_size = window_size
32
+
33
+ self.drop = nn.Dropout(p_dropout)
34
+ self.attn_layers = nn.ModuleList()
35
+ self.norm_layers_1 = nn.ModuleList()
36
+ self.ffn_layers = nn.ModuleList()
37
+ self.norm_layers_2 = nn.ModuleList()
38
+ for i in range(self.n_layers):
39
+ self.attn_layers.append(
40
+ MultiHeadAttention(
41
+ hidden_channels,
42
+ hidden_channels,
43
+ n_heads,
44
+ p_dropout=p_dropout,
45
+ window_size=window_size,
46
+ )
47
+ )
48
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
49
+ self.ffn_layers.append(
50
+ FFN(
51
+ hidden_channels,
52
+ hidden_channels,
53
+ filter_channels,
54
+ kernel_size,
55
+ p_dropout=p_dropout,
56
+ )
57
+ )
58
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
59
+
60
+ def forward(self, x, x_mask):
61
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
62
+ x = x * x_mask
63
+ for i in range(self.n_layers):
64
+ y = self.attn_layers[i](x, x, attn_mask)
65
+ y = self.drop(y)
66
+ x = self.norm_layers_1[i](x + y)
67
+
68
+ y = self.ffn_layers[i](x, x_mask)
69
+ y = self.drop(y)
70
+ x = self.norm_layers_2[i](x + y)
71
+ x = x * x_mask
72
+ return x
73
+
74
+
75
+ class Decoder(nn.Module):
76
+ def __init__(
77
+ self,
78
+ hidden_channels,
79
+ filter_channels,
80
+ n_heads,
81
+ n_layers,
82
+ kernel_size=1,
83
+ p_dropout=0.0,
84
+ proximal_bias=False,
85
+ proximal_init=True,
86
+ **kwargs
87
+ ):
88
+ super().__init__()
89
+ self.hidden_channels = hidden_channels
90
+ self.filter_channels = filter_channels
91
+ self.n_heads = n_heads
92
+ self.n_layers = n_layers
93
+ self.kernel_size = kernel_size
94
+ self.p_dropout = p_dropout
95
+ self.proximal_bias = proximal_bias
96
+ self.proximal_init = proximal_init
97
+
98
+ self.drop = nn.Dropout(p_dropout)
99
+ self.self_attn_layers = nn.ModuleList()
100
+ self.norm_layers_0 = nn.ModuleList()
101
+ self.encdec_attn_layers = nn.ModuleList()
102
+ self.norm_layers_1 = nn.ModuleList()
103
+ self.ffn_layers = nn.ModuleList()
104
+ self.norm_layers_2 = nn.ModuleList()
105
+ for i in range(self.n_layers):
106
+ self.self_attn_layers.append(
107
+ MultiHeadAttention(
108
+ hidden_channels,
109
+ hidden_channels,
110
+ n_heads,
111
+ p_dropout=p_dropout,
112
+ proximal_bias=proximal_bias,
113
+ proximal_init=proximal_init,
114
+ )
115
+ )
116
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
117
+ self.encdec_attn_layers.append(
118
+ MultiHeadAttention(
119
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
120
+ )
121
+ )
122
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
123
+ self.ffn_layers.append(
124
+ FFN(
125
+ hidden_channels,
126
+ hidden_channels,
127
+ filter_channels,
128
+ kernel_size,
129
+ p_dropout=p_dropout,
130
+ causal=True,
131
+ )
132
+ )
133
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
134
+
135
+ def forward(self, x, x_mask, h, h_mask):
136
+ """
137
+ x: decoder input
138
+ h: encoder output
139
+ """
140
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
141
+ device=x.device, dtype=x.dtype
142
+ )
143
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
144
+ x = x * x_mask
145
+ for i in range(self.n_layers):
146
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
147
+ y = self.drop(y)
148
+ x = self.norm_layers_0[i](x + y)
149
+
150
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
151
+ y = self.drop(y)
152
+ x = self.norm_layers_1[i](x + y)
153
+
154
+ y = self.ffn_layers[i](x, x_mask)
155
+ y = self.drop(y)
156
+ x = self.norm_layers_2[i](x + y)
157
+ x = x * x_mask
158
+ return x
159
+
160
+
161
+ class MultiHeadAttention(nn.Module):
162
+ def __init__(
163
+ self,
164
+ channels,
165
+ out_channels,
166
+ n_heads,
167
+ p_dropout=0.0,
168
+ window_size=None,
169
+ heads_share=True,
170
+ block_length=None,
171
+ proximal_bias=False,
172
+ proximal_init=False,
173
+ ):
174
+ super().__init__()
175
+ assert channels % n_heads == 0
176
+
177
+ self.channels = channels
178
+ self.out_channels = out_channels
179
+ self.n_heads = n_heads
180
+ self.p_dropout = p_dropout
181
+ self.window_size = window_size
182
+ self.heads_share = heads_share
183
+ self.block_length = block_length
184
+ self.proximal_bias = proximal_bias
185
+ self.proximal_init = proximal_init
186
+ self.attn = None
187
+
188
+ self.k_channels = channels // n_heads
189
+ self.conv_q = nn.Conv1d(channels, channels, 1)
190
+ self.conv_k = nn.Conv1d(channels, channels, 1)
191
+ self.conv_v = nn.Conv1d(channels, channels, 1)
192
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
193
+ self.drop = nn.Dropout(p_dropout)
194
+
195
+ if window_size is not None:
196
+ n_heads_rel = 1 if heads_share else n_heads
197
+ rel_stddev = self.k_channels**-0.5
198
+ self.emb_rel_k = nn.Parameter(
199
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
200
+ * rel_stddev
201
+ )
202
+ self.emb_rel_v = nn.Parameter(
203
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
204
+ * rel_stddev
205
+ )
206
+
207
+ nn.init.xavier_uniform_(self.conv_q.weight)
208
+ nn.init.xavier_uniform_(self.conv_k.weight)
209
+ nn.init.xavier_uniform_(self.conv_v.weight)
210
+ if proximal_init:
211
+ with torch.no_grad():
212
+ self.conv_k.weight.copy_(self.conv_q.weight)
213
+ self.conv_k.bias.copy_(self.conv_q.bias)
214
+
215
+ def forward(self, x, c, attn_mask=None):
216
+ q = self.conv_q(x)
217
+ k = self.conv_k(c)
218
+ v = self.conv_v(c)
219
+
220
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
221
+
222
+ x = self.conv_o(x)
223
+ return x
224
+
225
+ def attention(self, query, key, value, mask=None):
226
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
227
+ b, d, t_s, t_t = (*key.size(), query.size(2))
228
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
229
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
230
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
+
232
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
233
+ if self.window_size is not None:
234
+ assert (
235
+ t_s == t_t
236
+ ), "Relative attention is only available for self-attention."
237
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
238
+ rel_logits = self._matmul_with_relative_keys(
239
+ query / math.sqrt(self.k_channels), key_relative_embeddings
240
+ )
241
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
242
+ scores = scores + scores_local
243
+ if self.proximal_bias:
244
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
245
+ scores = scores + self._attention_bias_proximal(t_s).to(
246
+ device=scores.device, dtype=scores.dtype
247
+ )
248
+ if mask is not None:
249
+ scores = scores.masked_fill(mask == 0, -1e4)
250
+ if self.block_length is not None:
251
+ assert (
252
+ t_s == t_t
253
+ ), "Local attention is only available for self-attention."
254
+ block_mask = (
255
+ torch.ones_like(scores)
256
+ .triu(-self.block_length)
257
+ .tril(self.block_length)
258
+ )
259
+ scores = scores.masked_fill(block_mask == 0, -1e4)
260
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
261
+ p_attn = self.drop(p_attn)
262
+ output = torch.matmul(p_attn, value)
263
+ if self.window_size is not None:
264
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
265
+ value_relative_embeddings = self._get_relative_embeddings(
266
+ self.emb_rel_v, t_s
267
+ )
268
+ output = output + self._matmul_with_relative_values(
269
+ relative_weights, value_relative_embeddings
270
+ )
271
+ output = (
272
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
273
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
274
+ return output, p_attn
275
+
276
+ def _matmul_with_relative_values(self, x, y):
277
+ """
278
+ x: [b, h, l, m]
279
+ y: [h or 1, m, d]
280
+ ret: [b, h, l, d]
281
+ """
282
+ ret = torch.matmul(x, y.unsqueeze(0))
283
+ return ret
284
+
285
+ def _matmul_with_relative_keys(self, x, y):
286
+ """
287
+ x: [b, h, l, d]
288
+ y: [h or 1, m, d]
289
+ ret: [b, h, l, m]
290
+ """
291
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
292
+ return ret
293
+
294
+ def _get_relative_embeddings(self, relative_embeddings, length):
295
+ max_relative_position = 2 * self.window_size + 1
296
+ # Pad first before slice to avoid using cond ops.
297
+ pad_length = max(length - (self.window_size + 1), 0)
298
+ slice_start_position = max((self.window_size + 1) - length, 0)
299
+ slice_end_position = slice_start_position + 2 * length - 1
300
+ if pad_length > 0:
301
+ padded_relative_embeddings = F.pad(
302
+ relative_embeddings,
303
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
304
+ )
305
+ else:
306
+ padded_relative_embeddings = relative_embeddings
307
+ used_relative_embeddings = padded_relative_embeddings[
308
+ :, slice_start_position:slice_end_position
309
+ ]
310
+ return used_relative_embeddings
311
+
312
+ def _relative_position_to_absolute_position(self, x):
313
+ """
314
+ x: [b, h, l, 2*l-1]
315
+ ret: [b, h, l, l]
316
+ """
317
+ batch, heads, length, _ = x.size()
318
+ # Concat columns of pad to shift from relative to absolute indexing.
319
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
320
+
321
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
322
+ x_flat = x.view([batch, heads, length * 2 * length])
323
+ x_flat = F.pad(
324
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
325
+ )
326
+
327
+ # Reshape and slice out the padded elements.
328
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
329
+ :, :, :length, length - 1 :
330
+ ]
331
+ return x_final
332
+
333
+ def _absolute_position_to_relative_position(self, x):
334
+ """
335
+ x: [b, h, l, l]
336
+ ret: [b, h, l, 2*l-1]
337
+ """
338
+ batch, heads, length, _ = x.size()
339
+ # padd along column
340
+ x = F.pad(
341
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
342
+ )
343
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
344
+ # add 0's in the beginning that will skew the elements after reshape
345
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
346
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
347
+ return x_final
348
+
349
+ def _attention_bias_proximal(self, length):
350
+ """Bias for self-attention to encourage attention to close positions.
351
+ Args:
352
+ length: an integer scalar.
353
+ Returns:
354
+ a Tensor with shape [1, 1, length, length]
355
+ """
356
+ r = torch.arange(length, dtype=torch.float32)
357
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
358
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
359
+
360
+
361
+ class FFN(nn.Module):
362
+ def __init__(
363
+ self,
364
+ in_channels,
365
+ out_channels,
366
+ filter_channels,
367
+ kernel_size,
368
+ p_dropout=0.0,
369
+ activation=None,
370
+ causal=False,
371
+ ):
372
+ super().__init__()
373
+ self.in_channels = in_channels
374
+ self.out_channels = out_channels
375
+ self.filter_channels = filter_channels
376
+ self.kernel_size = kernel_size
377
+ self.p_dropout = p_dropout
378
+ self.activation = activation
379
+ self.causal = causal
380
+
381
+ if causal:
382
+ self.padding = self._causal_padding
383
+ else:
384
+ self.padding = self._same_padding
385
+
386
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
387
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
388
+ self.drop = nn.Dropout(p_dropout)
389
+
390
+ def forward(self, x, x_mask):
391
+ x = self.conv_1(self.padding(x * x_mask))
392
+ if self.activation == "gelu":
393
+ x = x * torch.sigmoid(1.702 * x)
394
+ else:
395
+ x = torch.relu(x)
396
+ x = self.drop(x)
397
+ x = self.conv_2(self.padding(x * x_mask))
398
+ return x * x_mask
399
+
400
+ def _causal_padding(self, x):
401
+ if self.kernel_size == 1:
402
+ return x
403
+ pad_l = self.kernel_size - 1
404
+ pad_r = 0
405
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
406
+ x = F.pad(x, commons.convert_pad_shape(padding))
407
+ return x
408
+
409
+ def _same_padding(self, x):
410
+ if self.kernel_size == 1:
411
+ return x
412
+ pad_l = (self.kernel_size - 1) // 2
413
+ pad_r = self.kernel_size // 2
414
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
415
+ x = F.pad(x, commons.convert_pad_shape(padding))
416
+ return x
lib/commons.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ l = pad_shape[::-1]
18
+ pad_shape = [item for sublist in l for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
23
+ """KL(P||Q)"""
24
+ kl = (logs_q - logs_p) - 0.5
25
+ kl += (
26
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
27
+ )
28
+ return kl
29
+
30
+
31
+ def rand_gumbel(shape):
32
+ """Sample from the Gumbel distribution, protect from overflows."""
33
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
34
+ return -torch.log(-torch.log(uniform_samples))
35
+
36
+
37
+ def rand_gumbel_like(x):
38
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
39
+ return g
40
+
41
+
42
+ def slice_segments(x, ids_str, segment_size=4):
43
+ ret = torch.zeros_like(x[:, :, :segment_size])
44
+ for i in range(x.size(0)):
45
+ idx_str = ids_str[i]
46
+ idx_end = idx_str + segment_size
47
+ ret[i] = x[i, :, idx_str:idx_end]
48
+ return ret
49
+
50
+
51
+ def slice_segments2(x, ids_str, segment_size=4):
52
+ ret = torch.zeros_like(x[:, :segment_size])
53
+ for i in range(x.size(0)):
54
+ idx_str = ids_str[i]
55
+ idx_end = idx_str + segment_size
56
+ ret[i] = x[i, idx_str:idx_end]
57
+ return ret
58
+
59
+
60
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
61
+ b, d, t = x.size()
62
+ if x_lengths is None:
63
+ x_lengths = t
64
+ ids_str_max = x_lengths - segment_size + 1
65
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
66
+ ret = slice_segments(x, ids_str, segment_size)
67
+ return ret, ids_str
68
+
69
+
70
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
71
+ position = torch.arange(length, dtype=torch.float)
72
+ num_timescales = channels // 2
73
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
74
+ num_timescales - 1
75
+ )
76
+ inv_timescales = min_timescale * torch.exp(
77
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
78
+ )
79
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
80
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
81
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
82
+ signal = signal.view(1, channels, length)
83
+ return signal
84
+
85
+
86
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
87
+ b, channels, length = x.size()
88
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
89
+ return x + signal.to(dtype=x.dtype, device=x.device)
90
+
91
+
92
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
93
+ b, channels, length = x.size()
94
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
95
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
96
+
97
+
98
+ def subsequent_mask(length):
99
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
100
+ return mask
101
+
102
+
103
+ @torch.jit.script
104
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
105
+ n_channels_int = n_channels[0]
106
+ in_act = input_a + input_b
107
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
108
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
109
+ acts = t_act * s_act
110
+ return acts
111
+
112
+
113
+ def convert_pad_shape(pad_shape):
114
+ l = pad_shape[::-1]
115
+ pad_shape = [item for sublist in l for item in sublist]
116
+ return pad_shape
117
+
118
+
119
+ def shift_1d(x):
120
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
121
+ return x
122
+
123
+
124
+ def sequence_mask(length, max_length=None):
125
+ if max_length is None:
126
+ max_length = length.max()
127
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
128
+ return x.unsqueeze(0) < length.unsqueeze(1)
129
+
130
+
131
+ def generate_path(duration, mask):
132
+ """
133
+ duration: [b, 1, t_x]
134
+ mask: [b, 1, t_y, t_x]
135
+ """
136
+ device = duration.device
137
+
138
+ b, _, t_y, t_x = mask.shape
139
+ cum_duration = torch.cumsum(duration, -1)
140
+
141
+ cum_duration_flat = cum_duration.view(b * t_x)
142
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
143
+ path = path.view(b, t_x, t_y)
144
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
145
+ path = path.unsqueeze(1).transpose(2, 3) * mask
146
+ return path
147
+
148
+
149
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
150
+ if isinstance(parameters, torch.Tensor):
151
+ parameters = [parameters]
152
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
153
+ norm_type = float(norm_type)
154
+ if clip_value is not None:
155
+ clip_value = float(clip_value)
156
+
157
+ total_norm = 0
158
+ for p in parameters:
159
+ param_norm = p.grad.data.norm(norm_type)
160
+ total_norm += param_norm.item() ** norm_type
161
+ if clip_value is not None:
162
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
163
+ total_norm = total_norm ** (1.0 / norm_type)
164
+ return total_norm
lib/models.py ADDED
@@ -0,0 +1,1145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from lib.commons import init_weights, get_padding
6
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
7
+ from torch.nn.utils import remove_weight_norm, spectral_norm
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+ from lib.commons import init_weights
10
+ import numpy as np
11
+ from lib import modules
12
+ from lib import attentions
13
+ from lib import commons
14
+
15
+
16
+ class TextEncoder256(nn.Module):
17
+ def __init__(
18
+ self,
19
+ out_channels,
20
+ hidden_channels,
21
+ filter_channels,
22
+ n_heads,
23
+ n_layers,
24
+ kernel_size,
25
+ p_dropout,
26
+ f0=True,
27
+ ):
28
+ super().__init__()
29
+ self.out_channels = out_channels
30
+ self.hidden_channels = hidden_channels
31
+ self.filter_channels = filter_channels
32
+ self.n_heads = n_heads
33
+ self.n_layers = n_layers
34
+ self.kernel_size = kernel_size
35
+ self.p_dropout = p_dropout
36
+ self.emb_phone = nn.Linear(256, hidden_channels)
37
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
38
+ if f0 == True:
39
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
40
+ self.encoder = attentions.Encoder(
41
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
42
+ )
43
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
44
+
45
+ def forward(self, phone, pitch, lengths):
46
+ if pitch == None:
47
+ x = self.emb_phone(phone)
48
+ else:
49
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
50
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
51
+ x = self.lrelu(x)
52
+ x = torch.transpose(x, 1, -1) # [b, h, t]
53
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
54
+ x.dtype
55
+ )
56
+ x = self.encoder(x * x_mask, x_mask)
57
+ stats = self.proj(x) * x_mask
58
+
59
+ m, logs = torch.split(stats, self.out_channels, dim=1)
60
+ return m, logs, x_mask
61
+
62
+
63
+ class TextEncoder768(nn.Module):
64
+ def __init__(
65
+ self,
66
+ out_channels,
67
+ hidden_channels,
68
+ filter_channels,
69
+ n_heads,
70
+ n_layers,
71
+ kernel_size,
72
+ p_dropout,
73
+ f0=True,
74
+ ):
75
+ super().__init__()
76
+ self.out_channels = out_channels
77
+ self.hidden_channels = hidden_channels
78
+ self.filter_channels = filter_channels
79
+ self.n_heads = n_heads
80
+ self.n_layers = n_layers
81
+ self.kernel_size = kernel_size
82
+ self.p_dropout = p_dropout
83
+ self.emb_phone = nn.Linear(768, hidden_channels)
84
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
85
+ if f0 == True:
86
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
87
+ self.encoder = attentions.Encoder(
88
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
89
+ )
90
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
91
+
92
+ def forward(self, phone, pitch, lengths):
93
+ if pitch == None:
94
+ x = self.emb_phone(phone)
95
+ else:
96
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
97
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
98
+ x = self.lrelu(x)
99
+ x = torch.transpose(x, 1, -1) # [b, h, t]
100
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
101
+ x.dtype
102
+ )
103
+ x = self.encoder(x * x_mask, x_mask)
104
+ stats = self.proj(x) * x_mask
105
+
106
+ m, logs = torch.split(stats, self.out_channels, dim=1)
107
+ return m, logs, x_mask
108
+
109
+
110
+ class ResidualCouplingBlock(nn.Module):
111
+ def __init__(
112
+ self,
113
+ channels,
114
+ hidden_channels,
115
+ kernel_size,
116
+ dilation_rate,
117
+ n_layers,
118
+ n_flows=4,
119
+ gin_channels=0,
120
+ ):
121
+ super().__init__()
122
+ self.channels = channels
123
+ self.hidden_channels = hidden_channels
124
+ self.kernel_size = kernel_size
125
+ self.dilation_rate = dilation_rate
126
+ self.n_layers = n_layers
127
+ self.n_flows = n_flows
128
+ self.gin_channels = gin_channels
129
+
130
+ self.flows = nn.ModuleList()
131
+ for i in range(n_flows):
132
+ self.flows.append(
133
+ modules.ResidualCouplingLayer(
134
+ channels,
135
+ hidden_channels,
136
+ kernel_size,
137
+ dilation_rate,
138
+ n_layers,
139
+ gin_channels=gin_channels,
140
+ mean_only=True,
141
+ )
142
+ )
143
+ self.flows.append(modules.Flip())
144
+
145
+ def forward(self, x, x_mask, g=None, reverse=False):
146
+ if not reverse:
147
+ for flow in self.flows:
148
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
149
+ else:
150
+ for flow in reversed(self.flows):
151
+ x = flow(x, x_mask, g=g, reverse=reverse)
152
+ return x
153
+
154
+ def remove_weight_norm(self):
155
+ for i in range(self.n_flows):
156
+ self.flows[i * 2].remove_weight_norm()
157
+
158
+
159
+ class PosteriorEncoder(nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ out_channels,
164
+ hidden_channels,
165
+ kernel_size,
166
+ dilation_rate,
167
+ n_layers,
168
+ gin_channels=0,
169
+ ):
170
+ super().__init__()
171
+ self.in_channels = in_channels
172
+ self.out_channels = out_channels
173
+ self.hidden_channels = hidden_channels
174
+ self.kernel_size = kernel_size
175
+ self.dilation_rate = dilation_rate
176
+ self.n_layers = n_layers
177
+ self.gin_channels = gin_channels
178
+
179
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
180
+ self.enc = modules.WN(
181
+ hidden_channels,
182
+ kernel_size,
183
+ dilation_rate,
184
+ n_layers,
185
+ gin_channels=gin_channels,
186
+ )
187
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
188
+
189
+ def forward(self, x, x_lengths, g=None):
190
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
191
+ x.dtype
192
+ )
193
+ x = self.pre(x) * x_mask
194
+ x = self.enc(x, x_mask, g=g)
195
+ stats = self.proj(x) * x_mask
196
+ m, logs = torch.split(stats, self.out_channels, dim=1)
197
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
198
+ return z, m, logs, x_mask
199
+
200
+ def remove_weight_norm(self):
201
+ self.enc.remove_weight_norm()
202
+
203
+
204
+ class Generator(torch.nn.Module):
205
+ def __init__(
206
+ self,
207
+ initial_channel,
208
+ resblock,
209
+ resblock_kernel_sizes,
210
+ resblock_dilation_sizes,
211
+ upsample_rates,
212
+ upsample_initial_channel,
213
+ upsample_kernel_sizes,
214
+ gin_channels=0,
215
+ ):
216
+ super(Generator, self).__init__()
217
+ self.num_kernels = len(resblock_kernel_sizes)
218
+ self.num_upsamples = len(upsample_rates)
219
+ self.conv_pre = Conv1d(
220
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
221
+ )
222
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
223
+
224
+ self.ups = nn.ModuleList()
225
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
226
+ self.ups.append(
227
+ weight_norm(
228
+ ConvTranspose1d(
229
+ upsample_initial_channel // (2**i),
230
+ upsample_initial_channel // (2 ** (i + 1)),
231
+ k,
232
+ u,
233
+ padding=(k - u) // 2,
234
+ )
235
+ )
236
+ )
237
+
238
+ self.resblocks = nn.ModuleList()
239
+ for i in range(len(self.ups)):
240
+ ch = upsample_initial_channel // (2 ** (i + 1))
241
+ for j, (k, d) in enumerate(
242
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
243
+ ):
244
+ self.resblocks.append(resblock(ch, k, d))
245
+
246
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
247
+ self.ups.apply(init_weights)
248
+
249
+ if gin_channels != 0:
250
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
251
+
252
+ def forward(self, x, g=None):
253
+ x = self.conv_pre(x)
254
+ if g is not None:
255
+ x = x + self.cond(g)
256
+
257
+ for i in range(self.num_upsamples):
258
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
259
+ x = self.ups[i](x)
260
+ xs = None
261
+ for j in range(self.num_kernels):
262
+ if xs is None:
263
+ xs = self.resblocks[i * self.num_kernels + j](x)
264
+ else:
265
+ xs += self.resblocks[i * self.num_kernels + j](x)
266
+ x = xs / self.num_kernels
267
+ x = F.leaky_relu(x)
268
+ x = self.conv_post(x)
269
+ x = torch.tanh(x)
270
+
271
+ return x
272
+
273
+ def remove_weight_norm(self):
274
+ for l in self.ups:
275
+ remove_weight_norm(l)
276
+ for l in self.resblocks:
277
+ l.remove_weight_norm()
278
+
279
+
280
+ class SineGen(torch.nn.Module):
281
+ """Definition of sine generator
282
+ SineGen(samp_rate, harmonic_num = 0,
283
+ sine_amp = 0.1, noise_std = 0.003,
284
+ voiced_threshold = 0,
285
+ flag_for_pulse=False)
286
+ samp_rate: sampling rate in Hz
287
+ harmonic_num: number of harmonic overtones (default 0)
288
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
289
+ noise_std: std of Gaussian noise (default 0.003)
290
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
291
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
292
+ Note: when flag_for_pulse is True, the first time step of a voiced
293
+ segment is always sin(np.pi) or cos(0)
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ samp_rate,
299
+ harmonic_num=0,
300
+ sine_amp=0.1,
301
+ noise_std=0.003,
302
+ voiced_threshold=0,
303
+ flag_for_pulse=False,
304
+ ):
305
+ super(SineGen, self).__init__()
306
+ self.sine_amp = sine_amp
307
+ self.noise_std = noise_std
308
+ self.harmonic_num = harmonic_num
309
+ self.dim = self.harmonic_num + 1
310
+ self.sampling_rate = samp_rate
311
+ self.voiced_threshold = voiced_threshold
312
+
313
+ def _f02uv(self, f0):
314
+ # generate uv signal
315
+ uv = torch.ones_like(f0)
316
+ uv = uv * (f0 > self.voiced_threshold)
317
+ return uv
318
+
319
+ def forward(self, f0, upp):
320
+ """sine_tensor, uv = forward(f0)
321
+ input F0: tensor(batchsize=1, length, dim=1)
322
+ f0 for unvoiced steps should be 0
323
+ output sine_tensor: tensor(batchsize=1, length, dim)
324
+ output uv: tensor(batchsize=1, length, 1)
325
+ """
326
+ with torch.no_grad():
327
+ f0 = f0[:, None].transpose(1, 2)
328
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
329
+ # fundamental component
330
+ f0_buf[:, :, 0] = f0[:, :, 0]
331
+ for idx in np.arange(self.harmonic_num):
332
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
333
+ idx + 2
334
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
335
+ rad_values = (
336
+ f0_buf / self.sampling_rate
337
+ ) % 1 ###%1意味着n_har的乘积无法后处理优化
338
+ rand_ini = torch.rand(
339
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
340
+ )
341
+ rand_ini[:, 0] = 0
342
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
343
+ tmp_over_one = torch.cumsum(
344
+ rad_values, 1
345
+ ) # % 1 #####%1意味着后面的cumsum无法再优化
346
+ tmp_over_one *= upp
347
+ tmp_over_one = F.interpolate(
348
+ tmp_over_one.transpose(2, 1),
349
+ scale_factor=upp,
350
+ mode="linear",
351
+ align_corners=True,
352
+ ).transpose(2, 1)
353
+ rad_values = F.interpolate(
354
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
355
+ ).transpose(
356
+ 2, 1
357
+ ) #######
358
+ tmp_over_one %= 1
359
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
360
+ cumsum_shift = torch.zeros_like(rad_values)
361
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
362
+ sine_waves = torch.sin(
363
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
364
+ )
365
+ sine_waves = sine_waves * self.sine_amp
366
+ uv = self._f02uv(f0)
367
+ uv = F.interpolate(
368
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
369
+ ).transpose(2, 1)
370
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
371
+ noise = noise_amp * torch.randn_like(sine_waves)
372
+ sine_waves = sine_waves * uv + noise
373
+ return sine_waves, uv, noise
374
+
375
+
376
+ class SourceModuleHnNSF(torch.nn.Module):
377
+ """SourceModule for hn-nsf
378
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
379
+ add_noise_std=0.003, voiced_threshod=0)
380
+ sampling_rate: sampling_rate in Hz
381
+ harmonic_num: number of harmonic above F0 (default: 0)
382
+ sine_amp: amplitude of sine source signal (default: 0.1)
383
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
384
+ note that amplitude of noise in unvoiced is decided
385
+ by sine_amp
386
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
387
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
388
+ F0_sampled (batchsize, length, 1)
389
+ Sine_source (batchsize, length, 1)
390
+ noise_source (batchsize, length 1)
391
+ uv (batchsize, length, 1)
392
+ """
393
+
394
+ def __init__(
395
+ self,
396
+ sampling_rate,
397
+ harmonic_num=0,
398
+ sine_amp=0.1,
399
+ add_noise_std=0.003,
400
+ voiced_threshod=0,
401
+ is_half=True,
402
+ ):
403
+ super(SourceModuleHnNSF, self).__init__()
404
+
405
+ self.sine_amp = sine_amp
406
+ self.noise_std = add_noise_std
407
+ self.is_half = is_half
408
+ # to produce sine waveforms
409
+ self.l_sin_gen = SineGen(
410
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
411
+ )
412
+
413
+ # to merge source harmonics into a single excitation
414
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
415
+ self.l_tanh = torch.nn.Tanh()
416
+
417
+ def forward(self, x, upp=None):
418
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
419
+ if self.is_half:
420
+ sine_wavs = sine_wavs.half()
421
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
422
+ return sine_merge, None, None # noise, uv
423
+
424
+
425
+ class GeneratorNSF(torch.nn.Module):
426
+ def __init__(
427
+ self,
428
+ initial_channel,
429
+ resblock,
430
+ resblock_kernel_sizes,
431
+ resblock_dilation_sizes,
432
+ upsample_rates,
433
+ upsample_initial_channel,
434
+ upsample_kernel_sizes,
435
+ gin_channels,
436
+ sr,
437
+ is_half=False,
438
+ ):
439
+ super(GeneratorNSF, self).__init__()
440
+ self.num_kernels = len(resblock_kernel_sizes)
441
+ self.num_upsamples = len(upsample_rates)
442
+
443
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
444
+ self.m_source = SourceModuleHnNSF(
445
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
446
+ )
447
+ self.noise_convs = nn.ModuleList()
448
+ self.conv_pre = Conv1d(
449
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
450
+ )
451
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
452
+
453
+ self.ups = nn.ModuleList()
454
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
455
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
456
+ self.ups.append(
457
+ weight_norm(
458
+ ConvTranspose1d(
459
+ upsample_initial_channel // (2**i),
460
+ upsample_initial_channel // (2 ** (i + 1)),
461
+ k,
462
+ u,
463
+ padding=(k - u) // 2,
464
+ )
465
+ )
466
+ )
467
+ if i + 1 < len(upsample_rates):
468
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
469
+ self.noise_convs.append(
470
+ Conv1d(
471
+ 1,
472
+ c_cur,
473
+ kernel_size=stride_f0 * 2,
474
+ stride=stride_f0,
475
+ padding=stride_f0 // 2,
476
+ )
477
+ )
478
+ else:
479
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
480
+
481
+ self.resblocks = nn.ModuleList()
482
+ for i in range(len(self.ups)):
483
+ ch = upsample_initial_channel // (2 ** (i + 1))
484
+ for j, (k, d) in enumerate(
485
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
486
+ ):
487
+ self.resblocks.append(resblock(ch, k, d))
488
+
489
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
490
+ self.ups.apply(init_weights)
491
+
492
+ if gin_channels != 0:
493
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
494
+
495
+ self.upp = np.prod(upsample_rates)
496
+
497
+ def forward(self, x, f0, g=None):
498
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
499
+ har_source = har_source.transpose(1, 2)
500
+ x = self.conv_pre(x)
501
+ if g is not None:
502
+ x = x + self.cond(g)
503
+
504
+ for i in range(self.num_upsamples):
505
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
506
+ x = self.ups[i](x)
507
+ x_source = self.noise_convs[i](har_source)
508
+ x = x + x_source
509
+ xs = None
510
+ for j in range(self.num_kernels):
511
+ if xs is None:
512
+ xs = self.resblocks[i * self.num_kernels + j](x)
513
+ else:
514
+ xs += self.resblocks[i * self.num_kernels + j](x)
515
+ x = xs / self.num_kernels
516
+ x = F.leaky_relu(x)
517
+ x = self.conv_post(x)
518
+ x = torch.tanh(x)
519
+ return x
520
+
521
+ def remove_weight_norm(self):
522
+ for l in self.ups:
523
+ remove_weight_norm(l)
524
+ for l in self.resblocks:
525
+ l.remove_weight_norm()
526
+
527
+
528
+ sr2sr = {
529
+ "32k": 32000,
530
+ "40k": 40000,
531
+ "48k": 48000,
532
+ }
533
+
534
+
535
+ class SynthesizerTrnMs256NSFsid(nn.Module):
536
+ def __init__(
537
+ self,
538
+ spec_channels,
539
+ segment_size,
540
+ inter_channels,
541
+ hidden_channels,
542
+ filter_channels,
543
+ n_heads,
544
+ n_layers,
545
+ kernel_size,
546
+ p_dropout,
547
+ resblock,
548
+ resblock_kernel_sizes,
549
+ resblock_dilation_sizes,
550
+ upsample_rates,
551
+ upsample_initial_channel,
552
+ upsample_kernel_sizes,
553
+ spk_embed_dim,
554
+ gin_channels,
555
+ sr,
556
+ **kwargs
557
+ ):
558
+ super().__init__()
559
+ if type(sr) == type("strr"):
560
+ sr = sr2sr[sr]
561
+ self.spec_channels = spec_channels
562
+ self.inter_channels = inter_channels
563
+ self.hidden_channels = hidden_channels
564
+ self.filter_channels = filter_channels
565
+ self.n_heads = n_heads
566
+ self.n_layers = n_layers
567
+ self.kernel_size = kernel_size
568
+ self.p_dropout = p_dropout
569
+ self.resblock = resblock
570
+ self.resblock_kernel_sizes = resblock_kernel_sizes
571
+ self.resblock_dilation_sizes = resblock_dilation_sizes
572
+ self.upsample_rates = upsample_rates
573
+ self.upsample_initial_channel = upsample_initial_channel
574
+ self.upsample_kernel_sizes = upsample_kernel_sizes
575
+ self.segment_size = segment_size
576
+ self.gin_channels = gin_channels
577
+ # self.hop_length = hop_length#
578
+ self.spk_embed_dim = spk_embed_dim
579
+ self.enc_p = TextEncoder256(
580
+ inter_channels,
581
+ hidden_channels,
582
+ filter_channels,
583
+ n_heads,
584
+ n_layers,
585
+ kernel_size,
586
+ p_dropout,
587
+ )
588
+ self.dec = GeneratorNSF(
589
+ inter_channels,
590
+ resblock,
591
+ resblock_kernel_sizes,
592
+ resblock_dilation_sizes,
593
+ upsample_rates,
594
+ upsample_initial_channel,
595
+ upsample_kernel_sizes,
596
+ gin_channels=gin_channels,
597
+ sr=sr,
598
+ is_half=kwargs["is_half"],
599
+ )
600
+ self.enc_q = PosteriorEncoder(
601
+ spec_channels,
602
+ inter_channels,
603
+ hidden_channels,
604
+ 5,
605
+ 1,
606
+ 16,
607
+ gin_channels=gin_channels,
608
+ )
609
+ self.flow = ResidualCouplingBlock(
610
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
611
+ )
612
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
613
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
614
+
615
+ def remove_weight_norm(self):
616
+ self.dec.remove_weight_norm()
617
+ self.flow.remove_weight_norm()
618
+ self.enc_q.remove_weight_norm()
619
+
620
+ def forward(
621
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
622
+ ): # 这里ds是id,[bs,1]
623
+ # print(1,pitch.shape)#[bs,t]
624
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
625
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
626
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
627
+ z_p = self.flow(z, y_mask, g=g)
628
+ z_slice, ids_slice = commons.rand_slice_segments(
629
+ z, y_lengths, self.segment_size
630
+ )
631
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
632
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
633
+ # print(-2,pitchf.shape,z_slice.shape)
634
+ o = self.dec(z_slice, pitchf, g=g)
635
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
636
+
637
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
638
+ g = self.emb_g(sid).unsqueeze(-1)
639
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
640
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
641
+ if rate:
642
+ head = int(z_p.shape[2] * rate)
643
+ z_p = z_p[:, :, -head:]
644
+ x_mask = x_mask[:, :, -head:]
645
+ nsff0 = nsff0[:, -head:]
646
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
647
+ o = self.dec(z * x_mask, nsff0, g=g)
648
+ return o, x_mask, (z, z_p, m_p, logs_p)
649
+
650
+
651
+ class SynthesizerTrnMs768NSFsid(nn.Module):
652
+ def __init__(
653
+ self,
654
+ spec_channels,
655
+ segment_size,
656
+ inter_channels,
657
+ hidden_channels,
658
+ filter_channels,
659
+ n_heads,
660
+ n_layers,
661
+ kernel_size,
662
+ p_dropout,
663
+ resblock,
664
+ resblock_kernel_sizes,
665
+ resblock_dilation_sizes,
666
+ upsample_rates,
667
+ upsample_initial_channel,
668
+ upsample_kernel_sizes,
669
+ spk_embed_dim,
670
+ gin_channels,
671
+ sr,
672
+ **kwargs
673
+ ):
674
+ super().__init__()
675
+ if type(sr) == type("strr"):
676
+ sr = sr2sr[sr]
677
+ self.spec_channels = spec_channels
678
+ self.inter_channels = inter_channels
679
+ self.hidden_channels = hidden_channels
680
+ self.filter_channels = filter_channels
681
+ self.n_heads = n_heads
682
+ self.n_layers = n_layers
683
+ self.kernel_size = kernel_size
684
+ self.p_dropout = p_dropout
685
+ self.resblock = resblock
686
+ self.resblock_kernel_sizes = resblock_kernel_sizes
687
+ self.resblock_dilation_sizes = resblock_dilation_sizes
688
+ self.upsample_rates = upsample_rates
689
+ self.upsample_initial_channel = upsample_initial_channel
690
+ self.upsample_kernel_sizes = upsample_kernel_sizes
691
+ self.segment_size = segment_size
692
+ self.gin_channels = gin_channels
693
+ # self.hop_length = hop_length#
694
+ self.spk_embed_dim = spk_embed_dim
695
+ self.enc_p = TextEncoder768(
696
+ inter_channels,
697
+ hidden_channels,
698
+ filter_channels,
699
+ n_heads,
700
+ n_layers,
701
+ kernel_size,
702
+ p_dropout,
703
+ )
704
+ self.dec = GeneratorNSF(
705
+ inter_channels,
706
+ resblock,
707
+ resblock_kernel_sizes,
708
+ resblock_dilation_sizes,
709
+ upsample_rates,
710
+ upsample_initial_channel,
711
+ upsample_kernel_sizes,
712
+ gin_channels=gin_channels,
713
+ sr=sr,
714
+ is_half=kwargs["is_half"],
715
+ )
716
+ self.enc_q = PosteriorEncoder(
717
+ spec_channels,
718
+ inter_channels,
719
+ hidden_channels,
720
+ 5,
721
+ 1,
722
+ 16,
723
+ gin_channels=gin_channels,
724
+ )
725
+ self.flow = ResidualCouplingBlock(
726
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
727
+ )
728
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
729
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
730
+
731
+ def remove_weight_norm(self):
732
+ self.dec.remove_weight_norm()
733
+ self.flow.remove_weight_norm()
734
+ self.enc_q.remove_weight_norm()
735
+
736
+ def forward(
737
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
738
+ ): # 这里ds是id,[bs,1]
739
+ # print(1,pitch.shape)#[bs,t]
740
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
741
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
742
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
743
+ z_p = self.flow(z, y_mask, g=g)
744
+ z_slice, ids_slice = commons.rand_slice_segments(
745
+ z, y_lengths, self.segment_size
746
+ )
747
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
748
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
749
+ # print(-2,pitchf.shape,z_slice.shape)
750
+ o = self.dec(z_slice, pitchf, g=g)
751
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
752
+
753
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
754
+ g = self.emb_g(sid).unsqueeze(-1)
755
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
756
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
757
+ if rate:
758
+ head = int(z_p.shape[2] * rate)
759
+ z_p = z_p[:, :, -head:]
760
+ x_mask = x_mask[:, :, -head:]
761
+ nsff0 = nsff0[:, -head:]
762
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
763
+ o = self.dec(z * x_mask, nsff0, g=g)
764
+ return o, x_mask, (z, z_p, m_p, logs_p)
765
+
766
+
767
+ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
768
+ def __init__(
769
+ self,
770
+ spec_channels,
771
+ segment_size,
772
+ inter_channels,
773
+ hidden_channels,
774
+ filter_channels,
775
+ n_heads,
776
+ n_layers,
777
+ kernel_size,
778
+ p_dropout,
779
+ resblock,
780
+ resblock_kernel_sizes,
781
+ resblock_dilation_sizes,
782
+ upsample_rates,
783
+ upsample_initial_channel,
784
+ upsample_kernel_sizes,
785
+ spk_embed_dim,
786
+ gin_channels,
787
+ sr=None,
788
+ **kwargs
789
+ ):
790
+ super().__init__()
791
+ self.spec_channels = spec_channels
792
+ self.inter_channels = inter_channels
793
+ self.hidden_channels = hidden_channels
794
+ self.filter_channels = filter_channels
795
+ self.n_heads = n_heads
796
+ self.n_layers = n_layers
797
+ self.kernel_size = kernel_size
798
+ self.p_dropout = p_dropout
799
+ self.resblock = resblock
800
+ self.resblock_kernel_sizes = resblock_kernel_sizes
801
+ self.resblock_dilation_sizes = resblock_dilation_sizes
802
+ self.upsample_rates = upsample_rates
803
+ self.upsample_initial_channel = upsample_initial_channel
804
+ self.upsample_kernel_sizes = upsample_kernel_sizes
805
+ self.segment_size = segment_size
806
+ self.gin_channels = gin_channels
807
+ # self.hop_length = hop_length#
808
+ self.spk_embed_dim = spk_embed_dim
809
+ self.enc_p = TextEncoder256(
810
+ inter_channels,
811
+ hidden_channels,
812
+ filter_channels,
813
+ n_heads,
814
+ n_layers,
815
+ kernel_size,
816
+ p_dropout,
817
+ f0=False,
818
+ )
819
+ self.dec = Generator(
820
+ inter_channels,
821
+ resblock,
822
+ resblock_kernel_sizes,
823
+ resblock_dilation_sizes,
824
+ upsample_rates,
825
+ upsample_initial_channel,
826
+ upsample_kernel_sizes,
827
+ gin_channels=gin_channels,
828
+ )
829
+ self.enc_q = PosteriorEncoder(
830
+ spec_channels,
831
+ inter_channels,
832
+ hidden_channels,
833
+ 5,
834
+ 1,
835
+ 16,
836
+ gin_channels=gin_channels,
837
+ )
838
+ self.flow = ResidualCouplingBlock(
839
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
840
+ )
841
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
842
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
843
+
844
+ def remove_weight_norm(self):
845
+ self.dec.remove_weight_norm()
846
+ self.flow.remove_weight_norm()
847
+ self.enc_q.remove_weight_norm()
848
+
849
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
850
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
851
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
852
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
853
+ z_p = self.flow(z, y_mask, g=g)
854
+ z_slice, ids_slice = commons.rand_slice_segments(
855
+ z, y_lengths, self.segment_size
856
+ )
857
+ o = self.dec(z_slice, g=g)
858
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
859
+
860
+ def infer(self, phone, phone_lengths, sid, rate=None):
861
+ g = self.emb_g(sid).unsqueeze(-1)
862
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
863
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
864
+ if rate:
865
+ head = int(z_p.shape[2] * rate)
866
+ z_p = z_p[:, :, -head:]
867
+ x_mask = x_mask[:, :, -head:]
868
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
869
+ o = self.dec(z * x_mask, g=g)
870
+ return o, x_mask, (z, z_p, m_p, logs_p)
871
+
872
+
873
+ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
874
+ def __init__(
875
+ self,
876
+ spec_channels,
877
+ segment_size,
878
+ inter_channels,
879
+ hidden_channels,
880
+ filter_channels,
881
+ n_heads,
882
+ n_layers,
883
+ kernel_size,
884
+ p_dropout,
885
+ resblock,
886
+ resblock_kernel_sizes,
887
+ resblock_dilation_sizes,
888
+ upsample_rates,
889
+ upsample_initial_channel,
890
+ upsample_kernel_sizes,
891
+ spk_embed_dim,
892
+ gin_channels,
893
+ sr=None,
894
+ **kwargs
895
+ ):
896
+ super().__init__()
897
+ self.spec_channels = spec_channels
898
+ self.inter_channels = inter_channels
899
+ self.hidden_channels = hidden_channels
900
+ self.filter_channels = filter_channels
901
+ self.n_heads = n_heads
902
+ self.n_layers = n_layers
903
+ self.kernel_size = kernel_size
904
+ self.p_dropout = p_dropout
905
+ self.resblock = resblock
906
+ self.resblock_kernel_sizes = resblock_kernel_sizes
907
+ self.resblock_dilation_sizes = resblock_dilation_sizes
908
+ self.upsample_rates = upsample_rates
909
+ self.upsample_initial_channel = upsample_initial_channel
910
+ self.upsample_kernel_sizes = upsample_kernel_sizes
911
+ self.segment_size = segment_size
912
+ self.gin_channels = gin_channels
913
+ # self.hop_length = hop_length#
914
+ self.spk_embed_dim = spk_embed_dim
915
+ self.enc_p = TextEncoder768(
916
+ inter_channels,
917
+ hidden_channels,
918
+ filter_channels,
919
+ n_heads,
920
+ n_layers,
921
+ kernel_size,
922
+ p_dropout,
923
+ f0=False,
924
+ )
925
+ self.dec = Generator(
926
+ inter_channels,
927
+ resblock,
928
+ resblock_kernel_sizes,
929
+ resblock_dilation_sizes,
930
+ upsample_rates,
931
+ upsample_initial_channel,
932
+ upsample_kernel_sizes,
933
+ gin_channels=gin_channels,
934
+ )
935
+ self.enc_q = PosteriorEncoder(
936
+ spec_channels,
937
+ inter_channels,
938
+ hidden_channels,
939
+ 5,
940
+ 1,
941
+ 16,
942
+ gin_channels=gin_channels,
943
+ )
944
+ self.flow = ResidualCouplingBlock(
945
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
946
+ )
947
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
948
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
949
+
950
+ def remove_weight_norm(self):
951
+ self.dec.remove_weight_norm()
952
+ self.flow.remove_weight_norm()
953
+ self.enc_q.remove_weight_norm()
954
+
955
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
956
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
957
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
958
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
959
+ z_p = self.flow(z, y_mask, g=g)
960
+ z_slice, ids_slice = commons.rand_slice_segments(
961
+ z, y_lengths, self.segment_size
962
+ )
963
+ o = self.dec(z_slice, g=g)
964
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
965
+
966
+ def infer(self, phone, phone_lengths, sid, rate=None):
967
+ g = self.emb_g(sid).unsqueeze(-1)
968
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
969
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
970
+ if rate:
971
+ head = int(z_p.shape[2] * rate)
972
+ z_p = z_p[:, :, -head:]
973
+ x_mask = x_mask[:, :, -head:]
974
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
975
+ o = self.dec(z * x_mask, g=g)
976
+ return o, x_mask, (z, z_p, m_p, logs_p)
977
+
978
+
979
+ class MultiPeriodDiscriminator(torch.nn.Module):
980
+ def __init__(self, use_spectral_norm=False):
981
+ super(MultiPeriodDiscriminator, self).__init__()
982
+ periods = [2, 3, 5, 7, 11, 17]
983
+ # periods = [3, 5, 7, 11, 17, 23, 37]
984
+
985
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
986
+ discs = discs + [
987
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
988
+ ]
989
+ self.discriminators = nn.ModuleList(discs)
990
+
991
+ def forward(self, y, y_hat):
992
+ y_d_rs = [] #
993
+ y_d_gs = []
994
+ fmap_rs = []
995
+ fmap_gs = []
996
+ for i, d in enumerate(self.discriminators):
997
+ y_d_r, fmap_r = d(y)
998
+ y_d_g, fmap_g = d(y_hat)
999
+ # for j in range(len(fmap_r)):
1000
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1001
+ y_d_rs.append(y_d_r)
1002
+ y_d_gs.append(y_d_g)
1003
+ fmap_rs.append(fmap_r)
1004
+ fmap_gs.append(fmap_g)
1005
+
1006
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1007
+
1008
+
1009
+ class MultiPeriodDiscriminatorV2(torch.nn.Module):
1010
+ def __init__(self, use_spectral_norm=False):
1011
+ super(MultiPeriodDiscriminatorV2, self).__init__()
1012
+ # periods = [2, 3, 5, 7, 11, 17]
1013
+ periods = [2, 3, 5, 7, 11, 17, 23, 37]
1014
+
1015
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
1016
+ discs = discs + [
1017
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
1018
+ ]
1019
+ self.discriminators = nn.ModuleList(discs)
1020
+
1021
+ def forward(self, y, y_hat):
1022
+ y_d_rs = [] #
1023
+ y_d_gs = []
1024
+ fmap_rs = []
1025
+ fmap_gs = []
1026
+ for i, d in enumerate(self.discriminators):
1027
+ y_d_r, fmap_r = d(y)
1028
+ y_d_g, fmap_g = d(y_hat)
1029
+ # for j in range(len(fmap_r)):
1030
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1031
+ y_d_rs.append(y_d_r)
1032
+ y_d_gs.append(y_d_g)
1033
+ fmap_rs.append(fmap_r)
1034
+ fmap_gs.append(fmap_g)
1035
+
1036
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1037
+
1038
+
1039
+ class DiscriminatorS(torch.nn.Module):
1040
+ def __init__(self, use_spectral_norm=False):
1041
+ super(DiscriminatorS, self).__init__()
1042
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1043
+ self.convs = nn.ModuleList(
1044
+ [
1045
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1046
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1047
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1048
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1049
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1050
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1051
+ ]
1052
+ )
1053
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1054
+
1055
+ def forward(self, x):
1056
+ fmap = []
1057
+
1058
+ for l in self.convs:
1059
+ x = l(x)
1060
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1061
+ fmap.append(x)
1062
+ x = self.conv_post(x)
1063
+ fmap.append(x)
1064
+ x = torch.flatten(x, 1, -1)
1065
+
1066
+ return x, fmap
1067
+
1068
+
1069
+ class DiscriminatorP(torch.nn.Module):
1070
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1071
+ super(DiscriminatorP, self).__init__()
1072
+ self.period = period
1073
+ self.use_spectral_norm = use_spectral_norm
1074
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1075
+ self.convs = nn.ModuleList(
1076
+ [
1077
+ norm_f(
1078
+ Conv2d(
1079
+ 1,
1080
+ 32,
1081
+ (kernel_size, 1),
1082
+ (stride, 1),
1083
+ padding=(get_padding(kernel_size, 1), 0),
1084
+ )
1085
+ ),
1086
+ norm_f(
1087
+ Conv2d(
1088
+ 32,
1089
+ 128,
1090
+ (kernel_size, 1),
1091
+ (stride, 1),
1092
+ padding=(get_padding(kernel_size, 1), 0),
1093
+ )
1094
+ ),
1095
+ norm_f(
1096
+ Conv2d(
1097
+ 128,
1098
+ 512,
1099
+ (kernel_size, 1),
1100
+ (stride, 1),
1101
+ padding=(get_padding(kernel_size, 1), 0),
1102
+ )
1103
+ ),
1104
+ norm_f(
1105
+ Conv2d(
1106
+ 512,
1107
+ 1024,
1108
+ (kernel_size, 1),
1109
+ (stride, 1),
1110
+ padding=(get_padding(kernel_size, 1), 0),
1111
+ )
1112
+ ),
1113
+ norm_f(
1114
+ Conv2d(
1115
+ 1024,
1116
+ 1024,
1117
+ (kernel_size, 1),
1118
+ 1,
1119
+ padding=(get_padding(kernel_size, 1), 0),
1120
+ )
1121
+ ),
1122
+ ]
1123
+ )
1124
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1125
+
1126
+ def forward(self, x):
1127
+ fmap = []
1128
+
1129
+ # 1d to 2d
1130
+ b, c, t = x.shape
1131
+ if t % self.period != 0: # pad first
1132
+ n_pad = self.period - (t % self.period)
1133
+ x = F.pad(x, (0, n_pad), "reflect")
1134
+ t = t + n_pad
1135
+ x = x.view(b, c, t // self.period, self.period)
1136
+
1137
+ for l in self.convs:
1138
+ x = l(x)
1139
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1140
+ fmap.append(x)
1141
+ x = self.conv_post(x)
1142
+ fmap.append(x)
1143
+ x = torch.flatten(x, 1, -1)
1144
+
1145
+ return x, fmap
lib/models_dml.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from lib import modules
6
+ from lib import attentions
7
+ from lib import commons
8
+ from lib.commons import init_weights, get_padding
9
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
10
+ from torch.nn.utils import remove_weight_norm, spectral_norm
11
+ from torch.nn.utils.parametrizations import weight_norm
12
+ from lib.commons import init_weights
13
+ import numpy as np
14
+
15
+
16
+ class TextEncoder256(nn.Module):
17
+ def __init__(
18
+ self,
19
+ out_channels,
20
+ hidden_channels,
21
+ filter_channels,
22
+ n_heads,
23
+ n_layers,
24
+ kernel_size,
25
+ p_dropout,
26
+ f0=True,
27
+ ):
28
+ super().__init__()
29
+ self.out_channels = out_channels
30
+ self.hidden_channels = hidden_channels
31
+ self.filter_channels = filter_channels
32
+ self.n_heads = n_heads
33
+ self.n_layers = n_layers
34
+ self.kernel_size = kernel_size
35
+ self.p_dropout = p_dropout
36
+ self.emb_phone = nn.Linear(256, hidden_channels)
37
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
38
+ if f0 == True:
39
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
40
+ self.encoder = attentions.Encoder(
41
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
42
+ )
43
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
44
+
45
+ def forward(self, phone, pitch, lengths):
46
+ if pitch == None:
47
+ x = self.emb_phone(phone)
48
+ else:
49
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
50
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
51
+ x = self.lrelu(x)
52
+ x = torch.transpose(x, 1, -1) # [b, h, t]
53
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
54
+ x.dtype
55
+ )
56
+ x = self.encoder(x * x_mask, x_mask)
57
+ stats = self.proj(x) * x_mask
58
+
59
+ m, logs = torch.split(stats, self.out_channels, dim=1)
60
+ return m, logs, x_mask
61
+
62
+
63
+ class TextEncoder768(nn.Module):
64
+ def __init__(
65
+ self,
66
+ out_channels,
67
+ hidden_channels,
68
+ filter_channels,
69
+ n_heads,
70
+ n_layers,
71
+ kernel_size,
72
+ p_dropout,
73
+ f0=True,
74
+ ):
75
+ super().__init__()
76
+ self.out_channels = out_channels
77
+ self.hidden_channels = hidden_channels
78
+ self.filter_channels = filter_channels
79
+ self.n_heads = n_heads
80
+ self.n_layers = n_layers
81
+ self.kernel_size = kernel_size
82
+ self.p_dropout = p_dropout
83
+ self.emb_phone = nn.Linear(768, hidden_channels)
84
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
85
+ if f0 == True:
86
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
87
+ self.encoder = attentions.Encoder(
88
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
89
+ )
90
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
91
+
92
+ def forward(self, phone, pitch, lengths):
93
+ if pitch == None:
94
+ x = self.emb_phone(phone)
95
+ else:
96
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
97
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
98
+ x = self.lrelu(x)
99
+ x = torch.transpose(x, 1, -1) # [b, h, t]
100
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
101
+ x.dtype
102
+ )
103
+ x = self.encoder(x * x_mask, x_mask)
104
+ stats = self.proj(x) * x_mask
105
+
106
+ m, logs = torch.split(stats, self.out_channels, dim=1)
107
+ return m, logs, x_mask
108
+
109
+
110
+ class ResidualCouplingBlock(nn.Module):
111
+ def __init__(
112
+ self,
113
+ channels,
114
+ hidden_channels,
115
+ kernel_size,
116
+ dilation_rate,
117
+ n_layers,
118
+ n_flows=4,
119
+ gin_channels=0,
120
+ ):
121
+ super().__init__()
122
+ self.channels = channels
123
+ self.hidden_channels = hidden_channels
124
+ self.kernel_size = kernel_size
125
+ self.dilation_rate = dilation_rate
126
+ self.n_layers = n_layers
127
+ self.n_flows = n_flows
128
+ self.gin_channels = gin_channels
129
+
130
+ self.flows = nn.ModuleList()
131
+ for i in range(n_flows):
132
+ self.flows.append(
133
+ modules.ResidualCouplingLayer(
134
+ channels,
135
+ hidden_channels,
136
+ kernel_size,
137
+ dilation_rate,
138
+ n_layers,
139
+ gin_channels=gin_channels,
140
+ mean_only=True,
141
+ )
142
+ )
143
+ self.flows.append(modules.Flip())
144
+
145
+ def forward(self, x, x_mask, g=None, reverse=False):
146
+ if not reverse:
147
+ for flow in self.flows:
148
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
149
+ else:
150
+ for flow in reversed(self.flows):
151
+ x = flow(x, x_mask, g=g, reverse=reverse)
152
+ return x
153
+
154
+ def remove_weight_norm(self):
155
+ for i in range(self.n_flows):
156
+ self.flows[i * 2].remove_weight_norm()
157
+
158
+
159
+ class PosteriorEncoder(nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ out_channels,
164
+ hidden_channels,
165
+ kernel_size,
166
+ dilation_rate,
167
+ n_layers,
168
+ gin_channels=0,
169
+ ):
170
+ super().__init__()
171
+ self.in_channels = in_channels
172
+ self.out_channels = out_channels
173
+ self.hidden_channels = hidden_channels
174
+ self.kernel_size = kernel_size
175
+ self.dilation_rate = dilation_rate
176
+ self.n_layers = n_layers
177
+ self.gin_channels = gin_channels
178
+
179
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
180
+ self.enc = modules.WN(
181
+ hidden_channels,
182
+ kernel_size,
183
+ dilation_rate,
184
+ n_layers,
185
+ gin_channels=gin_channels,
186
+ )
187
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
188
+
189
+ def forward(self, x, x_lengths, g=None):
190
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
191
+ x.dtype
192
+ )
193
+ x = self.pre(x) * x_mask
194
+ x = self.enc(x, x_mask, g=g)
195
+ stats = self.proj(x) * x_mask
196
+ m, logs = torch.split(stats, self.out_channels, dim=1)
197
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
198
+ return z, m, logs, x_mask
199
+
200
+ def remove_weight_norm(self):
201
+ self.enc.remove_weight_norm()
202
+
203
+
204
+ class Generator(torch.nn.Module):
205
+ def __init__(
206
+ self,
207
+ initial_channel,
208
+ resblock,
209
+ resblock_kernel_sizes,
210
+ resblock_dilation_sizes,
211
+ upsample_rates,
212
+ upsample_initial_channel,
213
+ upsample_kernel_sizes,
214
+ gin_channels=0,
215
+ ):
216
+ super(Generator, self).__init__()
217
+ self.num_kernels = len(resblock_kernel_sizes)
218
+ self.num_upsamples = len(upsample_rates)
219
+ self.conv_pre = Conv1d(
220
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
221
+ )
222
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
223
+
224
+ self.ups = nn.ModuleList()
225
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
226
+ self.ups.append(
227
+ weight_norm(
228
+ ConvTranspose1d(
229
+ upsample_initial_channel // (2**i),
230
+ upsample_initial_channel // (2 ** (i + 1)),
231
+ k,
232
+ u,
233
+ padding=(k - u) // 2,
234
+ )
235
+ )
236
+ )
237
+
238
+ self.resblocks = nn.ModuleList()
239
+ for i in range(len(self.ups)):
240
+ ch = upsample_initial_channel // (2 ** (i + 1))
241
+ for j, (k, d) in enumerate(
242
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
243
+ ):
244
+ self.resblocks.append(resblock(ch, k, d))
245
+
246
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
247
+ self.ups.apply(init_weights)
248
+
249
+ if gin_channels != 0:
250
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
251
+
252
+ def forward(self, x, g=None):
253
+ x = self.conv_pre(x)
254
+ if g is not None:
255
+ x = x + self.cond(g)
256
+
257
+ for i in range(self.num_upsamples):
258
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
259
+ x = self.ups[i](x)
260
+ xs = None
261
+ for j in range(self.num_kernels):
262
+ if xs is None:
263
+ xs = self.resblocks[i * self.num_kernels + j](x)
264
+ else:
265
+ xs += self.resblocks[i * self.num_kernels + j](x)
266
+ x = xs / self.num_kernels
267
+ x = F.leaky_relu(x)
268
+ x = self.conv_post(x)
269
+ x = torch.tanh(x)
270
+
271
+ return x
272
+
273
+ def remove_weight_norm(self):
274
+ for l in self.ups:
275
+ remove_weight_norm(l)
276
+ for l in self.resblocks:
277
+ l.remove_weight_norm()
278
+
279
+
280
+ class SineGen(torch.nn.Module):
281
+ """Definition of sine generator
282
+ SineGen(samp_rate, harmonic_num = 0,
283
+ sine_amp = 0.1, noise_std = 0.003,
284
+ voiced_threshold = 0,
285
+ flag_for_pulse=False)
286
+ samp_rate: sampling rate in Hz
287
+ harmonic_num: number of harmonic overtones (default 0)
288
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
289
+ noise_std: std of Gaussian noise (default 0.003)
290
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
291
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
292
+ Note: when flag_for_pulse is True, the first time step of a voiced
293
+ segment is always sin(np.pi) or cos(0)
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ samp_rate,
299
+ harmonic_num=0,
300
+ sine_amp=0.1,
301
+ noise_std=0.003,
302
+ voiced_threshold=0,
303
+ flag_for_pulse=False,
304
+ ):
305
+ super(SineGen, self).__init__()
306
+ self.sine_amp = sine_amp
307
+ self.noise_std = noise_std
308
+ self.harmonic_num = harmonic_num
309
+ self.dim = self.harmonic_num + 1
310
+ self.sampling_rate = samp_rate
311
+ self.voiced_threshold = voiced_threshold
312
+
313
+ def _f02uv(self, f0):
314
+ # generate uv signal
315
+ uv = torch.ones_like(f0)
316
+ uv = uv * (f0 > self.voiced_threshold)
317
+ return uv.float()
318
+
319
+ def forward(self, f0, upp):
320
+ """sine_tensor, uv = forward(f0)
321
+ input F0: tensor(batchsize=1, length, dim=1)
322
+ f0 for unvoiced steps should be 0
323
+ output sine_tensor: tensor(batchsize=1, length, dim)
324
+ output uv: tensor(batchsize=1, length, 1)
325
+ """
326
+ with torch.no_grad():
327
+ f0 = f0[:, None].transpose(1, 2)
328
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
329
+ # fundamental component
330
+ f0_buf[:, :, 0] = f0[:, :, 0]
331
+ for idx in np.arange(self.harmonic_num):
332
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
333
+ idx + 2
334
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
335
+ rad_values = (
336
+ f0_buf / self.sampling_rate
337
+ ) % 1 ###%1意味着n_har的乘积无法后处理优化
338
+ rand_ini = torch.rand(
339
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
340
+ )
341
+ rand_ini[:, 0] = 0
342
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
343
+ tmp_over_one = torch.cumsum(
344
+ rad_values, 1
345
+ ) # % 1 #####%1意味着后面的cumsum无法再优化
346
+ tmp_over_one *= upp
347
+ tmp_over_one = F.interpolate(
348
+ tmp_over_one.transpose(2, 1),
349
+ scale_factor=upp,
350
+ mode="linear",
351
+ align_corners=True,
352
+ ).transpose(2, 1)
353
+ rad_values = F.interpolate(
354
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
355
+ ).transpose(
356
+ 2, 1
357
+ ) #######
358
+ tmp_over_one %= 1
359
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
360
+ cumsum_shift = torch.zeros_like(rad_values)
361
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
362
+ sine_waves = torch.sin(
363
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
364
+ )
365
+ sine_waves = sine_waves * self.sine_amp
366
+ uv = self._f02uv(f0)
367
+ uv = F.interpolate(
368
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
369
+ ).transpose(2, 1)
370
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
371
+ noise = noise_amp * torch.randn_like(sine_waves)
372
+ sine_waves = sine_waves * uv + noise
373
+ return sine_waves, uv, noise
374
+
375
+
376
+ class SourceModuleHnNSF(torch.nn.Module):
377
+ """SourceModule for hn-nsf
378
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
379
+ add_noise_std=0.003, voiced_threshod=0)
380
+ sampling_rate: sampling_rate in Hz
381
+ harmonic_num: number of harmonic above F0 (default: 0)
382
+ sine_amp: amplitude of sine source signal (default: 0.1)
383
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
384
+ note that amplitude of noise in unvoiced is decided
385
+ by sine_amp
386
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
387
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
388
+ F0_sampled (batchsize, length, 1)
389
+ Sine_source (batchsize, length, 1)
390
+ noise_source (batchsize, length 1)
391
+ uv (batchsize, length, 1)
392
+ """
393
+
394
+ def __init__(
395
+ self,
396
+ sampling_rate,
397
+ harmonic_num=0,
398
+ sine_amp=0.1,
399
+ add_noise_std=0.003,
400
+ voiced_threshod=0,
401
+ is_half=True,
402
+ ):
403
+ super(SourceModuleHnNSF, self).__init__()
404
+
405
+ self.sine_amp = sine_amp
406
+ self.noise_std = add_noise_std
407
+ self.is_half = is_half
408
+ # to produce sine waveforms
409
+ self.l_sin_gen = SineGen(
410
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
411
+ )
412
+
413
+ # to merge source harmonics into a single excitation
414
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
415
+ self.l_tanh = torch.nn.Tanh()
416
+
417
+ def forward(self, x, upp=None):
418
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
419
+ if self.is_half:
420
+ sine_wavs = sine_wavs.half()
421
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
422
+ return sine_merge, None, None # noise, uv
423
+
424
+
425
+ class GeneratorNSF(torch.nn.Module):
426
+ def __init__(
427
+ self,
428
+ initial_channel,
429
+ resblock,
430
+ resblock_kernel_sizes,
431
+ resblock_dilation_sizes,
432
+ upsample_rates,
433
+ upsample_initial_channel,
434
+ upsample_kernel_sizes,
435
+ gin_channels,
436
+ sr,
437
+ is_half=False,
438
+ ):
439
+ super(GeneratorNSF, self).__init__()
440
+ self.num_kernels = len(resblock_kernel_sizes)
441
+ self.num_upsamples = len(upsample_rates)
442
+
443
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
444
+ self.m_source = SourceModuleHnNSF(
445
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
446
+ )
447
+ self.noise_convs = nn.ModuleList()
448
+ self.conv_pre = Conv1d(
449
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
450
+ )
451
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
452
+
453
+ self.ups = nn.ModuleList()
454
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
455
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
456
+ self.ups.append(
457
+ weight_norm(
458
+ ConvTranspose1d(
459
+ upsample_initial_channel // (2**i),
460
+ upsample_initial_channel // (2 ** (i + 1)),
461
+ k,
462
+ u,
463
+ padding=(k - u) // 2,
464
+ )
465
+ )
466
+ )
467
+ if i + 1 < len(upsample_rates):
468
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
469
+ self.noise_convs.append(
470
+ Conv1d(
471
+ 1,
472
+ c_cur,
473
+ kernel_size=stride_f0 * 2,
474
+ stride=stride_f0,
475
+ padding=stride_f0 // 2,
476
+ )
477
+ )
478
+ else:
479
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
480
+
481
+ self.resblocks = nn.ModuleList()
482
+ for i in range(len(self.ups)):
483
+ ch = upsample_initial_channel // (2 ** (i + 1))
484
+ for j, (k, d) in enumerate(
485
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
486
+ ):
487
+ self.resblocks.append(resblock(ch, k, d))
488
+
489
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
490
+ self.ups.apply(init_weights)
491
+
492
+ if gin_channels != 0:
493
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
494
+
495
+ self.upp = np.prod(upsample_rates)
496
+
497
+ def forward(self, x, f0, g=None):
498
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
499
+ har_source = har_source.transpose(1, 2)
500
+ x = self.conv_pre(x)
501
+ if g is not None:
502
+ x = x + self.cond(g)
503
+
504
+ for i in range(self.num_upsamples):
505
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
506
+ x = self.ups[i](x)
507
+ x_source = self.noise_convs[i](har_source)
508
+ x = x + x_source
509
+ xs = None
510
+ for j in range(self.num_kernels):
511
+ if xs is None:
512
+ xs = self.resblocks[i * self.num_kernels + j](x)
513
+ else:
514
+ xs += self.resblocks[i * self.num_kernels + j](x)
515
+ x = xs / self.num_kernels
516
+ x = F.leaky_relu(x)
517
+ x = self.conv_post(x)
518
+ x = torch.tanh(x)
519
+ return x
520
+
521
+ def remove_weight_norm(self):
522
+ for l in self.ups:
523
+ remove_weight_norm(l)
524
+ for l in self.resblocks:
525
+ l.remove_weight_norm()
526
+
527
+
528
+ sr2sr = {
529
+ "32k": 32000,
530
+ "40k": 40000,
531
+ "48k": 48000,
532
+ }
533
+
534
+
535
+ class SynthesizerTrnMs256NSFsid(nn.Module):
536
+ def __init__(
537
+ self,
538
+ spec_channels,
539
+ segment_size,
540
+ inter_channels,
541
+ hidden_channels,
542
+ filter_channels,
543
+ n_heads,
544
+ n_layers,
545
+ kernel_size,
546
+ p_dropout,
547
+ resblock,
548
+ resblock_kernel_sizes,
549
+ resblock_dilation_sizes,
550
+ upsample_rates,
551
+ upsample_initial_channel,
552
+ upsample_kernel_sizes,
553
+ spk_embed_dim,
554
+ gin_channels,
555
+ sr,
556
+ **kwargs
557
+ ):
558
+ super().__init__()
559
+ if type(sr) == type("strr"):
560
+ sr = sr2sr[sr]
561
+ self.spec_channels = spec_channels
562
+ self.inter_channels = inter_channels
563
+ self.hidden_channels = hidden_channels
564
+ self.filter_channels = filter_channels
565
+ self.n_heads = n_heads
566
+ self.n_layers = n_layers
567
+ self.kernel_size = kernel_size
568
+ self.p_dropout = p_dropout
569
+ self.resblock = resblock
570
+ self.resblock_kernel_sizes = resblock_kernel_sizes
571
+ self.resblock_dilation_sizes = resblock_dilation_sizes
572
+ self.upsample_rates = upsample_rates
573
+ self.upsample_initial_channel = upsample_initial_channel
574
+ self.upsample_kernel_sizes = upsample_kernel_sizes
575
+ self.segment_size = segment_size
576
+ self.gin_channels = gin_channels
577
+ # self.hop_length = hop_length#
578
+ self.spk_embed_dim = spk_embed_dim
579
+ self.enc_p = TextEncoder256(
580
+ inter_channels,
581
+ hidden_channels,
582
+ filter_channels,
583
+ n_heads,
584
+ n_layers,
585
+ kernel_size,
586
+ p_dropout,
587
+ )
588
+ self.dec = GeneratorNSF(
589
+ inter_channels,
590
+ resblock,
591
+ resblock_kernel_sizes,
592
+ resblock_dilation_sizes,
593
+ upsample_rates,
594
+ upsample_initial_channel,
595
+ upsample_kernel_sizes,
596
+ gin_channels=gin_channels,
597
+ sr=sr,
598
+ is_half=kwargs["is_half"],
599
+ )
600
+ self.enc_q = PosteriorEncoder(
601
+ spec_channels,
602
+ inter_channels,
603
+ hidden_channels,
604
+ 5,
605
+ 1,
606
+ 16,
607
+ gin_channels=gin_channels,
608
+ )
609
+ self.flow = ResidualCouplingBlock(
610
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
611
+ )
612
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
613
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
614
+
615
+ def remove_weight_norm(self):
616
+ self.dec.remove_weight_norm()
617
+ self.flow.remove_weight_norm()
618
+ self.enc_q.remove_weight_norm()
619
+
620
+ def forward(
621
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
622
+ ): # 这里ds是id,[bs,1]
623
+ # print(1,pitch.shape)#[bs,t]
624
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
625
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
626
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
627
+ z_p = self.flow(z, y_mask, g=g)
628
+ z_slice, ids_slice = commons.rand_slice_segments(
629
+ z, y_lengths, self.segment_size
630
+ )
631
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
632
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
633
+ # print(-2,pitchf.shape,z_slice.shape)
634
+ o = self.dec(z_slice, pitchf, g=g)
635
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
636
+
637
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
638
+ g = self.emb_g(sid).unsqueeze(-1)
639
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
640
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
641
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
642
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
643
+ return o, x_mask, (z, z_p, m_p, logs_p)
644
+
645
+
646
+ class SynthesizerTrnMs768NSFsid(nn.Module):
647
+ def __init__(
648
+ self,
649
+ spec_channels,
650
+ segment_size,
651
+ inter_channels,
652
+ hidden_channels,
653
+ filter_channels,
654
+ n_heads,
655
+ n_layers,
656
+ kernel_size,
657
+ p_dropout,
658
+ resblock,
659
+ resblock_kernel_sizes,
660
+ resblock_dilation_sizes,
661
+ upsample_rates,
662
+ upsample_initial_channel,
663
+ upsample_kernel_sizes,
664
+ spk_embed_dim,
665
+ gin_channels,
666
+ sr,
667
+ **kwargs
668
+ ):
669
+ super().__init__()
670
+ if type(sr) == type("strr"):
671
+ sr = sr2sr[sr]
672
+ self.spec_channels = spec_channels
673
+ self.inter_channels = inter_channels
674
+ self.hidden_channels = hidden_channels
675
+ self.filter_channels = filter_channels
676
+ self.n_heads = n_heads
677
+ self.n_layers = n_layers
678
+ self.kernel_size = kernel_size
679
+ self.p_dropout = p_dropout
680
+ self.resblock = resblock
681
+ self.resblock_kernel_sizes = resblock_kernel_sizes
682
+ self.resblock_dilation_sizes = resblock_dilation_sizes
683
+ self.upsample_rates = upsample_rates
684
+ self.upsample_initial_channel = upsample_initial_channel
685
+ self.upsample_kernel_sizes = upsample_kernel_sizes
686
+ self.segment_size = segment_size
687
+ self.gin_channels = gin_channels
688
+ # self.hop_length = hop_length#
689
+ self.spk_embed_dim = spk_embed_dim
690
+ self.enc_p = TextEncoder768(
691
+ inter_channels,
692
+ hidden_channels,
693
+ filter_channels,
694
+ n_heads,
695
+ n_layers,
696
+ kernel_size,
697
+ p_dropout,
698
+ )
699
+ self.dec = GeneratorNSF(
700
+ inter_channels,
701
+ resblock,
702
+ resblock_kernel_sizes,
703
+ resblock_dilation_sizes,
704
+ upsample_rates,
705
+ upsample_initial_channel,
706
+ upsample_kernel_sizes,
707
+ gin_channels=gin_channels,
708
+ sr=sr,
709
+ is_half=kwargs["is_half"],
710
+ )
711
+ self.enc_q = PosteriorEncoder(
712
+ spec_channels,
713
+ inter_channels,
714
+ hidden_channels,
715
+ 5,
716
+ 1,
717
+ 16,
718
+ gin_channels=gin_channels,
719
+ )
720
+ self.flow = ResidualCouplingBlock(
721
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
722
+ )
723
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
724
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
725
+
726
+ def remove_weight_norm(self):
727
+ self.dec.remove_weight_norm()
728
+ self.flow.remove_weight_norm()
729
+ self.enc_q.remove_weight_norm()
730
+
731
+ def forward(
732
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
733
+ ): # 这里ds是id,[bs,1]
734
+ # print(1,pitch.shape)#[bs,t]
735
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
736
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
737
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
738
+ z_p = self.flow(z, y_mask, g=g)
739
+ z_slice, ids_slice = commons.rand_slice_segments(
740
+ z, y_lengths, self.segment_size
741
+ )
742
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
743
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
744
+ # print(-2,pitchf.shape,z_slice.shape)
745
+ o = self.dec(z_slice, pitchf, g=g)
746
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
747
+
748
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
749
+ g = self.emb_g(sid).unsqueeze(-1)
750
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
751
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
752
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
753
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
754
+ return o, x_mask, (z, z_p, m_p, logs_p)
755
+
756
+
757
+ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
758
+ def __init__(
759
+ self,
760
+ spec_channels,
761
+ segment_size,
762
+ inter_channels,
763
+ hidden_channels,
764
+ filter_channels,
765
+ n_heads,
766
+ n_layers,
767
+ kernel_size,
768
+ p_dropout,
769
+ resblock,
770
+ resblock_kernel_sizes,
771
+ resblock_dilation_sizes,
772
+ upsample_rates,
773
+ upsample_initial_channel,
774
+ upsample_kernel_sizes,
775
+ spk_embed_dim,
776
+ gin_channels,
777
+ sr=None,
778
+ **kwargs
779
+ ):
780
+ super().__init__()
781
+ self.spec_channels = spec_channels
782
+ self.inter_channels = inter_channels
783
+ self.hidden_channels = hidden_channels
784
+ self.filter_channels = filter_channels
785
+ self.n_heads = n_heads
786
+ self.n_layers = n_layers
787
+ self.kernel_size = kernel_size
788
+ self.p_dropout = p_dropout
789
+ self.resblock = resblock
790
+ self.resblock_kernel_sizes = resblock_kernel_sizes
791
+ self.resblock_dilation_sizes = resblock_dilation_sizes
792
+ self.upsample_rates = upsample_rates
793
+ self.upsample_initial_channel = upsample_initial_channel
794
+ self.upsample_kernel_sizes = upsample_kernel_sizes
795
+ self.segment_size = segment_size
796
+ self.gin_channels = gin_channels
797
+ # self.hop_length = hop_length#
798
+ self.spk_embed_dim = spk_embed_dim
799
+ self.enc_p = TextEncoder256(
800
+ inter_channels,
801
+ hidden_channels,
802
+ filter_channels,
803
+ n_heads,
804
+ n_layers,
805
+ kernel_size,
806
+ p_dropout,
807
+ f0=False,
808
+ )
809
+ self.dec = Generator(
810
+ inter_channels,
811
+ resblock,
812
+ resblock_kernel_sizes,
813
+ resblock_dilation_sizes,
814
+ upsample_rates,
815
+ upsample_initial_channel,
816
+ upsample_kernel_sizes,
817
+ gin_channels=gin_channels,
818
+ )
819
+ self.enc_q = PosteriorEncoder(
820
+ spec_channels,
821
+ inter_channels,
822
+ hidden_channels,
823
+ 5,
824
+ 1,
825
+ 16,
826
+ gin_channels=gin_channels,
827
+ )
828
+ self.flow = ResidualCouplingBlock(
829
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
830
+ )
831
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
832
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
833
+
834
+ def remove_weight_norm(self):
835
+ self.dec.remove_weight_norm()
836
+ self.flow.remove_weight_norm()
837
+ self.enc_q.remove_weight_norm()
838
+
839
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
840
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
841
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
842
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
843
+ z_p = self.flow(z, y_mask, g=g)
844
+ z_slice, ids_slice = commons.rand_slice_segments(
845
+ z, y_lengths, self.segment_size
846
+ )
847
+ o = self.dec(z_slice, g=g)
848
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
849
+
850
+ def infer(self, phone, phone_lengths, sid, max_len=None):
851
+ g = self.emb_g(sid).unsqueeze(-1)
852
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
853
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
854
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
855
+ o = self.dec((z * x_mask)[:, :, :max_len], g=g)
856
+ return o, x_mask, (z, z_p, m_p, logs_p)
857
+
858
+
859
+ class SynthesizerTrnMs768NSFsid_nono(nn.Module):
860
+ def __init__(
861
+ self,
862
+ spec_channels,
863
+ segment_size,
864
+ inter_channels,
865
+ hidden_channels,
866
+ filter_channels,
867
+ n_heads,
868
+ n_layers,
869
+ kernel_size,
870
+ p_dropout,
871
+ resblock,
872
+ resblock_kernel_sizes,
873
+ resblock_dilation_sizes,
874
+ upsample_rates,
875
+ upsample_initial_channel,
876
+ upsample_kernel_sizes,
877
+ spk_embed_dim,
878
+ gin_channels,
879
+ sr=None,
880
+ **kwargs
881
+ ):
882
+ super().__init__()
883
+ self.spec_channels = spec_channels
884
+ self.inter_channels = inter_channels
885
+ self.hidden_channels = hidden_channels
886
+ self.filter_channels = filter_channels
887
+ self.n_heads = n_heads
888
+ self.n_layers = n_layers
889
+ self.kernel_size = kernel_size
890
+ self.p_dropout = p_dropout
891
+ self.resblock = resblock
892
+ self.resblock_kernel_sizes = resblock_kernel_sizes
893
+ self.resblock_dilation_sizes = resblock_dilation_sizes
894
+ self.upsample_rates = upsample_rates
895
+ self.upsample_initial_channel = upsample_initial_channel
896
+ self.upsample_kernel_sizes = upsample_kernel_sizes
897
+ self.segment_size = segment_size
898
+ self.gin_channels = gin_channels
899
+ # self.hop_length = hop_length#
900
+ self.spk_embed_dim = spk_embed_dim
901
+ self.enc_p = TextEncoder768(
902
+ inter_channels,
903
+ hidden_channels,
904
+ filter_channels,
905
+ n_heads,
906
+ n_layers,
907
+ kernel_size,
908
+ p_dropout,
909
+ f0=False,
910
+ )
911
+ self.dec = Generator(
912
+ inter_channels,
913
+ resblock,
914
+ resblock_kernel_sizes,
915
+ resblock_dilation_sizes,
916
+ upsample_rates,
917
+ upsample_initial_channel,
918
+ upsample_kernel_sizes,
919
+ gin_channels=gin_channels,
920
+ )
921
+ self.enc_q = PosteriorEncoder(
922
+ spec_channels,
923
+ inter_channels,
924
+ hidden_channels,
925
+ 5,
926
+ 1,
927
+ 16,
928
+ gin_channels=gin_channels,
929
+ )
930
+ self.flow = ResidualCouplingBlock(
931
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
932
+ )
933
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
934
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
935
+
936
+ def remove_weight_norm(self):
937
+ self.dec.remove_weight_norm()
938
+ self.flow.remove_weight_norm()
939
+ self.enc_q.remove_weight_norm()
940
+
941
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
942
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
943
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
944
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
945
+ z_p = self.flow(z, y_mask, g=g)
946
+ z_slice, ids_slice = commons.rand_slice_segments(
947
+ z, y_lengths, self.segment_size
948
+ )
949
+ o = self.dec(z_slice, g=g)
950
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
951
+
952
+ def infer(self, phone, phone_lengths, sid, max_len=None):
953
+ g = self.emb_g(sid).unsqueeze(-1)
954
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
955
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
956
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
957
+ o = self.dec((z * x_mask)[:, :, :max_len], g=g)
958
+ return o, x_mask, (z, z_p, m_p, logs_p)
959
+
960
+
961
+ class MultiPeriodDiscriminator(torch.nn.Module):
962
+ def __init__(self, use_spectral_norm=False):
963
+ super(MultiPeriodDiscriminator, self).__init__()
964
+ periods = [2, 3, 5, 7, 11, 17]
965
+ # periods = [3, 5, 7, 11, 17, 23, 37]
966
+
967
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
968
+ discs = discs + [
969
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
970
+ ]
971
+ self.discriminators = nn.ModuleList(discs)
972
+
973
+ def forward(self, y, y_hat):
974
+ y_d_rs = [] #
975
+ y_d_gs = []
976
+ fmap_rs = []
977
+ fmap_gs = []
978
+ for i, d in enumerate(self.discriminators):
979
+ y_d_r, fmap_r = d(y)
980
+ y_d_g, fmap_g = d(y_hat)
981
+ # for j in range(len(fmap_r)):
982
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
983
+ y_d_rs.append(y_d_r)
984
+ y_d_gs.append(y_d_g)
985
+ fmap_rs.append(fmap_r)
986
+ fmap_gs.append(fmap_g)
987
+
988
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
989
+
990
+
991
+ class MultiPeriodDiscriminatorV2(torch.nn.Module):
992
+ def __init__(self, use_spectral_norm=False):
993
+ super(MultiPeriodDiscriminatorV2, self).__init__()
994
+ # periods = [2, 3, 5, 7, 11, 17]
995
+ periods = [2, 3, 5, 7, 11, 17, 23, 37]
996
+
997
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
998
+ discs = discs + [
999
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
1000
+ ]
1001
+ self.discriminators = nn.ModuleList(discs)
1002
+
1003
+ def forward(self, y, y_hat):
1004
+ y_d_rs = [] #
1005
+ y_d_gs = []
1006
+ fmap_rs = []
1007
+ fmap_gs = []
1008
+ for i, d in enumerate(self.discriminators):
1009
+ y_d_r, fmap_r = d(y)
1010
+ y_d_g, fmap_g = d(y_hat)
1011
+ # for j in range(len(fmap_r)):
1012
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1013
+ y_d_rs.append(y_d_r)
1014
+ y_d_gs.append(y_d_g)
1015
+ fmap_rs.append(fmap_r)
1016
+ fmap_gs.append(fmap_g)
1017
+
1018
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1019
+
1020
+
1021
+ class DiscriminatorS(torch.nn.Module):
1022
+ def __init__(self, use_spectral_norm=False):
1023
+ super(DiscriminatorS, self).__init__()
1024
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1025
+ self.convs = nn.ModuleList(
1026
+ [
1027
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1028
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1029
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1030
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1031
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1032
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1033
+ ]
1034
+ )
1035
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1036
+
1037
+ def forward(self, x):
1038
+ fmap = []
1039
+
1040
+ for l in self.convs:
1041
+ x = l(x)
1042
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1043
+ fmap.append(x)
1044
+ x = self.conv_post(x)
1045
+ fmap.append(x)
1046
+ x = torch.flatten(x, 1, -1)
1047
+
1048
+ return x, fmap
1049
+
1050
+
1051
+ class DiscriminatorP(torch.nn.Module):
1052
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1053
+ super(DiscriminatorP, self).__init__()
1054
+ self.period = period
1055
+ self.use_spectral_norm = use_spectral_norm
1056
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1057
+ self.convs = nn.ModuleList(
1058
+ [
1059
+ norm_f(
1060
+ Conv2d(
1061
+ 1,
1062
+ 32,
1063
+ (kernel_size, 1),
1064
+ (stride, 1),
1065
+ padding=(get_padding(kernel_size, 1), 0),
1066
+ )
1067
+ ),
1068
+ norm_f(
1069
+ Conv2d(
1070
+ 32,
1071
+ 128,
1072
+ (kernel_size, 1),
1073
+ (stride, 1),
1074
+ padding=(get_padding(kernel_size, 1), 0),
1075
+ )
1076
+ ),
1077
+ norm_f(
1078
+ Conv2d(
1079
+ 128,
1080
+ 512,
1081
+ (kernel_size, 1),
1082
+ (stride, 1),
1083
+ padding=(get_padding(kernel_size, 1), 0),
1084
+ )
1085
+ ),
1086
+ norm_f(
1087
+ Conv2d(
1088
+ 512,
1089
+ 1024,
1090
+ (kernel_size, 1),
1091
+ (stride, 1),
1092
+ padding=(get_padding(kernel_size, 1), 0),
1093
+ )
1094
+ ),
1095
+ norm_f(
1096
+ Conv2d(
1097
+ 1024,
1098
+ 1024,
1099
+ (kernel_size, 1),
1100
+ 1,
1101
+ padding=(get_padding(kernel_size, 1), 0),
1102
+ )
1103
+ ),
1104
+ ]
1105
+ )
1106
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1107
+
1108
+ def forward(self, x):
1109
+ fmap = []
1110
+
1111
+ # 1d to 2d
1112
+ b, c, t = x.shape
1113
+ if t % self.period != 0: # pad first
1114
+ n_pad = self.period - (t % self.period)
1115
+ x = F.pad(x, (0, n_pad), "reflect")
1116
+ t = t + n_pad
1117
+ x = x.view(b, c, t // self.period, self.period)
1118
+
1119
+ for l in self.convs:
1120
+ x = l(x)
1121
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1122
+ fmap.append(x)
1123
+ x = self.conv_post(x)
1124
+ fmap.append(x)
1125
+ x = torch.flatten(x, 1, -1)
1126
+
1127
+ return x, fmap
lib/modules.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import remove_weight_norm
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+
10
+ from lib.commons import init_weights, get_padding, fused_add_tanh_sigmoid_multiply
11
+ from lib.transforms import piecewise_rational_quadratic_transform
12
+
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels,
56
+ hidden_channels,
57
+ kernel_size,
58
+ padding=kernel_size // 2,
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = p_dropout
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_1 = nn.ModuleList()
103
+ self.norms_2 = nn.ModuleList()
104
+ for i in range(n_layers):
105
+ dilation = kernel_size**i
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ self.convs_sep.append(
108
+ nn.Conv1d(
109
+ channels,
110
+ channels,
111
+ kernel_size,
112
+ groups=channels,
113
+ dilation=dilation,
114
+ padding=padding,
115
+ )
116
+ )
117
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
+ self.norms_1.append(LayerNorm(channels))
119
+ self.norms_2.append(LayerNorm(channels))
120
+
121
+ def forward(self, x, x_mask, g=None):
122
+ if g is not None:
123
+ x = x + g
124
+ for i in range(self.n_layers):
125
+ y = self.convs_sep[i](x * x_mask)
126
+ y = self.norms_1[i](y)
127
+ y = F.gelu(y)
128
+ y = self.convs_1x1[i](y)
129
+ y = self.norms_2[i](y)
130
+ y = F.gelu(y)
131
+ y = self.drop(y)
132
+ x = x + y
133
+ return x * x_mask
134
+
135
+
136
+ class WN(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ hidden_channels,
140
+ kernel_size,
141
+ dilation_rate,
142
+ n_layers,
143
+ gin_channels=0,
144
+ p_dropout=0,
145
+ ):
146
+ super(WN, self).__init__()
147
+ assert kernel_size % 2 == 1
148
+ self.hidden_channels = hidden_channels
149
+ self.kernel_size = (kernel_size,)
150
+ self.dilation_rate = dilation_rate
151
+ self.n_layers = n_layers
152
+ self.gin_channels = gin_channels
153
+ self.p_dropout = p_dropout
154
+
155
+ self.in_layers = torch.nn.ModuleList()
156
+ self.res_skip_layers = torch.nn.ModuleList()
157
+ self.drop = nn.Dropout(p_dropout)
158
+
159
+ if gin_channels != 0:
160
+ cond_layer = torch.nn.Conv1d(
161
+ gin_channels, 2 * hidden_channels * n_layers, 1
162
+ )
163
+ self.cond_layer = torch.nn.utils.parametrizations.weight_norm(
164
+ cond_layer, name="weight"
165
+ )
166
+
167
+ for i in range(n_layers):
168
+ dilation = dilation_rate**i
169
+ padding = int((kernel_size * dilation - dilation) / 2)
170
+ in_layer = torch.nn.Conv1d(
171
+ hidden_channels,
172
+ 2 * hidden_channels,
173
+ kernel_size,
174
+ dilation=dilation,
175
+ padding=padding,
176
+ )
177
+ in_layer = torch.nn.utils.parametrizations.weight_norm(
178
+ in_layer, name="weight"
179
+ )
180
+ self.in_layers.append(in_layer)
181
+
182
+ # last one is not necessary
183
+ if i < n_layers - 1:
184
+ res_skip_channels = 2 * hidden_channels
185
+ else:
186
+ res_skip_channels = hidden_channels
187
+
188
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
189
+ res_skip_layer = torch.nn.utils.parametrizations.weight_norm(
190
+ res_skip_layer, name="weight"
191
+ )
192
+ self.res_skip_layers.append(res_skip_layer)
193
+
194
+ def forward(self, x, x_mask, g=None, **kwargs):
195
+ output = torch.zeros_like(x)
196
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
197
+
198
+ if g is not None:
199
+ g = self.cond_layer(g)
200
+
201
+ for i in range(self.n_layers):
202
+ x_in = self.in_layers[i](x)
203
+ if g is not None:
204
+ cond_offset = i * 2 * self.hidden_channels
205
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
206
+ else:
207
+ g_l = torch.zeros_like(x_in)
208
+
209
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
210
+ acts = self.drop(acts)
211
+
212
+ res_skip_acts = self.res_skip_layers[i](acts)
213
+ if i < self.n_layers - 1:
214
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
215
+ x = (x + res_acts) * x_mask
216
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
217
+ else:
218
+ output = output + res_skip_acts
219
+ return output * x_mask
220
+
221
+ def remove_weight_norm(self):
222
+ if self.gin_channels != 0:
223
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
224
+ for l in self.in_layers:
225
+ torch.nn.utils.remove_weight_norm(l)
226
+ for l in self.res_skip_layers:
227
+ torch.nn.utils.remove_weight_norm(l)
228
+
229
+
230
+ class ResBlock1(torch.nn.Module):
231
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
232
+ super(ResBlock1, self).__init__()
233
+ self.convs1 = nn.ModuleList(
234
+ [
235
+ weight_norm(
236
+ Conv1d(
237
+ channels,
238
+ channels,
239
+ kernel_size,
240
+ 1,
241
+ dilation=dilation[0],
242
+ padding=get_padding(kernel_size, dilation[0]),
243
+ )
244
+ ),
245
+ weight_norm(
246
+ Conv1d(
247
+ channels,
248
+ channels,
249
+ kernel_size,
250
+ 1,
251
+ dilation=dilation[1],
252
+ padding=get_padding(kernel_size, dilation[1]),
253
+ )
254
+ ),
255
+ weight_norm(
256
+ Conv1d(
257
+ channels,
258
+ channels,
259
+ kernel_size,
260
+ 1,
261
+ dilation=dilation[2],
262
+ padding=get_padding(kernel_size, dilation[2]),
263
+ )
264
+ ),
265
+ ]
266
+ )
267
+ self.convs1.apply(init_weights)
268
+
269
+ self.convs2 = nn.ModuleList(
270
+ [
271
+ weight_norm(
272
+ Conv1d(
273
+ channels,
274
+ channels,
275
+ kernel_size,
276
+ 1,
277
+ dilation=1,
278
+ padding=get_padding(kernel_size, 1),
279
+ )
280
+ ),
281
+ weight_norm(
282
+ Conv1d(
283
+ channels,
284
+ channels,
285
+ kernel_size,
286
+ 1,
287
+ dilation=1,
288
+ padding=get_padding(kernel_size, 1),
289
+ )
290
+ ),
291
+ weight_norm(
292
+ Conv1d(
293
+ channels,
294
+ channels,
295
+ kernel_size,
296
+ 1,
297
+ dilation=1,
298
+ padding=get_padding(kernel_size, 1),
299
+ )
300
+ ),
301
+ ]
302
+ )
303
+ self.convs2.apply(init_weights)
304
+
305
+ def forward(self, x, x_mask=None):
306
+ for c1, c2 in zip(self.convs1, self.convs2):
307
+ xt = F.leaky_relu(x, LRELU_SLOPE)
308
+ if x_mask is not None:
309
+ xt = xt * x_mask
310
+ xt = c1(xt)
311
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
312
+ if x_mask is not None:
313
+ xt = xt * x_mask
314
+ xt = c2(xt)
315
+ x = xt + x
316
+ if x_mask is not None:
317
+ x = x * x_mask
318
+ return x
319
+
320
+ def remove_weight_norm(self):
321
+ for l in self.convs1:
322
+ remove_weight_norm(l)
323
+ for l in self.convs2:
324
+ remove_weight_norm(l)
325
+
326
+
327
+ class ResBlock2(torch.nn.Module):
328
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
329
+ super(ResBlock2, self).__init__()
330
+ self.convs = nn.ModuleList(
331
+ [
332
+ weight_norm(
333
+ Conv1d(
334
+ channels,
335
+ channels,
336
+ kernel_size,
337
+ 1,
338
+ dilation=dilation[0],
339
+ padding=get_padding(kernel_size, dilation[0]),
340
+ )
341
+ ),
342
+ weight_norm(
343
+ Conv1d(
344
+ channels,
345
+ channels,
346
+ kernel_size,
347
+ 1,
348
+ dilation=dilation[1],
349
+ padding=get_padding(kernel_size, dilation[1]),
350
+ )
351
+ ),
352
+ ]
353
+ )
354
+ self.convs.apply(init_weights)
355
+
356
+ def forward(self, x, x_mask=None):
357
+ for c in self.convs:
358
+ xt = F.leaky_relu(x, LRELU_SLOPE)
359
+ if x_mask is not None:
360
+ xt = xt * x_mask
361
+ xt = c(xt)
362
+ x = xt + x
363
+ if x_mask is not None:
364
+ x = x * x_mask
365
+ return x
366
+
367
+ def remove_weight_norm(self):
368
+ for l in self.convs:
369
+ remove_weight_norm(l)
370
+
371
+
372
+ class Log(nn.Module):
373
+ def forward(self, x, x_mask, reverse=False, **kwargs):
374
+ if not reverse:
375
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
376
+ logdet = torch.sum(-y, [1, 2])
377
+ return y, logdet
378
+ else:
379
+ x = torch.exp(x) * x_mask
380
+ return x
381
+
382
+
383
+ class Flip(nn.Module):
384
+ def forward(self, x, *args, reverse=False, **kwargs):
385
+ x = torch.flip(x, [1])
386
+ if not reverse:
387
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
388
+ return x, logdet
389
+ else:
390
+ return x
391
+
392
+
393
+ class ElementwiseAffine(nn.Module):
394
+ def __init__(self, channels):
395
+ super().__init__()
396
+ self.channels = channels
397
+ self.m = nn.Parameter(torch.zeros(channels, 1))
398
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
399
+
400
+ def forward(self, x, x_mask, reverse=False, **kwargs):
401
+ if not reverse:
402
+ y = self.m + torch.exp(self.logs) * x
403
+ y = y * x_mask
404
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
405
+ return y, logdet
406
+ else:
407
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
408
+ return x
409
+
410
+
411
+ class ResidualCouplingLayer(nn.Module):
412
+ def __init__(
413
+ self,
414
+ channels,
415
+ hidden_channels,
416
+ kernel_size,
417
+ dilation_rate,
418
+ n_layers,
419
+ p_dropout=0,
420
+ gin_channels=0,
421
+ mean_only=False,
422
+ ):
423
+ assert channels % 2 == 0, "channels should be divisible by 2"
424
+ super().__init__()
425
+ self.channels = channels
426
+ self.hidden_channels = hidden_channels
427
+ self.kernel_size = kernel_size
428
+ self.dilation_rate = dilation_rate
429
+ self.n_layers = n_layers
430
+ self.half_channels = channels // 2
431
+ self.mean_only = mean_only
432
+
433
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
434
+ self.enc = WN(
435
+ hidden_channels,
436
+ kernel_size,
437
+ dilation_rate,
438
+ n_layers,
439
+ p_dropout=p_dropout,
440
+ gin_channels=gin_channels,
441
+ )
442
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
443
+ self.post.weight.data.zero_()
444
+ self.post.bias.data.zero_()
445
+
446
+ def forward(self, x, x_mask, g=None, reverse=False):
447
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
448
+ h = self.pre(x0) * x_mask
449
+ h = self.enc(h, x_mask, g=g)
450
+ stats = self.post(h) * x_mask
451
+ if not self.mean_only:
452
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
453
+ else:
454
+ m = stats
455
+ logs = torch.zeros_like(m)
456
+
457
+ if not reverse:
458
+ x1 = m + x1 * torch.exp(logs) * x_mask
459
+ x = torch.cat([x0, x1], 1)
460
+ logdet = torch.sum(logs, [1, 2])
461
+ return x, logdet
462
+ else:
463
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
464
+ x = torch.cat([x0, x1], 1)
465
+ return x
466
+
467
+ def remove_weight_norm(self):
468
+ self.enc.remove_weight_norm()
469
+
470
+
471
+ class ConvFlow(nn.Module):
472
+ def __init__(
473
+ self,
474
+ in_channels,
475
+ filter_channels,
476
+ kernel_size,
477
+ n_layers,
478
+ num_bins=10,
479
+ tail_bound=5.0,
480
+ ):
481
+ super().__init__()
482
+ self.in_channels = in_channels
483
+ self.filter_channels = filter_channels
484
+ self.kernel_size = kernel_size
485
+ self.n_layers = n_layers
486
+ self.num_bins = num_bins
487
+ self.tail_bound = tail_bound
488
+ self.half_channels = in_channels // 2
489
+
490
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
491
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
492
+ self.proj = nn.Conv1d(
493
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
494
+ )
495
+ self.proj.weight.data.zero_()
496
+ self.proj.bias.data.zero_()
497
+
498
+ def forward(self, x, x_mask, g=None, reverse=False):
499
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
500
+ h = self.pre(x0)
501
+ h = self.convs(h, x_mask, g=g)
502
+ h = self.proj(h) * x_mask
503
+
504
+ b, c, t = x0.shape
505
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
506
+
507
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
508
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
509
+ self.filter_channels
510
+ )
511
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
512
+
513
+ x1, logabsdet = piecewise_rational_quadratic_transform(
514
+ x1,
515
+ unnormalized_widths,
516
+ unnormalized_heights,
517
+ unnormalized_derivatives,
518
+ inverse=reverse,
519
+ tails="linear",
520
+ tail_bound=self.tail_bound,
521
+ )
522
+
523
+ x = torch.cat([x0, x1], 1) * x_mask
524
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
525
+ if not reverse:
526
+ return x, logdet
527
+ else:
528
+ return x
lib/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
model_loader.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ import requests
5
+ import zipfile
6
+
7
+ from lib.models import (
8
+ SynthesizerTrnMs256NSFsid,
9
+ SynthesizerTrnMs256NSFsid_nono,
10
+ SynthesizerTrnMs768NSFsid,
11
+ SynthesizerTrnMs768NSFsid_nono,
12
+ )
13
+ from src.config import Config
14
+ from src.vc_infer_pipeline import VC
15
+
16
+ logging.getLogger("fairseq").setLevel(logging.WARNING)
17
+
18
+
19
+ class ModelLoader:
20
+ def __init__(self):
21
+ self.model_root = os.path.join(os.getcwd(), "weights")
22
+ self.config = Config()
23
+ self.model_list = [
24
+ d
25
+ for d in os.listdir(self.model_root)
26
+ if os.path.isdir(os.path.join(self.model_root, d))
27
+ ]
28
+ if len(self.model_list) == 0:
29
+ raise ValueError("No model found in `weights` folder")
30
+
31
+ self.model_name = ""
32
+ self.model_list.sort()
33
+
34
+ self.tgt_sr = None
35
+ self.net_g = None
36
+ self.vc = None
37
+ self.version = None
38
+ self.index_file = None
39
+ self.if_f0 = None
40
+
41
+ def _load_from_zip_url(self, url):
42
+ response = requests.get(url)
43
+ file_name = os.path.join(
44
+ self.model_root, os.path.basename(url[: url.index(".zip") + 4])
45
+ )
46
+ model_name = os.path.basename(file_name).replace(".zip", "")
47
+ print(f"Extraacting Model: {model_name}")
48
+
49
+ if response.status_code == 200:
50
+ with open(file_name, "wb") as file:
51
+ file.write(response.content)
52
+
53
+ with zipfile.ZipFile(file_name, "r") as zip_ref:
54
+ zip_ref.extractall(os.path.join(self.model_root, model_name))
55
+ os.remove(file_name)
56
+ else:
57
+ print("Could not download model: {model_name}")
58
+ return model_name
59
+
60
+ def load(self, model_name):
61
+ if "http" in model_name:
62
+ model_name = self._load_from_zip_url(model_name)
63
+
64
+ pth_files = [
65
+ os.path.join(self.model_root, model_name, f)
66
+ for f in os.listdir(os.path.join(self.model_root, model_name))
67
+ if f.endswith(".pth")
68
+ ]
69
+ if len(pth_files) == 0:
70
+ raise ValueError(f"No pth file found in {self.model_root}/{model_name}")
71
+
72
+ self.model_name = model_name
73
+ pth_path = pth_files[0]
74
+ print(f"Loading {pth_path}, model: {model_name}")
75
+
76
+ cpt = torch.load(pth_path, map_location="cpu")
77
+ self.tgt_sr = cpt["config"][-1]
78
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
79
+ self.if_f0 = cpt.get("f0", 1)
80
+ self.version = cpt.get("version", "v1")
81
+
82
+ if self.version == "v1":
83
+ if self.if_f0 == 1:
84
+ self.net_g = SynthesizerTrnMs256NSFsid(
85
+ *cpt["config"], is_half=self.config.is_half
86
+ )
87
+ else:
88
+ self.net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
89
+ elif self.version == "v2":
90
+ if self.if_f0 == 1:
91
+ self.net_g = SynthesizerTrnMs768NSFsid(
92
+ *cpt["config"], is_half=self.config.is_half
93
+ )
94
+ else:
95
+ self.net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
96
+ else:
97
+ raise ValueError("Unknown version")
98
+
99
+ del self.net_g.enc_q
100
+ self.net_g.load_state_dict(cpt["weight"], strict=False)
101
+ print("Model loaded")
102
+ self.net_g.eval().to(self.config.device)
103
+
104
+ if self.config.is_half:
105
+ self.net_g = self.net_g.half()
106
+ else:
107
+ self.net_g = self.net_g.float()
108
+
109
+ self.vc = VC(self.tgt_sr, self.config)
110
+
111
+ index_files = [
112
+ os.path.join(self.model_root, model_name, f)
113
+ for f in os.listdir(os.path.join(self.model_root, model_name))
114
+ if f.endswith(".index")
115
+ ]
116
+
117
+ if len(index_files) == 0:
118
+ print("No index file found")
119
+ self.index_file = ""
120
+ else:
121
+ self.index_file = index_files[0]
122
+ print(f"Index file found: {self.index_file}")
123
+
124
+ def load_hubert(self):
125
+ from fairseq import checkpoint_utils
126
+
127
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
128
+ ["weights/hubert_base.pt"],
129
+ suffix="",
130
+ )
131
+ self.hubert_model = models[0]
132
+ self.hubert_model = self.hubert_model.to(self.config.device)
133
+
134
+ if self.config.is_half:
135
+ self.hubert_model = self.hubert_model.half()
136
+ else:
137
+ self.hubert_model = self.hubert_model.float()
138
+
139
+ return self.hubert_model.eval()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ edge_tts==6.1.7
2
+ fairseq==0.12.2
3
+ faiss_cpu==1.7.4
4
+ gradio==3.38.0
5
+ librosa==0.9.1
6
+ numpy==1.22.4
7
+ praat-parselmouth==0.4.3
8
+ pyworld==0.3.4
9
+ torchcrepe==0.0.21
10
+ scikit-learn==1.3.0
11
+ gradio==3.38.0
12
+ gradio_client==0.8.1
src/config.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+
5
+ class Config:
6
+ def __init__(self):
7
+ self.device = "cuda:0"
8
+ self.is_half = True
9
+ self.n_cpu = 0
10
+ self.gpu_name = None
11
+ self.gpu_mem = None
12
+ self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
13
+
14
+ # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
15
+ # check `getattr` and try it for compatibility
16
+ @staticmethod
17
+ def has_mps() -> bool:
18
+ if not torch.backends.mps.is_available():
19
+ return False
20
+ try:
21
+ torch.zeros(1).to(torch.device("mps"))
22
+ return True
23
+ except Exception:
24
+ return False
25
+
26
+ def device_config(self) -> tuple:
27
+ if torch.cuda.is_available():
28
+ i_device = int(self.device.split(":")[-1])
29
+ self.gpu_name = torch.cuda.get_device_name(i_device)
30
+ if (
31
+ ("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
32
+ or "P40" in self.gpu_name.upper()
33
+ or "1060" in self.gpu_name
34
+ or "1070" in self.gpu_name
35
+ or "1080" in self.gpu_name
36
+ ):
37
+ print("Found GPU", self.gpu_name, ", force to fp32")
38
+ self.is_half = False
39
+ else:
40
+ print("Found GPU", self.gpu_name)
41
+ self.gpu_mem = int(
42
+ torch.cuda.get_device_properties(i_device).total_memory
43
+ / 1024
44
+ / 1024
45
+ / 1024
46
+ + 0.4
47
+ )
48
+ elif self.has_mps():
49
+ print("No supported Nvidia GPU found, use MPS instead")
50
+ self.device = "mps"
51
+ self.is_half = False
52
+ else:
53
+ print("No supported Nvidia GPU found, use CPU instead")
54
+ self.device = "cpu"
55
+ self.is_half = False
56
+
57
+ if self.n_cpu == 0:
58
+ self.n_cpu = os.cpu_count()
59
+
60
+ if self.is_half:
61
+ # 6G GPU Memory
62
+ x_pad = 3
63
+ x_query = 10
64
+ x_center = 60
65
+ x_max = 65
66
+ else:
67
+ # 5G GPU Memory
68
+ x_pad = 1
69
+ x_query = 6
70
+ x_center = 38
71
+ x_max = 41
72
+
73
+ if self.gpu_mem != None and self.gpu_mem <= 4:
74
+ x_pad = 1
75
+ x_query = 5
76
+ x_center = 30
77
+ x_max = 32
78
+
79
+ return x_pad, x_query, x_center, x_max
src/rmvpe.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class BiGRU(nn.Module):
8
+ def __init__(self, input_features, hidden_features, num_layers):
9
+ super(BiGRU, self).__init__()
10
+ self.gru = nn.GRU(
11
+ input_features,
12
+ hidden_features,
13
+ num_layers=num_layers,
14
+ batch_first=True,
15
+ bidirectional=True,
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.gru(x)[0]
20
+
21
+
22
+ class ConvBlockRes(nn.Module):
23
+ def __init__(self, in_channels, out_channels, momentum=0.01):
24
+ super(ConvBlockRes, self).__init__()
25
+ self.conv = nn.Sequential(
26
+ nn.Conv2d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=(3, 3),
30
+ stride=(1, 1),
31
+ padding=(1, 1),
32
+ bias=False,
33
+ ),
34
+ nn.BatchNorm2d(out_channels, momentum=momentum),
35
+ nn.ReLU(),
36
+ nn.Conv2d(
37
+ in_channels=out_channels,
38
+ out_channels=out_channels,
39
+ kernel_size=(3, 3),
40
+ stride=(1, 1),
41
+ padding=(1, 1),
42
+ bias=False,
43
+ ),
44
+ nn.BatchNorm2d(out_channels, momentum=momentum),
45
+ nn.ReLU(),
46
+ )
47
+ if in_channels != out_channels:
48
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
49
+ self.is_shortcut = True
50
+ else:
51
+ self.is_shortcut = False
52
+
53
+ def forward(self, x):
54
+ if self.is_shortcut:
55
+ return self.conv(x) + self.shortcut(x)
56
+ else:
57
+ return self.conv(x) + x
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ in_size,
65
+ n_encoders,
66
+ kernel_size,
67
+ n_blocks,
68
+ out_channels=16,
69
+ momentum=0.01,
70
+ ):
71
+ super(Encoder, self).__init__()
72
+ self.n_encoders = n_encoders
73
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
74
+ self.layers = nn.ModuleList()
75
+ self.latent_channels = []
76
+ for i in range(self.n_encoders):
77
+ self.layers.append(
78
+ ResEncoderBlock(
79
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
80
+ )
81
+ )
82
+ self.latent_channels.append([out_channels, in_size])
83
+ in_channels = out_channels
84
+ out_channels *= 2
85
+ in_size //= 2
86
+ self.out_size = in_size
87
+ self.out_channel = out_channels
88
+
89
+ def forward(self, x):
90
+ concat_tensors = []
91
+ x = self.bn(x)
92
+ for i in range(self.n_encoders):
93
+ _, x = self.layers[i](x)
94
+ concat_tensors.append(_)
95
+ return x, concat_tensors
96
+
97
+
98
+ class ResEncoderBlock(nn.Module):
99
+ def __init__(
100
+ self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
101
+ ):
102
+ super(ResEncoderBlock, self).__init__()
103
+ self.n_blocks = n_blocks
104
+ self.conv = nn.ModuleList()
105
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
106
+ for i in range(n_blocks - 1):
107
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
108
+ self.kernel_size = kernel_size
109
+ if self.kernel_size is not None:
110
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
111
+
112
+ def forward(self, x):
113
+ for i in range(self.n_blocks):
114
+ x = self.conv[i](x)
115
+ if self.kernel_size is not None:
116
+ return x, self.pool(x)
117
+ else:
118
+ return x
119
+
120
+
121
+ class Intermediate(nn.Module): #
122
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
123
+ super(Intermediate, self).__init__()
124
+ self.n_inters = n_inters
125
+ self.layers = nn.ModuleList()
126
+ self.layers.append(
127
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
128
+ )
129
+ for i in range(self.n_inters - 1):
130
+ self.layers.append(
131
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
132
+ )
133
+
134
+ def forward(self, x):
135
+ for i in range(self.n_inters):
136
+ x = self.layers[i](x)
137
+ return x
138
+
139
+
140
+ class ResDecoderBlock(nn.Module):
141
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
142
+ super(ResDecoderBlock, self).__init__()
143
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
144
+ self.n_blocks = n_blocks
145
+ self.conv1 = nn.Sequential(
146
+ nn.ConvTranspose2d(
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ kernel_size=(3, 3),
150
+ stride=stride,
151
+ padding=(1, 1),
152
+ output_padding=out_padding,
153
+ bias=False,
154
+ ),
155
+ nn.BatchNorm2d(out_channels, momentum=momentum),
156
+ nn.ReLU(),
157
+ )
158
+ self.conv2 = nn.ModuleList()
159
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
160
+ for i in range(n_blocks - 1):
161
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
162
+
163
+ def forward(self, x, concat_tensor):
164
+ x = self.conv1(x)
165
+ x = torch.cat((x, concat_tensor), dim=1)
166
+ for i in range(self.n_blocks):
167
+ x = self.conv2[i](x)
168
+ return x
169
+
170
+
171
+ class Decoder(nn.Module):
172
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
173
+ super(Decoder, self).__init__()
174
+ self.layers = nn.ModuleList()
175
+ self.n_decoders = n_decoders
176
+ for i in range(self.n_decoders):
177
+ out_channels = in_channels // 2
178
+ self.layers.append(
179
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
180
+ )
181
+ in_channels = out_channels
182
+
183
+ def forward(self, x, concat_tensors):
184
+ for i in range(self.n_decoders):
185
+ x = self.layers[i](x, concat_tensors[-1 - i])
186
+ return x
187
+
188
+
189
+ class DeepUnet(nn.Module):
190
+ def __init__(
191
+ self,
192
+ kernel_size,
193
+ n_blocks,
194
+ en_de_layers=5,
195
+ inter_layers=4,
196
+ in_channels=1,
197
+ en_out_channels=16,
198
+ ):
199
+ super(DeepUnet, self).__init__()
200
+ self.encoder = Encoder(
201
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
202
+ )
203
+ self.intermediate = Intermediate(
204
+ self.encoder.out_channel // 2,
205
+ self.encoder.out_channel,
206
+ inter_layers,
207
+ n_blocks,
208
+ )
209
+ self.decoder = Decoder(
210
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
211
+ )
212
+
213
+ def forward(self, x):
214
+ x, concat_tensors = self.encoder(x)
215
+ x = self.intermediate(x)
216
+ x = self.decoder(x, concat_tensors)
217
+ return x
218
+
219
+
220
+ class E2E(nn.Module):
221
+ def __init__(
222
+ self,
223
+ n_blocks,
224
+ n_gru,
225
+ kernel_size,
226
+ en_de_layers=5,
227
+ inter_layers=4,
228
+ in_channels=1,
229
+ en_out_channels=16,
230
+ ):
231
+ super(E2E, self).__init__()
232
+ self.unet = DeepUnet(
233
+ kernel_size,
234
+ n_blocks,
235
+ en_de_layers,
236
+ inter_layers,
237
+ in_channels,
238
+ en_out_channels,
239
+ )
240
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
241
+ self.fc = nn.Sequential(
242
+ BiGRU(3 * 128, 256, n_gru),
243
+ nn.Linear(512, 360),
244
+ nn.Dropout(0.25),
245
+ nn.Sigmoid(),
246
+ )
247
+
248
+ def forward(self, mel):
249
+ mel = mel.transpose(-1, -2).unsqueeze(1)
250
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
251
+ x = self.fc(x)
252
+ return x
253
+
254
+
255
+ from librosa.filters import mel
256
+
257
+
258
+ class MelSpectrogram(torch.nn.Module):
259
+ def __init__(
260
+ self,
261
+ is_half,
262
+ n_mel_channels,
263
+ sampling_rate,
264
+ win_length,
265
+ hop_length,
266
+ n_fft=None,
267
+ mel_fmin=0,
268
+ mel_fmax=None,
269
+ clamp=1e-5,
270
+ ):
271
+ super().__init__()
272
+ n_fft = win_length if n_fft is None else n_fft
273
+ self.hann_window = {}
274
+ mel_basis = mel(
275
+ sr=sampling_rate,
276
+ n_fft=n_fft,
277
+ n_mels=n_mel_channels,
278
+ fmin=mel_fmin,
279
+ fmax=mel_fmax,
280
+ htk=True,
281
+ )
282
+ mel_basis = torch.from_numpy(mel_basis).float()
283
+ self.register_buffer("mel_basis", mel_basis)
284
+ self.n_fft = win_length if n_fft is None else n_fft
285
+ self.hop_length = hop_length
286
+ self.win_length = win_length
287
+ self.sampling_rate = sampling_rate
288
+ self.n_mel_channels = n_mel_channels
289
+ self.clamp = clamp
290
+ self.is_half = is_half
291
+
292
+ def forward(self, audio, keyshift=0, speed=1, center=True):
293
+ factor = 2 ** (keyshift / 12)
294
+ n_fft_new = int(np.round(self.n_fft * factor))
295
+ win_length_new = int(np.round(self.win_length * factor))
296
+ hop_length_new = int(np.round(self.hop_length * speed))
297
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
298
+ if keyshift_key not in self.hann_window:
299
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
300
+ audio.device
301
+ )
302
+ fft = torch.stft(
303
+ audio,
304
+ n_fft=n_fft_new,
305
+ hop_length=hop_length_new,
306
+ win_length=win_length_new,
307
+ window=self.hann_window[keyshift_key],
308
+ center=center,
309
+ return_complex=True,
310
+ )
311
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
312
+ if keyshift != 0:
313
+ size = self.n_fft // 2 + 1
314
+ resize = magnitude.size(1)
315
+ if resize < size:
316
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
317
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
318
+ mel_output = torch.matmul(self.mel_basis, magnitude)
319
+ if self.is_half == True:
320
+ mel_output = mel_output.half()
321
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
322
+ return log_mel_spec
323
+
324
+
325
+ class RMVPE:
326
+ def __init__(self, model_path, is_half, device=None):
327
+ self.resample_kernel = {}
328
+ model = E2E(4, 1, (2, 2))
329
+ ckpt = torch.load(model_path, map_location="cpu")
330
+ model.load_state_dict(ckpt)
331
+ model.eval()
332
+ if is_half == True:
333
+ model = model.half()
334
+ self.model = model
335
+ self.resample_kernel = {}
336
+ self.is_half = is_half
337
+ if device is None:
338
+ device = "cuda" if torch.cuda.is_available() else "cpu"
339
+ self.device = device
340
+ self.mel_extractor = MelSpectrogram(
341
+ is_half, 128, 16000, 1024, 160, None, 30, 8000
342
+ ).to(device)
343
+ self.model = self.model.to(device)
344
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
345
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
346
+
347
+ def mel2hidden(self, mel):
348
+ with torch.no_grad():
349
+ n_frames = mel.shape[-1]
350
+ mel = F.pad(
351
+ mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect"
352
+ )
353
+ hidden = self.model(mel)
354
+ return hidden[:, :n_frames]
355
+
356
+ def decode(self, hidden, thred=0.03):
357
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
358
+ f0 = 10 * (2 ** (cents_pred / 1200))
359
+ f0[f0 == 10] = 0
360
+ # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
361
+ return f0
362
+
363
+ def infer_from_audio(self, audio, thred=0.03):
364
+ audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
365
+ # torch.cuda.synchronize()
366
+ # t0=ttime()
367
+ mel = self.mel_extractor(audio, center=True)
368
+ # torch.cuda.synchronize()
369
+ # t1=ttime()
370
+ hidden = self.mel2hidden(mel)
371
+ # torch.cuda.synchronize()
372
+ # t2=ttime()
373
+ hidden = hidden.squeeze(0).cpu().numpy()
374
+ if self.is_half == True:
375
+ hidden = hidden.astype("float32")
376
+ f0 = self.decode(hidden, thred=thred)
377
+ # torch.cuda.synchronize()
378
+ # t3=ttime()
379
+ # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
380
+ return f0
381
+
382
+ def to_local_average_cents(self, salience, thred=0.05):
383
+ # t0 = ttime()
384
+ center = np.argmax(salience, axis=1) # 帧长#index
385
+ salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
386
+ # t1 = ttime()
387
+ center += 4
388
+ todo_salience = []
389
+ todo_cents_mapping = []
390
+ starts = center - 4
391
+ ends = center + 5
392
+ for idx in range(salience.shape[0]):
393
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
394
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
395
+ # t2 = ttime()
396
+ todo_salience = np.array(todo_salience) # 帧长,9
397
+ todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
398
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
399
+ weight_sum = np.sum(todo_salience, 1) # 帧长
400
+ devided = product_sum / weight_sum # 帧长
401
+ # t3 = ttime()
402
+ maxx = np.max(salience, axis=1) # 帧长
403
+ devided[maxx <= thred] = 0
404
+ # t4 = ttime()
405
+ # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
406
+ return devided
407
+
408
+
409
+ # if __name__ == '__main__':
410
+ # audio, sampling_rate = sf.read("卢本伟语录~1.wav")
411
+ # if len(audio.shape) > 1:
412
+ # audio = librosa.to_mono(audio.transpose(1, 0))
413
+ # audio_bak = audio.copy()
414
+ # if sampling_rate != 16000:
415
+ # audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
416
+ # model_path = "/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/test-RMVPE/weights/rmvpe_llc_half.pt"
417
+ # thred = 0.03 # 0.01
418
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
419
+ # rmvpe = RMVPE(model_path,is_half=False, device=device)
420
+ # t0=ttime()
421
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
422
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
423
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
424
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
425
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
426
+ # t1=ttime()
427
+ # print(f0.shape,t1-t0)
src/vc_infer_pipeline.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import parselmouth
5
+ import torch
6
+ from time import time as ttime
7
+ import scipy.signal as signal
8
+ import pyworld, os, traceback, faiss, librosa, torchcrepe
9
+ from scipy import signal
10
+ from functools import lru_cache
11
+
12
+ now_dir = os.getcwd()
13
+ sys.path.append(now_dir)
14
+
15
+ bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
16
+
17
+ input_audio_path2wav = {}
18
+
19
+
20
+ @lru_cache
21
+ def cache_harvest_f0(input_audio_path, fs, f0max, f0min, frame_period):
22
+ audio = input_audio_path2wav[input_audio_path]
23
+ f0, t = pyworld.harvest(
24
+ audio,
25
+ fs=fs,
26
+ f0_ceil=f0max,
27
+ f0_floor=f0min,
28
+ frame_period=frame_period,
29
+ )
30
+ f0 = pyworld.stonemask(audio, f0, t, fs)
31
+ return f0
32
+
33
+
34
+ def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比
35
+ # print(data1.max(),data2.max())
36
+ rms1 = librosa.feature.rms(
37
+ y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2
38
+ ) # 每半秒一个点
39
+ rms2 = librosa.feature.rms(y=data2, frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
40
+ rms1 = torch.from_numpy(rms1)
41
+ rms1 = torch.nn.functional.interpolate(
42
+ rms1.unsqueeze(0), size=data2.shape[0], mode="linear"
43
+ ).squeeze()
44
+ rms2 = torch.from_numpy(rms2)
45
+ rms2 = torch.nn.functional.interpolate(
46
+ rms2.unsqueeze(0), size=data2.shape[0], mode="linear"
47
+ ).squeeze()
48
+ rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
49
+ data2 *= (
50
+ torch.pow(rms1, torch.tensor(1 - rate))
51
+ * torch.pow(rms2, torch.tensor(rate - 1))
52
+ ).numpy()
53
+ return data2
54
+
55
+
56
+ class VC(object):
57
+ def __init__(self, tgt_sr, config):
58
+ self.x_pad, self.x_query, self.x_center, self.x_max, self.is_half = (
59
+ config.x_pad,
60
+ config.x_query,
61
+ config.x_center,
62
+ config.x_max,
63
+ config.is_half,
64
+ )
65
+ self.sr = 16000 # hubert输入采样率
66
+ self.window = 160 # 每帧点数
67
+ self.t_pad = self.sr * self.x_pad # 每条前后pad时间
68
+ self.t_pad_tgt = tgt_sr * self.x_pad
69
+ self.t_pad2 = self.t_pad * 2
70
+ self.t_query = self.sr * self.x_query # 查询切点前后查询时间
71
+ self.t_center = self.sr * self.x_center # 查询切点位置
72
+ self.t_max = self.sr * self.x_max # 免查询时长阈值
73
+ self.device = config.device
74
+
75
+ def get_f0(
76
+ self,
77
+ input_audio_path,
78
+ x,
79
+ p_len,
80
+ f0_up_key,
81
+ f0_method,
82
+ filter_radius,
83
+ inp_f0=None,
84
+ ):
85
+ global input_audio_path2wav
86
+ time_step = self.window / self.sr * 1000
87
+ f0_min = 50
88
+ f0_max = 1100
89
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
90
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
91
+ if f0_method == "pm":
92
+ f0 = (
93
+ parselmouth.Sound(x, self.sr)
94
+ .to_pitch_ac(
95
+ time_step=time_step / 1000,
96
+ voicing_threshold=0.6,
97
+ pitch_floor=f0_min,
98
+ pitch_ceiling=f0_max,
99
+ )
100
+ .selected_array["frequency"]
101
+ )
102
+ pad_size = (p_len - len(f0) + 1) // 2
103
+ if pad_size > 0 or p_len - len(f0) - pad_size > 0:
104
+ f0 = np.pad(
105
+ f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
106
+ )
107
+ elif f0_method == "harvest":
108
+ input_audio_path2wav[input_audio_path] = x.astype(np.double)
109
+ f0 = cache_harvest_f0(input_audio_path, self.sr, f0_max, f0_min, 10)
110
+ if filter_radius > 2:
111
+ f0 = signal.medfilt(f0, 3)
112
+ elif f0_method == "crepe":
113
+ model = "full"
114
+ # Pick a batch size that doesn't cause memory errors on your gpu
115
+ batch_size = 512
116
+ # Compute pitch using first gpu
117
+ audio = torch.tensor(np.copy(x))[None].float()
118
+ f0, pd = torchcrepe.predict(
119
+ audio,
120
+ self.sr,
121
+ self.window,
122
+ f0_min,
123
+ f0_max,
124
+ model,
125
+ batch_size=batch_size,
126
+ device=self.device,
127
+ return_periodicity=True,
128
+ )
129
+ pd = torchcrepe.filter.median(pd, 3)
130
+ f0 = torchcrepe.filter.mean(f0, 3)
131
+ f0[pd < 0.1] = 0
132
+ f0 = f0[0].cpu().numpy()
133
+ elif f0_method == "rmvpe":
134
+ if hasattr(self, "model_rmvpe") == False:
135
+ from src.rmvpe import RMVPE
136
+
137
+ print("loading rmvpe model")
138
+ self.model_rmvpe = RMVPE(
139
+ "weights/rmvpe.pt", is_half=self.is_half, device=self.device
140
+ )
141
+ f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
142
+ f0 *= pow(2, f0_up_key / 12)
143
+ # with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
144
+ tf0 = self.sr // self.window # 每秒f0点数
145
+ if inp_f0 is not None:
146
+ delta_t = np.round(
147
+ (inp_f0[:, 0].max() - inp_f0[:, 0].min()) * tf0 + 1
148
+ ).astype("int16")
149
+ replace_f0 = np.interp(
150
+ list(range(delta_t)), inp_f0[:, 0] * 100, inp_f0[:, 1]
151
+ )
152
+ shape = f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)].shape[0]
153
+ f0[self.x_pad * tf0 : self.x_pad * tf0 + len(replace_f0)] = replace_f0[
154
+ :shape
155
+ ]
156
+ # with open("test_opt.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
157
+ f0bak = f0.copy()
158
+ f0_mel = 1127 * np.log(1 + f0 / 700)
159
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (
160
+ f0_mel_max - f0_mel_min
161
+ ) + 1
162
+ f0_mel[f0_mel <= 1] = 1
163
+ f0_mel[f0_mel > 255] = 255
164
+ f0_coarse = np.rint(f0_mel).astype(np.int)
165
+ return f0_coarse, f0bak # 1-0
166
+
167
+ def vc(
168
+ self,
169
+ model,
170
+ net_g,
171
+ sid,
172
+ audio0,
173
+ pitch,
174
+ pitchf,
175
+ times,
176
+ index,
177
+ big_npy,
178
+ index_rate,
179
+ version,
180
+ protect,
181
+ ): # ,file_index,file_big_npy
182
+ feats = torch.from_numpy(audio0)
183
+ if self.is_half:
184
+ feats = feats.half()
185
+ else:
186
+ feats = feats.float()
187
+ if feats.dim() == 2: # double channels
188
+ feats = feats.mean(-1)
189
+ assert feats.dim() == 1, feats.dim()
190
+ feats = feats.view(1, -1)
191
+ padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
192
+
193
+ inputs = {
194
+ "source": feats.to(self.device),
195
+ "padding_mask": padding_mask,
196
+ "output_layer": 9 if version == "v1" else 12,
197
+ }
198
+ t0 = ttime()
199
+ with torch.no_grad():
200
+ logits = model.extract_features(**inputs)
201
+ feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
202
+ if protect < 0.5 and pitch != None and pitchf != None:
203
+ feats0 = feats.clone()
204
+ if (
205
+ isinstance(index, type(None)) == False
206
+ and isinstance(big_npy, type(None)) == False
207
+ and index_rate != 0
208
+ ):
209
+ npy = feats[0].cpu().numpy()
210
+ if self.is_half:
211
+ npy = npy.astype("float32")
212
+
213
+ # _, I = index.search(npy, 1)
214
+ # npy = big_npy[I.squeeze()]
215
+
216
+ score, ix = index.search(npy, k=8)
217
+ weight = np.square(1 / score)
218
+ weight /= weight.sum(axis=1, keepdims=True)
219
+ npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
220
+
221
+ if self.is_half:
222
+ npy = npy.astype("float16")
223
+ feats = (
224
+ torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
225
+ + (1 - index_rate) * feats
226
+ )
227
+
228
+ feats = torch.nn.functional.interpolate(
229
+ feats.permute(0, 2, 1), scale_factor=2
230
+ ).permute(0, 2, 1)
231
+ if protect < 0.5 and pitch != None and pitchf != None:
232
+ feats0 = torch.nn.functional.interpolate(
233
+ feats0.permute(0, 2, 1), scale_factor=2
234
+ ).permute(0, 2, 1)
235
+ t1 = ttime()
236
+ p_len = audio0.shape[0] // self.window
237
+ if feats.shape[1] < p_len:
238
+ p_len = feats.shape[1]
239
+ if pitch != None and pitchf != None:
240
+ pitch = pitch[:, :p_len]
241
+ pitchf = pitchf[:, :p_len]
242
+
243
+ if protect < 0.5 and pitch != None and pitchf != None:
244
+ pitchff = pitchf.clone()
245
+ pitchff[pitchf > 0] = 1
246
+ pitchff[pitchf < 1] = protect
247
+ pitchff = pitchff.unsqueeze(-1)
248
+ feats = feats * pitchff + feats0 * (1 - pitchff)
249
+ feats = feats.to(feats0.dtype)
250
+ p_len = torch.tensor([p_len], device=self.device).long()
251
+ with torch.no_grad():
252
+ if pitch != None and pitchf != None:
253
+ audio1 = (
254
+ (net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0])
255
+ .data.cpu()
256
+ .float()
257
+ .numpy()
258
+ )
259
+ else:
260
+ audio1 = (
261
+ (net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
262
+ )
263
+ del feats, p_len, padding_mask
264
+ if torch.cuda.is_available():
265
+ torch.cuda.empty_cache()
266
+ t2 = ttime()
267
+ times[0] += t1 - t0
268
+ times[2] += t2 - t1
269
+ return audio1
270
+
271
+ def pipeline(
272
+ self,
273
+ model,
274
+ net_g,
275
+ sid,
276
+ audio,
277
+ input_audio_path,
278
+ times,
279
+ f0_up_key,
280
+ f0_method,
281
+ file_index,
282
+ # file_big_npy,
283
+ index_rate,
284
+ if_f0,
285
+ filter_radius,
286
+ tgt_sr,
287
+ resample_sr,
288
+ rms_mix_rate,
289
+ version,
290
+ protect,
291
+ f0_file=None,
292
+ ):
293
+ if (
294
+ file_index != ""
295
+ # and file_big_npy != ""
296
+ # and os.path.exists(file_big_npy) == True
297
+ and os.path.exists(file_index) == True
298
+ and index_rate != 0
299
+ ):
300
+ try:
301
+ index = faiss.read_index(file_index)
302
+ # big_npy = np.load(file_big_npy)
303
+ big_npy = index.reconstruct_n(0, index.ntotal)
304
+ except:
305
+ traceback.print_exc()
306
+ index = big_npy = None
307
+ else:
308
+ index = big_npy = None
309
+ audio = signal.filtfilt(bh, ah, audio)
310
+ audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
311
+ opt_ts = []
312
+ if audio_pad.shape[0] > self.t_max:
313
+ audio_sum = np.zeros_like(audio)
314
+ for i in range(self.window):
315
+ audio_sum += audio_pad[i : i - self.window]
316
+ for t in range(self.t_center, audio.shape[0], self.t_center):
317
+ opt_ts.append(
318
+ t
319
+ - self.t_query
320
+ + np.where(
321
+ np.abs(audio_sum[t - self.t_query : t + self.t_query])
322
+ == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min()
323
+ )[0][0]
324
+ )
325
+ s = 0
326
+ audio_opt = []
327
+ t = None
328
+ t1 = ttime()
329
+ audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
330
+ p_len = audio_pad.shape[0] // self.window
331
+ inp_f0 = None
332
+ if hasattr(f0_file, "name") == True:
333
+ try:
334
+ with open(f0_file.name, "r") as f:
335
+ lines = f.read().strip("\n").split("\n")
336
+ inp_f0 = []
337
+ for line in lines:
338
+ inp_f0.append([float(i) for i in line.split(",")])
339
+ inp_f0 = np.array(inp_f0, dtype="float32")
340
+ except:
341
+ traceback.print_exc()
342
+ sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
343
+ pitch, pitchf = None, None
344
+ if if_f0 == 1:
345
+ pitch, pitchf = self.get_f0(
346
+ input_audio_path,
347
+ audio_pad,
348
+ p_len,
349
+ f0_up_key,
350
+ f0_method,
351
+ filter_radius,
352
+ inp_f0,
353
+ )
354
+ pitch = pitch[:p_len]
355
+ pitchf = pitchf[:p_len]
356
+ if self.device == "mps":
357
+ pitchf = pitchf.astype(np.float32)
358
+ pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
359
+ pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
360
+ t2 = ttime()
361
+ times[1] += t2 - t1
362
+ for t in opt_ts:
363
+ t = t // self.window * self.window
364
+ if if_f0 == 1:
365
+ audio_opt.append(
366
+ self.vc(
367
+ model,
368
+ net_g,
369
+ sid,
370
+ audio_pad[s : t + self.t_pad2 + self.window],
371
+ pitch[:, s // self.window : (t + self.t_pad2) // self.window],
372
+ pitchf[:, s // self.window : (t + self.t_pad2) // self.window],
373
+ times,
374
+ index,
375
+ big_npy,
376
+ index_rate,
377
+ version,
378
+ protect,
379
+ )[self.t_pad_tgt : -self.t_pad_tgt]
380
+ )
381
+ else:
382
+ audio_opt.append(
383
+ self.vc(
384
+ model,
385
+ net_g,
386
+ sid,
387
+ audio_pad[s : t + self.t_pad2 + self.window],
388
+ None,
389
+ None,
390
+ times,
391
+ index,
392
+ big_npy,
393
+ index_rate,
394
+ version,
395
+ protect,
396
+ )[self.t_pad_tgt : -self.t_pad_tgt]
397
+ )
398
+ s = t
399
+ if if_f0 == 1:
400
+ audio_opt.append(
401
+ self.vc(
402
+ model,
403
+ net_g,
404
+ sid,
405
+ audio_pad[t:],
406
+ pitch[:, t // self.window :] if t is not None else pitch,
407
+ pitchf[:, t // self.window :] if t is not None else pitchf,
408
+ times,
409
+ index,
410
+ big_npy,
411
+ index_rate,
412
+ version,
413
+ protect,
414
+ )[self.t_pad_tgt : -self.t_pad_tgt]
415
+ )
416
+ else:
417
+ audio_opt.append(
418
+ self.vc(
419
+ model,
420
+ net_g,
421
+ sid,
422
+ audio_pad[t:],
423
+ None,
424
+ None,
425
+ times,
426
+ index,
427
+ big_npy,
428
+ index_rate,
429
+ version,
430
+ protect,
431
+ )[self.t_pad_tgt : -self.t_pad_tgt]
432
+ )
433
+ audio_opt = np.concatenate(audio_opt)
434
+ if rms_mix_rate != 1:
435
+ audio_opt = change_rms(audio, 16000, audio_opt, tgt_sr, rms_mix_rate)
436
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
437
+ audio_opt = librosa.resample(
438
+ audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
439
+ )
440
+ audio_max = np.abs(audio_opt).max() / 0.99
441
+ max_int16 = 32768
442
+ if audio_max > 1:
443
+ max_int16 /= audio_max
444
+ audio_opt = (audio_opt * max_int16).astype(np.int16)
445
+ del pitch, pitchf, sid
446
+ if torch.cuda.is_available():
447
+ torch.cuda.empty_cache()
448
+ return audio_opt
weights/.gitkeep ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ *.mp3
4
+ hubert_base.pt
5
+ rmvpe.pt
weights/char1/metadata.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc8b7d9d55e55fc7c492217983aef600f1b2be4167afc4ea5436326729966f8e
3
+ size 2895
weights/char1/model.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f54918d9f21b6ccf33c4e8e29a6dca0cb34909caab98011481651026b6bbdb9a
3
+ size 270186979
weights/char1/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3198f1f28943e63f21bbc98878cfbc322e3f30b78555bc39c5e6fcc265d6609d
3
+ size 57589983