Spaces:
Running
on
Zero
Running
on
Zero
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +44 -0
- app.py +253 -0
- bert/bert_models.json +14 -0
- bert/chinese-roberta-wwm-ext-large/.gitattributes +9 -0
- bert/chinese-roberta-wwm-ext-large/README.md +57 -0
- bert/chinese-roberta-wwm-ext-large/added_tokens.json +1 -0
- bert/chinese-roberta-wwm-ext-large/config.json +28 -0
- bert/chinese-roberta-wwm-ext-large/pytorch_model.bin +3 -0
- bert/chinese-roberta-wwm-ext-large/special_tokens_map.json +1 -0
- bert/chinese-roberta-wwm-ext-large/tokenizer.json +0 -0
- bert/chinese-roberta-wwm-ext-large/tokenizer_config.json +1 -0
- bert/chinese-roberta-wwm-ext-large/vocab.txt +0 -0
- bert/deberta-v2-large-japanese-char-wwm/.gitattributes +34 -0
- bert/deberta-v2-large-japanese-char-wwm/README.md +89 -0
- bert/deberta-v2-large-japanese-char-wwm/config.json +37 -0
- bert/deberta-v2-large-japanese-char-wwm/pytorch_model.bin +3 -0
- bert/deberta-v2-large-japanese-char-wwm/special_tokens_map.json +7 -0
- bert/deberta-v2-large-japanese-char-wwm/tokenizer_config.json +19 -0
- bert/deberta-v2-large-japanese-char-wwm/vocab.txt +0 -0
- bert/deberta-v3-large/.gitattributes +27 -0
- bert/deberta-v3-large/README.md +93 -0
- bert/deberta-v3-large/config.json +22 -0
- bert/deberta-v3-large/generator_config.json +22 -0
- bert/deberta-v3-large/pytorch_model.bin +3 -0
- bert/deberta-v3-large/pytorch_model.bin.bin +3 -0
- bert/deberta-v3-large/spm.model +3 -0
- bert/deberta-v3-large/tokenizer_config.json +4 -0
- chupa_examples.txt +0 -0
- model_assets/chupa_1/chupa_1spk_e1000_s194312.safetensors +3 -0
- model_assets/chupa_1/config.json +87 -0
- model_assets/chupa_1/style_vectors.npy +3 -0
- requirements.txt +23 -0
- style_bert_vits2/.editorconfig +15 -0
- style_bert_vits2/__init__.py +0 -0
- style_bert_vits2/constants.py +48 -0
- style_bert_vits2/logging.py +15 -0
- style_bert_vits2/models/__init__.py +0 -0
- style_bert_vits2/models/attentions.py +491 -0
- style_bert_vits2/models/commons.py +223 -0
- style_bert_vits2/models/hyper_parameters.py +129 -0
- style_bert_vits2/models/infer.py +308 -0
- style_bert_vits2/models/models.py +1102 -0
- style_bert_vits2/models/models_jp_extra.py +1157 -0
- style_bert_vits2/models/modules.py +642 -0
- style_bert_vits2/models/monotonic_alignment.py +89 -0
- style_bert_vits2/models/transforms.py +215 -0
- style_bert_vits2/models/utils/__init__.py +264 -0
- style_bert_vits2/models/utils/checkpoints.py +202 -0
- style_bert_vits2/models/utils/safetensors.py +91 -0
- style_bert_vits2/nlp/__init__.py +120 -0
.gitignore
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
venv/
|
3 |
+
.venv/
|
4 |
+
dist/
|
5 |
+
.coverage
|
6 |
+
.ipynb_checkpoints/
|
7 |
+
.ruff_cache/
|
8 |
+
|
9 |
+
/*.yml
|
10 |
+
!/default_config.yml
|
11 |
+
# /bert/*/*.bin
|
12 |
+
# /bert/*/*.h5
|
13 |
+
# /bert/*/*.model
|
14 |
+
# /bert/*/*.safetensors
|
15 |
+
# /bert/*/*.msgpack
|
16 |
+
|
17 |
+
/configs/paths.yml
|
18 |
+
|
19 |
+
/pretrained/*.safetensors
|
20 |
+
/pretrained/*.pth
|
21 |
+
|
22 |
+
/pretrained_jp_extra/*.safetensors
|
23 |
+
/pretrained_jp_extra/*.pth
|
24 |
+
|
25 |
+
/slm/*/*.bin
|
26 |
+
|
27 |
+
/scripts/test/
|
28 |
+
/scripts/lib/
|
29 |
+
/scripts/Style-Bert-VITS2/
|
30 |
+
/scripts/sbv2/
|
31 |
+
*.zip
|
32 |
+
*.csv
|
33 |
+
*.bak
|
34 |
+
/mos_results/
|
35 |
+
|
36 |
+
safetensors.ipynb
|
37 |
+
*.wav
|
38 |
+
/static/
|
39 |
+
|
40 |
+
# pyopenjtalk's dictionary
|
41 |
+
*.dic
|
42 |
+
|
43 |
+
playground.ipynb
|
44 |
+
playgrounds/
|
app.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
from pathlib import Path
|
3 |
+
import gradio as gr
|
4 |
+
import random
|
5 |
+
from style_bert_vits2.constants import (
|
6 |
+
DEFAULT_LENGTH,
|
7 |
+
DEFAULT_LINE_SPLIT,
|
8 |
+
DEFAULT_NOISE,
|
9 |
+
DEFAULT_NOISEW,
|
10 |
+
DEFAULT_SPLIT_INTERVAL,
|
11 |
+
)
|
12 |
+
from style_bert_vits2.logging import logger
|
13 |
+
from style_bert_vits2.models.infer import InvalidToneError
|
14 |
+
from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk
|
15 |
+
from style_bert_vits2.tts_model import TTSModelHolder
|
16 |
+
|
17 |
+
|
18 |
+
pyopenjtalk.initialize_worker()
|
19 |
+
|
20 |
+
example_file = "chupa_examples.txt"
|
21 |
+
|
22 |
+
initial_text = (
|
23 |
+
"ちゅぱ、ちゅるる、ぢゅ、んく、れーれゅれろれろれろ、じゅぽぽぽぽぽ……ちゅううう!"
|
24 |
+
)
|
25 |
+
|
26 |
+
with open(example_file, "r", encoding="utf-8") as f:
|
27 |
+
examples = f.read().splitlines()
|
28 |
+
|
29 |
+
|
30 |
+
def get_random_text() -> str:
|
31 |
+
return random.choice(examples)
|
32 |
+
|
33 |
+
|
34 |
+
initial_md = """
|
35 |
+
# チュパ音合成デモ
|
36 |
+
|
37 |
+
2024-07-07: initial ver
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
def make_interactive():
|
42 |
+
return gr.update(interactive=True, value="音声合成")
|
43 |
+
|
44 |
+
|
45 |
+
def make_non_interactive():
|
46 |
+
return gr.update(interactive=False, value="音声合成(モデルをロードしてください)")
|
47 |
+
|
48 |
+
|
49 |
+
def gr_util(item):
|
50 |
+
if item == "プリセットから選ぶ":
|
51 |
+
return (gr.update(visible=True), gr.Audio(visible=False, value=None))
|
52 |
+
else:
|
53 |
+
return (gr.update(visible=False), gr.update(visible=True))
|
54 |
+
|
55 |
+
|
56 |
+
def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks:
|
57 |
+
def tts_fn(
|
58 |
+
model_name,
|
59 |
+
model_path,
|
60 |
+
text,
|
61 |
+
language,
|
62 |
+
sdp_ratio,
|
63 |
+
noise_scale,
|
64 |
+
noise_scale_w,
|
65 |
+
length_scale,
|
66 |
+
line_split,
|
67 |
+
split_interval,
|
68 |
+
speaker,
|
69 |
+
):
|
70 |
+
model_holder.get_model(model_name, model_path)
|
71 |
+
assert model_holder.current_model is not None
|
72 |
+
|
73 |
+
speaker_id = model_holder.current_model.spk2id[speaker]
|
74 |
+
|
75 |
+
start_time = datetime.datetime.now()
|
76 |
+
|
77 |
+
try:
|
78 |
+
sr, audio = model_holder.current_model.infer(
|
79 |
+
text=text,
|
80 |
+
language=language,
|
81 |
+
sdp_ratio=sdp_ratio,
|
82 |
+
noise=noise_scale,
|
83 |
+
noise_w=noise_scale_w,
|
84 |
+
length=length_scale,
|
85 |
+
line_split=line_split,
|
86 |
+
split_interval=split_interval,
|
87 |
+
speaker_id=speaker_id,
|
88 |
+
)
|
89 |
+
except InvalidToneError as e:
|
90 |
+
logger.error(f"Tone error: {e}")
|
91 |
+
return f"Error: アクセント指定が不正です:\n{e}", None
|
92 |
+
except ValueError as e:
|
93 |
+
logger.error(f"Value error: {e}")
|
94 |
+
return f"Error: {e}", None
|
95 |
+
|
96 |
+
end_time = datetime.datetime.now()
|
97 |
+
duration = (end_time - start_time).total_seconds()
|
98 |
+
|
99 |
+
message = f"Success, time: {duration} seconds."
|
100 |
+
return message, (sr, audio)
|
101 |
+
|
102 |
+
def get_model_files(model_name: str):
|
103 |
+
return [str(f) for f in model_holder.model_files_dict[model_name]]
|
104 |
+
|
105 |
+
model_names = model_holder.model_names
|
106 |
+
if len(model_names) == 0:
|
107 |
+
logger.error(
|
108 |
+
f"モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。"
|
109 |
+
)
|
110 |
+
with gr.Blocks() as app:
|
111 |
+
gr.Markdown(
|
112 |
+
f"Error: モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。"
|
113 |
+
)
|
114 |
+
return app
|
115 |
+
|
116 |
+
initial_pth_files = get_model_files(model_names[0])
|
117 |
+
model = model_holder.get_model(model_names[0], initial_pth_files[0])
|
118 |
+
speakers = list(model.spk2id.keys())
|
119 |
+
|
120 |
+
with gr.Blocks(theme="ParityError/Anime") as app:
|
121 |
+
gr.Markdown(initial_md)
|
122 |
+
with gr.Row():
|
123 |
+
with gr.Column():
|
124 |
+
with gr.Row():
|
125 |
+
with gr.Column(scale=3):
|
126 |
+
model_name = gr.Dropdown(
|
127 |
+
label="モデル一覧",
|
128 |
+
choices=model_names,
|
129 |
+
value=model_names[0],
|
130 |
+
)
|
131 |
+
model_path = gr.Dropdown(
|
132 |
+
label="モデルファイル",
|
133 |
+
choices=initial_pth_files,
|
134 |
+
value=initial_pth_files[0],
|
135 |
+
)
|
136 |
+
refresh_button = gr.Button("更新", scale=1, visible=False)
|
137 |
+
load_button = gr.Button("ロード", scale=1, variant="primary")
|
138 |
+
with gr.Row():
|
139 |
+
text_input = gr.TextArea(
|
140 |
+
label="テキスト", value=initial_text, scale=3
|
141 |
+
)
|
142 |
+
random_button = gr.Button("例から選ぶ 🎲", scale=1)
|
143 |
+
random_button.click(get_random_text, outputs=[text_input])
|
144 |
+
with gr.Row():
|
145 |
+
length_scale = gr.Slider(
|
146 |
+
minimum=0.1,
|
147 |
+
maximum=2,
|
148 |
+
value=DEFAULT_LENGTH,
|
149 |
+
step=0.1,
|
150 |
+
label="生成音声の長さ(Length)",
|
151 |
+
)
|
152 |
+
sdp_ratio = gr.Slider(
|
153 |
+
minimum=0,
|
154 |
+
maximum=1,
|
155 |
+
value=1,
|
156 |
+
step=0.1,
|
157 |
+
label="SDP Ratio",
|
158 |
+
)
|
159 |
+
line_split = gr.Checkbox(
|
160 |
+
label="改行で分けて生成(分けたほうが感情が乗ります)",
|
161 |
+
value=DEFAULT_LINE_SPLIT,
|
162 |
+
visible=False,
|
163 |
+
)
|
164 |
+
split_interval = gr.Slider(
|
165 |
+
minimum=0.0,
|
166 |
+
maximum=2,
|
167 |
+
value=DEFAULT_SPLIT_INTERVAL,
|
168 |
+
step=0.1,
|
169 |
+
label="改行ごとに挟む無音の長さ(秒)",
|
170 |
+
)
|
171 |
+
line_split.change(
|
172 |
+
lambda x: (gr.Slider(visible=x)),
|
173 |
+
inputs=[line_split],
|
174 |
+
outputs=[split_interval],
|
175 |
+
)
|
176 |
+
language = gr.Dropdown(
|
177 |
+
choices=["JP"], value="JP", label="Language", visible=False
|
178 |
+
)
|
179 |
+
speaker = gr.Dropdown(label="話者", choices=speakers, value=speakers[0])
|
180 |
+
with gr.Accordion(label="詳細設定", open=True):
|
181 |
+
noise_scale = gr.Slider(
|
182 |
+
minimum=0.1,
|
183 |
+
maximum=2,
|
184 |
+
value=DEFAULT_NOISE,
|
185 |
+
step=0.1,
|
186 |
+
label="Noise",
|
187 |
+
)
|
188 |
+
noise_scale_w = gr.Slider(
|
189 |
+
minimum=0.1,
|
190 |
+
maximum=2,
|
191 |
+
value=DEFAULT_NOISEW,
|
192 |
+
step=0.1,
|
193 |
+
label="Noise_W",
|
194 |
+
)
|
195 |
+
with gr.Column():
|
196 |
+
tts_button = gr.Button("音声合成", variant="primary")
|
197 |
+
text_output = gr.Textbox(label="情報")
|
198 |
+
audio_output = gr.Audio(label="結果")
|
199 |
+
|
200 |
+
tts_button.click(
|
201 |
+
tts_fn,
|
202 |
+
inputs=[
|
203 |
+
model_name,
|
204 |
+
model_path,
|
205 |
+
text_input,
|
206 |
+
language,
|
207 |
+
sdp_ratio,
|
208 |
+
noise_scale,
|
209 |
+
noise_scale_w,
|
210 |
+
length_scale,
|
211 |
+
line_split,
|
212 |
+
split_interval,
|
213 |
+
speaker,
|
214 |
+
],
|
215 |
+
outputs=[text_output, audio_output],
|
216 |
+
)
|
217 |
+
|
218 |
+
model_name.change(
|
219 |
+
model_holder.update_model_files_for_gradio,
|
220 |
+
inputs=[model_name],
|
221 |
+
outputs=[model_path],
|
222 |
+
)
|
223 |
+
|
224 |
+
model_path.change(make_non_interactive, outputs=[tts_button])
|
225 |
+
|
226 |
+
refresh_button.click(
|
227 |
+
model_holder.update_model_names_for_gradio,
|
228 |
+
outputs=[model_name, model_path, tts_button],
|
229 |
+
)
|
230 |
+
style = gr.Dropdown(label="スタイル", choices=[], visible=False)
|
231 |
+
|
232 |
+
load_button.click(
|
233 |
+
model_holder.get_model_for_gradio,
|
234 |
+
inputs=[model_name, model_path],
|
235 |
+
outputs=[style, tts_button, speaker],
|
236 |
+
)
|
237 |
+
|
238 |
+
return app
|
239 |
+
|
240 |
+
|
241 |
+
if __name__ == "__main__":
|
242 |
+
import torch
|
243 |
+
|
244 |
+
from style_bert_vits2.constants import Languages
|
245 |
+
from style_bert_vits2.nlp import bert_models
|
246 |
+
|
247 |
+
bert_models.load_model(Languages.JP)
|
248 |
+
bert_models.load_tokenizer(Languages.JP)
|
249 |
+
|
250 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
251 |
+
model_holder = TTSModelHolder(Path("model_assets"), device)
|
252 |
+
app = create_inference_app(model_holder)
|
253 |
+
app.launch(inbrowser=True)
|
bert/bert_models.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"deberta-v2-large-japanese-char-wwm": {
|
3 |
+
"repo_id": "ku-nlp/deberta-v2-large-japanese-char-wwm",
|
4 |
+
"files": ["pytorch_model.bin"]
|
5 |
+
},
|
6 |
+
"chinese-roberta-wwm-ext-large": {
|
7 |
+
"repo_id": "hfl/chinese-roberta-wwm-ext-large",
|
8 |
+
"files": ["pytorch_model.bin"]
|
9 |
+
},
|
10 |
+
"deberta-v3-large": {
|
11 |
+
"repo_id": "microsoft/deberta-v3-large",
|
12 |
+
"files": ["spm.model", "pytorch_model.bin"]
|
13 |
+
}
|
14 |
+
}
|
bert/chinese-roberta-wwm-ext-large/.gitattributes
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
bert/chinese-roberta-wwm-ext-large/README.md
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
tags:
|
5 |
+
- bert
|
6 |
+
license: "apache-2.0"
|
7 |
+
---
|
8 |
+
|
9 |
+
# Please use 'Bert' related functions to load this model!
|
10 |
+
|
11 |
+
## Chinese BERT with Whole Word Masking
|
12 |
+
For further accelerating Chinese natural language processing, we provide **Chinese pre-trained BERT with Whole Word Masking**.
|
13 |
+
|
14 |
+
**[Pre-Training with Whole Word Masking for Chinese BERT](https://arxiv.org/abs/1906.08101)**
|
15 |
+
Yiming Cui, Wanxiang Che, Ting Liu, Bing Qin, Ziqing Yang, Shijin Wang, Guoping Hu
|
16 |
+
|
17 |
+
This repository is developed based on:https://github.com/google-research/bert
|
18 |
+
|
19 |
+
You may also interested in,
|
20 |
+
- Chinese BERT series: https://github.com/ymcui/Chinese-BERT-wwm
|
21 |
+
- Chinese MacBERT: https://github.com/ymcui/MacBERT
|
22 |
+
- Chinese ELECTRA: https://github.com/ymcui/Chinese-ELECTRA
|
23 |
+
- Chinese XLNet: https://github.com/ymcui/Chinese-XLNet
|
24 |
+
- Knowledge Distillation Toolkit - TextBrewer: https://github.com/airaria/TextBrewer
|
25 |
+
|
26 |
+
More resources by HFL: https://github.com/ymcui/HFL-Anthology
|
27 |
+
|
28 |
+
## Citation
|
29 |
+
If you find the technical report or resource is useful, please cite the following technical report in your paper.
|
30 |
+
- Primary: https://arxiv.org/abs/2004.13922
|
31 |
+
```
|
32 |
+
@inproceedings{cui-etal-2020-revisiting,
|
33 |
+
title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing",
|
34 |
+
author = "Cui, Yiming and
|
35 |
+
Che, Wanxiang and
|
36 |
+
Liu, Ting and
|
37 |
+
Qin, Bing and
|
38 |
+
Wang, Shijin and
|
39 |
+
Hu, Guoping",
|
40 |
+
booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
|
41 |
+
month = nov,
|
42 |
+
year = "2020",
|
43 |
+
address = "Online",
|
44 |
+
publisher = "Association for Computational Linguistics",
|
45 |
+
url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58",
|
46 |
+
pages = "657--668",
|
47 |
+
}
|
48 |
+
```
|
49 |
+
- Secondary: https://arxiv.org/abs/1906.08101
|
50 |
+
```
|
51 |
+
@article{chinese-bert-wwm,
|
52 |
+
title={Pre-Training with Whole Word Masking for Chinese BERT},
|
53 |
+
author={Cui, Yiming and Che, Wanxiang and Liu, Ting and Qin, Bing and Yang, Ziqing and Wang, Shijin and Hu, Guoping},
|
54 |
+
journal={arXiv preprint arXiv:1906.08101},
|
55 |
+
year={2019}
|
56 |
+
}
|
57 |
+
```
|
bert/chinese-roberta-wwm-ext-large/added_tokens.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
bert/chinese-roberta-wwm-ext-large/config.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"bos_token_id": 0,
|
7 |
+
"directionality": "bidi",
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 1024,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 4096,
|
14 |
+
"layer_norm_eps": 1e-12,
|
15 |
+
"max_position_embeddings": 512,
|
16 |
+
"model_type": "bert",
|
17 |
+
"num_attention_heads": 16,
|
18 |
+
"num_hidden_layers": 24,
|
19 |
+
"output_past": true,
|
20 |
+
"pad_token_id": 0,
|
21 |
+
"pooler_fc_size": 768,
|
22 |
+
"pooler_num_attention_heads": 12,
|
23 |
+
"pooler_num_fc_layers": 3,
|
24 |
+
"pooler_size_per_head": 128,
|
25 |
+
"pooler_type": "first_token_transform",
|
26 |
+
"type_vocab_size": 2,
|
27 |
+
"vocab_size": 21128
|
28 |
+
}
|
bert/chinese-roberta-wwm-ext-large/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd
|
3 |
+
size 1306484351
|
bert/chinese-roberta-wwm-ext-large/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
bert/chinese-roberta-wwm-ext-large/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
bert/chinese-roberta-wwm-ext-large/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"init_inputs": []}
|
bert/chinese-roberta-wwm-ext-large/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
bert/deberta-v2-large-japanese-char-wwm/.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
bert/deberta-v2-large-japanese-char-wwm/README.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: ja
|
3 |
+
license: cc-by-sa-4.0
|
4 |
+
library_name: transformers
|
5 |
+
tags:
|
6 |
+
- deberta
|
7 |
+
- deberta-v2
|
8 |
+
- fill-mask
|
9 |
+
- character
|
10 |
+
- wwm
|
11 |
+
datasets:
|
12 |
+
- wikipedia
|
13 |
+
- cc100
|
14 |
+
- oscar
|
15 |
+
metrics:
|
16 |
+
- accuracy
|
17 |
+
mask_token: "[MASK]"
|
18 |
+
widget:
|
19 |
+
- text: "京都大学で自然言語処理を[MASK][MASK]する。"
|
20 |
+
---
|
21 |
+
|
22 |
+
# Model Card for Japanese character-level DeBERTa V2 large
|
23 |
+
|
24 |
+
## Model description
|
25 |
+
|
26 |
+
This is a Japanese DeBERTa V2 large model pre-trained on Japanese Wikipedia, the Japanese portion of CC-100, and the Japanese portion of OSCAR.
|
27 |
+
This model is trained with character-level tokenization and whole word masking.
|
28 |
+
|
29 |
+
## How to use
|
30 |
+
|
31 |
+
You can use this model for masked language modeling as follows:
|
32 |
+
|
33 |
+
```python
|
34 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained('ku-nlp/deberta-v2-large-japanese-char-wwm')
|
36 |
+
model = AutoModelForMaskedLM.from_pretrained('ku-nlp/deberta-v2-large-japanese-char-wwm')
|
37 |
+
|
38 |
+
sentence = '京都大学で自然言語処理を[MASK][MASK]する。'
|
39 |
+
encoding = tokenizer(sentence, return_tensors='pt')
|
40 |
+
...
|
41 |
+
```
|
42 |
+
|
43 |
+
You can also fine-tune this model on downstream tasks.
|
44 |
+
|
45 |
+
## Tokenization
|
46 |
+
|
47 |
+
There is no need to tokenize texts in advance, and you can give raw texts to the tokenizer.
|
48 |
+
The texts are tokenized into character-level tokens by [sentencepiece](https://github.com/google/sentencepiece).
|
49 |
+
|
50 |
+
## Training data
|
51 |
+
|
52 |
+
We used the following corpora for pre-training:
|
53 |
+
|
54 |
+
- Japanese Wikipedia (as of 20221020, 3.2GB, 27M sentences, 1.3M documents)
|
55 |
+
- Japanese portion of CC-100 (85GB, 619M sentences, 66M documents)
|
56 |
+
- Japanese portion of OSCAR (54GB, 326M sentences, 25M documents)
|
57 |
+
|
58 |
+
Note that we filtered out documents annotated with "header", "footer", or "noisy" tags in OSCAR.
|
59 |
+
Also note that Japanese Wikipedia was duplicated 10 times to make the total size of the corpus comparable to that of CC-100 and OSCAR. As a result, the total size of the training data is 171GB.
|
60 |
+
|
61 |
+
## Training procedure
|
62 |
+
|
63 |
+
We first segmented texts in the corpora into words using [Juman++ 2.0.0-rc3](https://github.com/ku-nlp/jumanpp/releases/tag/v2.0.0-rc3) for whole word masking.
|
64 |
+
Then, we built a sentencepiece model with 22,012 tokens including all characters that appear in the training corpus.
|
65 |
+
|
66 |
+
We tokenized raw corpora into character-level subwords using the sentencepiece model and trained the Japanese DeBERTa model using [transformers](https://github.com/huggingface/transformers) library.
|
67 |
+
The training took 26 days using 16 NVIDIA A100-SXM4-40GB GPUs.
|
68 |
+
|
69 |
+
The following hyperparameters were used during pre-training:
|
70 |
+
|
71 |
+
- learning_rate: 1e-4
|
72 |
+
- per_device_train_batch_size: 26
|
73 |
+
- distributed_type: multi-GPU
|
74 |
+
- num_devices: 16
|
75 |
+
- gradient_accumulation_steps: 8
|
76 |
+
- total_train_batch_size: 3,328
|
77 |
+
- max_seq_length: 512
|
78 |
+
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-06
|
79 |
+
- lr_scheduler_type: linear schedule with warmup (lr = 0 at 300k steps)
|
80 |
+
- training_steps: 260,000
|
81 |
+
- warmup_steps: 10,000
|
82 |
+
|
83 |
+
The accuracy of the trained model on the masked language modeling task was 0.795.
|
84 |
+
The evaluation set consists of 5,000 randomly sampled documents from each of the training corpora.
|
85 |
+
|
86 |
+
## Acknowledgments
|
87 |
+
|
88 |
+
This work was supported by Joint Usage/Research Center for Interdisciplinary Large-scale Information Infrastructures (JHPCN) through General Collaboration Project no. jh221004, "Developing a Platform for Constructing and Sharing of Large-Scale Japanese Language Models".
|
89 |
+
For training models, we used the mdx: a platform for the data-driven future.
|
bert/deberta-v2-large-japanese-char-wwm/config.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"DebertaV2ForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_head_size": 64,
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"conv_act": "gelu",
|
8 |
+
"conv_kernel_size": 3,
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 1024,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 4096,
|
14 |
+
"layer_norm_eps": 1e-07,
|
15 |
+
"max_position_embeddings": 512,
|
16 |
+
"max_relative_positions": -1,
|
17 |
+
"model_type": "deberta-v2",
|
18 |
+
"norm_rel_ebd": "layer_norm",
|
19 |
+
"num_attention_heads": 16,
|
20 |
+
"num_hidden_layers": 24,
|
21 |
+
"pad_token_id": 0,
|
22 |
+
"pooler_dropout": 0,
|
23 |
+
"pooler_hidden_act": "gelu",
|
24 |
+
"pooler_hidden_size": 1024,
|
25 |
+
"pos_att_type": [
|
26 |
+
"p2c",
|
27 |
+
"c2p"
|
28 |
+
],
|
29 |
+
"position_biased_input": false,
|
30 |
+
"position_buckets": 256,
|
31 |
+
"relative_attention": true,
|
32 |
+
"share_att_key": true,
|
33 |
+
"torch_dtype": "float16",
|
34 |
+
"transformers_version": "4.25.1",
|
35 |
+
"type_vocab_size": 0,
|
36 |
+
"vocab_size": 22012
|
37 |
+
}
|
bert/deberta-v2-large-japanese-char-wwm/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201
|
3 |
+
size 1318456639
|
bert/deberta-v2-large-japanese-char-wwm/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
bert/deberta-v2-large-japanese-char-wwm/tokenizer_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"do_lower_case": false,
|
4 |
+
"do_subword_tokenize": true,
|
5 |
+
"do_word_tokenize": true,
|
6 |
+
"jumanpp_kwargs": null,
|
7 |
+
"mask_token": "[MASK]",
|
8 |
+
"mecab_kwargs": null,
|
9 |
+
"model_max_length": 1000000000000000019884624838656,
|
10 |
+
"never_split": null,
|
11 |
+
"pad_token": "[PAD]",
|
12 |
+
"sep_token": "[SEP]",
|
13 |
+
"special_tokens_map_file": null,
|
14 |
+
"subword_tokenizer_type": "character",
|
15 |
+
"sudachi_kwargs": null,
|
16 |
+
"tokenizer_class": "BertJapaneseTokenizer",
|
17 |
+
"unk_token": "[UNK]",
|
18 |
+
"word_tokenizer_type": "basic"
|
19 |
+
}
|
bert/deberta-v2-large-japanese-char-wwm/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
bert/deberta-v3-large/.gitattributes
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
bert/deberta-v3-large/README.md
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: en
|
3 |
+
tags:
|
4 |
+
- deberta
|
5 |
+
- deberta-v3
|
6 |
+
- fill-mask
|
7 |
+
thumbnail: https://huggingface.co/front/thumbnails/microsoft.png
|
8 |
+
license: mit
|
9 |
+
---
|
10 |
+
|
11 |
+
## DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing
|
12 |
+
|
13 |
+
[DeBERTa](https://arxiv.org/abs/2006.03654) improves the BERT and RoBERTa models using disentangled attention and enhanced mask decoder. With those two improvements, DeBERTa out perform RoBERTa on a majority of NLU tasks with 80GB training data.
|
14 |
+
|
15 |
+
In [DeBERTa V3](https://arxiv.org/abs/2111.09543), we further improved the efficiency of DeBERTa using ELECTRA-Style pre-training with Gradient Disentangled Embedding Sharing. Compared to DeBERTa, our V3 version significantly improves the model performance on downstream tasks. You can find more technique details about the new model from our [paper](https://arxiv.org/abs/2111.09543).
|
16 |
+
|
17 |
+
Please check the [official repository](https://github.com/microsoft/DeBERTa) for more implementation details and updates.
|
18 |
+
|
19 |
+
The DeBERTa V3 large model comes with 24 layers and a hidden size of 1024. It has 304M backbone parameters with a vocabulary containing 128K tokens which introduces 131M parameters in the Embedding layer. This model was trained using the 160GB data as DeBERTa V2.
|
20 |
+
|
21 |
+
|
22 |
+
#### Fine-tuning on NLU tasks
|
23 |
+
|
24 |
+
We present the dev results on SQuAD 2.0 and MNLI tasks.
|
25 |
+
|
26 |
+
| Model |Vocabulary(K)|Backbone #Params(M)| SQuAD 2.0(F1/EM) | MNLI-m/mm(ACC)|
|
27 |
+
|-------------------|----------|-------------------|-----------|----------|
|
28 |
+
| RoBERTa-large |50 |304 | 89.4/86.5 | 90.2 |
|
29 |
+
| XLNet-large |32 |- | 90.6/87.9 | 90.8 |
|
30 |
+
| DeBERTa-large |50 |- | 90.7/88.0 | 91.3 |
|
31 |
+
| **DeBERTa-v3-large**|128|304 | **91.5/89.0**| **91.8/91.9**|
|
32 |
+
|
33 |
+
|
34 |
+
#### Fine-tuning with HF transformers
|
35 |
+
|
36 |
+
```bash
|
37 |
+
#!/bin/bash
|
38 |
+
|
39 |
+
cd transformers/examples/pytorch/text-classification/
|
40 |
+
|
41 |
+
pip install datasets
|
42 |
+
export TASK_NAME=mnli
|
43 |
+
|
44 |
+
output_dir="ds_results"
|
45 |
+
|
46 |
+
num_gpus=8
|
47 |
+
|
48 |
+
batch_size=8
|
49 |
+
|
50 |
+
python -m torch.distributed.launch --nproc_per_node=${num_gpus} \
|
51 |
+
run_glue.py \
|
52 |
+
--model_name_or_path microsoft/deberta-v3-large \
|
53 |
+
--task_name $TASK_NAME \
|
54 |
+
--do_train \
|
55 |
+
--do_eval \
|
56 |
+
--evaluation_strategy steps \
|
57 |
+
--max_seq_length 256 \
|
58 |
+
--warmup_steps 50 \
|
59 |
+
--per_device_train_batch_size ${batch_size} \
|
60 |
+
--learning_rate 6e-6 \
|
61 |
+
--num_train_epochs 2 \
|
62 |
+
--output_dir $output_dir \
|
63 |
+
--overwrite_output_dir \
|
64 |
+
--logging_steps 1000 \
|
65 |
+
--logging_dir $output_dir
|
66 |
+
|
67 |
+
```
|
68 |
+
|
69 |
+
### Citation
|
70 |
+
|
71 |
+
If you find DeBERTa useful for your work, please cite the following papers:
|
72 |
+
|
73 |
+
``` latex
|
74 |
+
@misc{he2021debertav3,
|
75 |
+
title={DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing},
|
76 |
+
author={Pengcheng He and Jianfeng Gao and Weizhu Chen},
|
77 |
+
year={2021},
|
78 |
+
eprint={2111.09543},
|
79 |
+
archivePrefix={arXiv},
|
80 |
+
primaryClass={cs.CL}
|
81 |
+
}
|
82 |
+
```
|
83 |
+
|
84 |
+
``` latex
|
85 |
+
@inproceedings{
|
86 |
+
he2021deberta,
|
87 |
+
title={DEBERTA: DECODING-ENHANCED BERT WITH DISENTANGLED ATTENTION},
|
88 |
+
author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen},
|
89 |
+
booktitle={International Conference on Learning Representations},
|
90 |
+
year={2021},
|
91 |
+
url={https://openreview.net/forum?id=XPZIaotutsD}
|
92 |
+
}
|
93 |
+
```
|
bert/deberta-v3-large/config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "deberta-v2",
|
3 |
+
"attention_probs_dropout_prob": 0.1,
|
4 |
+
"hidden_act": "gelu",
|
5 |
+
"hidden_dropout_prob": 0.1,
|
6 |
+
"hidden_size": 1024,
|
7 |
+
"initializer_range": 0.02,
|
8 |
+
"intermediate_size": 4096,
|
9 |
+
"max_position_embeddings": 512,
|
10 |
+
"relative_attention": true,
|
11 |
+
"position_buckets": 256,
|
12 |
+
"norm_rel_ebd": "layer_norm",
|
13 |
+
"share_att_key": true,
|
14 |
+
"pos_att_type": "p2c|c2p",
|
15 |
+
"layer_norm_eps": 1e-7,
|
16 |
+
"max_relative_positions": -1,
|
17 |
+
"position_biased_input": false,
|
18 |
+
"num_attention_heads": 16,
|
19 |
+
"num_hidden_layers": 24,
|
20 |
+
"type_vocab_size": 0,
|
21 |
+
"vocab_size": 128100
|
22 |
+
}
|
bert/deberta-v3-large/generator_config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "deberta-v2",
|
3 |
+
"attention_probs_dropout_prob": 0.1,
|
4 |
+
"hidden_act": "gelu",
|
5 |
+
"hidden_dropout_prob": 0.1,
|
6 |
+
"hidden_size": 1024,
|
7 |
+
"initializer_range": 0.02,
|
8 |
+
"intermediate_size": 4096,
|
9 |
+
"max_position_embeddings": 512,
|
10 |
+
"relative_attention": true,
|
11 |
+
"position_buckets": 256,
|
12 |
+
"norm_rel_ebd": "layer_norm",
|
13 |
+
"share_att_key": true,
|
14 |
+
"pos_att_type": "p2c|c2p",
|
15 |
+
"layer_norm_eps": 1e-7,
|
16 |
+
"max_relative_positions": -1,
|
17 |
+
"position_biased_input": false,
|
18 |
+
"num_attention_heads": 16,
|
19 |
+
"num_hidden_layers": 12,
|
20 |
+
"type_vocab_size": 0,
|
21 |
+
"vocab_size": 128100
|
22 |
+
}
|
bert/deberta-v3-large/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5
|
3 |
+
size 873673253
|
bert/deberta-v3-large/pytorch_model.bin.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5
|
3 |
+
size 873673253
|
bert/deberta-v3-large/spm.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
|
3 |
+
size 2464616
|
bert/deberta-v3-large/tokenizer_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_lower_case": false,
|
3 |
+
"vocab_type": "spm"
|
4 |
+
}
|
chupa_examples.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_assets/chupa_1/chupa_1spk_e1000_s194312.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8af08fae399f64bbc506a4accaf6c56b0d294def6435235dbe60755728784d8c
|
3 |
+
size 251150980
|
model_assets/chupa_1/config.json
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_name": "chupa_1spk",
|
3 |
+
"train": {
|
4 |
+
"log_interval": 50,
|
5 |
+
"eval_interval": 1000,
|
6 |
+
"seed": 42,
|
7 |
+
"epochs": 1000,
|
8 |
+
"learning_rate": 0.0001,
|
9 |
+
"betas": [0.8, 0.99],
|
10 |
+
"eps": 1e-9,
|
11 |
+
"batch_size": 2,
|
12 |
+
"bf16_run": false,
|
13 |
+
"fp16_run": false,
|
14 |
+
"lr_decay": 0.99996,
|
15 |
+
"segment_size": 16384,
|
16 |
+
"init_lr_ratio": 1,
|
17 |
+
"warmup_epochs": 0,
|
18 |
+
"c_mel": 45,
|
19 |
+
"c_kl": 1.0,
|
20 |
+
"c_commit": 100,
|
21 |
+
"skip_optimizer": false,
|
22 |
+
"freeze_ZH_bert": false,
|
23 |
+
"freeze_JP_bert": false,
|
24 |
+
"freeze_EN_bert": false,
|
25 |
+
"freeze_emo": false,
|
26 |
+
"freeze_style": false,
|
27 |
+
"freeze_decoder": false
|
28 |
+
},
|
29 |
+
"data": {
|
30 |
+
"use_jp_extra": true,
|
31 |
+
"training_files": "Data/chupa_1/train.list",
|
32 |
+
"validation_files": "Data/chupa_1/val.list",
|
33 |
+
"max_wav_value": 32768.0,
|
34 |
+
"sampling_rate": 44100,
|
35 |
+
"filter_length": 2048,
|
36 |
+
"hop_length": 512,
|
37 |
+
"win_length": 2048,
|
38 |
+
"n_mel_channels": 128,
|
39 |
+
"mel_fmin": 0.0,
|
40 |
+
"mel_fmax": null,
|
41 |
+
"add_blank": true,
|
42 |
+
"n_speakers": 1,
|
43 |
+
"spk2id": {
|
44 |
+
"1": 0
|
45 |
+
},
|
46 |
+
"cleaned_text": true,
|
47 |
+
"num_styles": 1,
|
48 |
+
"style2id": {
|
49 |
+
"Neutral": 0
|
50 |
+
}
|
51 |
+
},
|
52 |
+
"model": {
|
53 |
+
"use_spk_conditioned_encoder": true,
|
54 |
+
"use_noise_scaled_mas": true,
|
55 |
+
"use_mel_posterior_encoder": false,
|
56 |
+
"use_duration_discriminator": false,
|
57 |
+
"use_wavlm_discriminator": true,
|
58 |
+
"inter_channels": 192,
|
59 |
+
"hidden_channels": 192,
|
60 |
+
"filter_channels": 768,
|
61 |
+
"n_heads": 2,
|
62 |
+
"n_layers": 6,
|
63 |
+
"kernel_size": 3,
|
64 |
+
"p_dropout": 0.1,
|
65 |
+
"resblock": "1",
|
66 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
67 |
+
"resblock_dilation_sizes": [
|
68 |
+
[1, 3, 5],
|
69 |
+
[1, 3, 5],
|
70 |
+
[1, 3, 5]
|
71 |
+
],
|
72 |
+
"upsample_rates": [8, 8, 2, 2, 2],
|
73 |
+
"upsample_initial_channel": 512,
|
74 |
+
"upsample_kernel_sizes": [16, 16, 8, 2, 2],
|
75 |
+
"n_layers_q": 3,
|
76 |
+
"use_spectral_norm": false,
|
77 |
+
"gin_channels": 512,
|
78 |
+
"slm": {
|
79 |
+
"model": "./slm/wavlm-base-plus",
|
80 |
+
"sr": 16000,
|
81 |
+
"hidden": 768,
|
82 |
+
"nlayers": 13,
|
83 |
+
"initial_channel": 64
|
84 |
+
}
|
85 |
+
},
|
86 |
+
"version": "2.6.0-JP-Extra"
|
87 |
+
}
|
model_assets/chupa_1/style_vectors.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fd42ba186b887c87b57fa66f5781f3fdf4382504d971d5338288d50b8b40461
|
3 |
+
size 1152
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cmudict
|
2 |
+
cn2an
|
3 |
+
# faster-whisper==0.10.1
|
4 |
+
g2p_en
|
5 |
+
GPUtil
|
6 |
+
gradio
|
7 |
+
jieba
|
8 |
+
# librosa==0.9.2
|
9 |
+
loguru
|
10 |
+
num2words
|
11 |
+
numpy<2
|
12 |
+
# protobuf==4.25
|
13 |
+
psutil
|
14 |
+
# punctuators
|
15 |
+
pyannote.audio>=3.1.0
|
16 |
+
# pyloudnorm
|
17 |
+
pyopenjtalk-dict
|
18 |
+
pypinyin
|
19 |
+
pyworld-prebuilt
|
20 |
+
# stable_ts
|
21 |
+
# tensorboard
|
22 |
+
torch
|
23 |
+
transformers
|
style_bert_vits2/.editorconfig
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
root = true
|
2 |
+
|
3 |
+
[*]
|
4 |
+
charset = utf-8
|
5 |
+
end_of_line = lf
|
6 |
+
insert_final_newline = true
|
7 |
+
indent_size = 4
|
8 |
+
indent_style = space
|
9 |
+
trim_trailing_whitespace = true
|
10 |
+
|
11 |
+
[*.md]
|
12 |
+
trim_trailing_whitespace = false
|
13 |
+
|
14 |
+
[*.yml]
|
15 |
+
indent_size = 2
|
style_bert_vits2/__init__.py
ADDED
File without changes
|
style_bert_vits2/constants.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from style_bert_vits2.utils.strenum import StrEnum
|
4 |
+
|
5 |
+
|
6 |
+
# Style-Bert-VITS2 のバージョン
|
7 |
+
VERSION = "2.6.0"
|
8 |
+
|
9 |
+
# Style-Bert-VITS2 のベースディレクトリ
|
10 |
+
BASE_DIR = Path(__file__).parent.parent
|
11 |
+
|
12 |
+
|
13 |
+
# 利用可能な言語
|
14 |
+
## JP-Extra モデル利用時は JP 以外の言語の音声合成はできない
|
15 |
+
class Languages(StrEnum):
|
16 |
+
JP = "JP"
|
17 |
+
EN = "EN"
|
18 |
+
ZH = "ZH"
|
19 |
+
|
20 |
+
|
21 |
+
# 言語ごとのデフォルトの BERT トークナイザーのパス
|
22 |
+
DEFAULT_BERT_TOKENIZER_PATHS = {
|
23 |
+
Languages.JP: BASE_DIR / "bert" / "deberta-v2-large-japanese-char-wwm",
|
24 |
+
Languages.EN: BASE_DIR / "bert" / "deberta-v3-large",
|
25 |
+
Languages.ZH: BASE_DIR / "bert" / "chinese-roberta-wwm-ext-large",
|
26 |
+
}
|
27 |
+
|
28 |
+
# デフォルトのユーザー辞書ディレクトリ
|
29 |
+
## style_bert_vits2.nlp.japanese.user_dict モジュールのデフォルト値として利用される
|
30 |
+
## ライブラリとしての利用などで外部のユーザー辞書を指定したい場合は、user_dict 以下の各関数の実行時、引数に辞書データファイルのパスを指定する
|
31 |
+
DEFAULT_USER_DICT_DIR = BASE_DIR / "dict_data"
|
32 |
+
|
33 |
+
# デフォルトの推論パラメータ
|
34 |
+
DEFAULT_STYLE = "Neutral"
|
35 |
+
DEFAULT_STYLE_WEIGHT = 1.0
|
36 |
+
DEFAULT_SDP_RATIO = 0.2
|
37 |
+
DEFAULT_NOISE = 0.6
|
38 |
+
DEFAULT_NOISEW = 0.8
|
39 |
+
DEFAULT_LENGTH = 1.0
|
40 |
+
DEFAULT_LINE_SPLIT = True
|
41 |
+
DEFAULT_SPLIT_INTERVAL = 0.5
|
42 |
+
DEFAULT_ASSIST_TEXT_WEIGHT = 0.7
|
43 |
+
DEFAULT_ASSIST_TEXT_WEIGHT = 1.0
|
44 |
+
|
45 |
+
# Gradio のテーマ
|
46 |
+
## Built-in theme: "default", "base", "monochrome", "soft", "glass"
|
47 |
+
## See https://huggingface.co/spaces/gradio/theme-gallery for more themes
|
48 |
+
GRADIO_THEME = "NoCrypt/miku"
|
style_bert_vits2/logging.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
|
3 |
+
from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT
|
4 |
+
|
5 |
+
|
6 |
+
# Remove all default handlers
|
7 |
+
logger.remove()
|
8 |
+
|
9 |
+
# Add a new handler
|
10 |
+
logger.add(
|
11 |
+
SAFE_STDOUT,
|
12 |
+
format="<g>{time:MM-DD HH:mm:ss}</g> |<lvl>{level:^8}</lvl>| {file}:{line} | {message}",
|
13 |
+
backtrace=True,
|
14 |
+
diagnose=True,
|
15 |
+
)
|
style_bert_vits2/models/__init__.py
ADDED
File without changes
|
style_bert_vits2/models/attentions.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from style_bert_vits2.models import commons
|
9 |
+
|
10 |
+
|
11 |
+
class LayerNorm(nn.Module):
|
12 |
+
def __init__(self, channels: int, eps: float = 1e-5) -> None:
|
13 |
+
super().__init__()
|
14 |
+
self.channels = channels
|
15 |
+
self.eps = eps
|
16 |
+
|
17 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
18 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
19 |
+
|
20 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
21 |
+
x = x.transpose(1, -1)
|
22 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
23 |
+
return x.transpose(1, -1)
|
24 |
+
|
25 |
+
|
26 |
+
@torch.jit.script # type: ignore
|
27 |
+
def fused_add_tanh_sigmoid_multiply(
|
28 |
+
input_a: torch.Tensor, input_b: torch.Tensor, n_channels: list[int]
|
29 |
+
) -> torch.Tensor:
|
30 |
+
n_channels_int = n_channels[0]
|
31 |
+
in_act = input_a + input_b
|
32 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
33 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
34 |
+
acts = t_act * s_act
|
35 |
+
return acts
|
36 |
+
|
37 |
+
|
38 |
+
class Encoder(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
hidden_channels: int,
|
42 |
+
filter_channels: int,
|
43 |
+
n_heads: int,
|
44 |
+
n_layers: int,
|
45 |
+
kernel_size: int = 1,
|
46 |
+
p_dropout: float = 0.0,
|
47 |
+
window_size: int = 4,
|
48 |
+
isflow: bool = True,
|
49 |
+
**kwargs: Any,
|
50 |
+
) -> None:
|
51 |
+
super().__init__()
|
52 |
+
self.hidden_channels = hidden_channels
|
53 |
+
self.filter_channels = filter_channels
|
54 |
+
self.n_heads = n_heads
|
55 |
+
self.n_layers = n_layers
|
56 |
+
self.kernel_size = kernel_size
|
57 |
+
self.p_dropout = p_dropout
|
58 |
+
self.window_size = window_size
|
59 |
+
# if isflow:
|
60 |
+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
61 |
+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
62 |
+
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
63 |
+
# self.gin_channels = 256
|
64 |
+
self.cond_layer_idx = self.n_layers
|
65 |
+
if "gin_channels" in kwargs:
|
66 |
+
self.gin_channels = kwargs["gin_channels"]
|
67 |
+
if self.gin_channels != 0:
|
68 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
69 |
+
# vits2 says 3rd block, so idx is 2 by default
|
70 |
+
self.cond_layer_idx = (
|
71 |
+
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
72 |
+
)
|
73 |
+
# logger.debug(self.gin_channels, self.cond_layer_idx)
|
74 |
+
assert (
|
75 |
+
self.cond_layer_idx < self.n_layers
|
76 |
+
), "cond_layer_idx should be less than n_layers"
|
77 |
+
self.drop = nn.Dropout(p_dropout)
|
78 |
+
self.attn_layers = nn.ModuleList()
|
79 |
+
self.norm_layers_1 = nn.ModuleList()
|
80 |
+
self.ffn_layers = nn.ModuleList()
|
81 |
+
self.norm_layers_2 = nn.ModuleList()
|
82 |
+
for i in range(self.n_layers):
|
83 |
+
self.attn_layers.append(
|
84 |
+
MultiHeadAttention(
|
85 |
+
hidden_channels,
|
86 |
+
hidden_channels,
|
87 |
+
n_heads,
|
88 |
+
p_dropout=p_dropout,
|
89 |
+
window_size=window_size,
|
90 |
+
)
|
91 |
+
)
|
92 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
93 |
+
self.ffn_layers.append(
|
94 |
+
FFN(
|
95 |
+
hidden_channels,
|
96 |
+
hidden_channels,
|
97 |
+
filter_channels,
|
98 |
+
kernel_size,
|
99 |
+
p_dropout=p_dropout,
|
100 |
+
)
|
101 |
+
)
|
102 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
103 |
+
|
104 |
+
def forward(
|
105 |
+
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
106 |
+
) -> torch.Tensor:
|
107 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
108 |
+
x = x * x_mask
|
109 |
+
for i in range(self.n_layers):
|
110 |
+
if i == self.cond_layer_idx and g is not None:
|
111 |
+
g = self.spk_emb_linear(g.transpose(1, 2))
|
112 |
+
assert g is not None
|
113 |
+
g = g.transpose(1, 2)
|
114 |
+
x = x + g
|
115 |
+
x = x * x_mask
|
116 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
117 |
+
y = self.drop(y)
|
118 |
+
x = self.norm_layers_1[i](x + y)
|
119 |
+
|
120 |
+
y = self.ffn_layers[i](x, x_mask)
|
121 |
+
y = self.drop(y)
|
122 |
+
x = self.norm_layers_2[i](x + y)
|
123 |
+
x = x * x_mask
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
class Decoder(nn.Module):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
hidden_channels: int,
|
131 |
+
filter_channels: int,
|
132 |
+
n_heads: int,
|
133 |
+
n_layers: int,
|
134 |
+
kernel_size: int = 1,
|
135 |
+
p_dropout: float = 0.0,
|
136 |
+
proximal_bias: bool = False,
|
137 |
+
proximal_init: bool = True,
|
138 |
+
**kwargs: Any,
|
139 |
+
) -> None:
|
140 |
+
super().__init__()
|
141 |
+
self.hidden_channels = hidden_channels
|
142 |
+
self.filter_channels = filter_channels
|
143 |
+
self.n_heads = n_heads
|
144 |
+
self.n_layers = n_layers
|
145 |
+
self.kernel_size = kernel_size
|
146 |
+
self.p_dropout = p_dropout
|
147 |
+
self.proximal_bias = proximal_bias
|
148 |
+
self.proximal_init = proximal_init
|
149 |
+
|
150 |
+
self.drop = nn.Dropout(p_dropout)
|
151 |
+
self.self_attn_layers = nn.ModuleList()
|
152 |
+
self.norm_layers_0 = nn.ModuleList()
|
153 |
+
self.encdec_attn_layers = nn.ModuleList()
|
154 |
+
self.norm_layers_1 = nn.ModuleList()
|
155 |
+
self.ffn_layers = nn.ModuleList()
|
156 |
+
self.norm_layers_2 = nn.ModuleList()
|
157 |
+
for i in range(self.n_layers):
|
158 |
+
self.self_attn_layers.append(
|
159 |
+
MultiHeadAttention(
|
160 |
+
hidden_channels,
|
161 |
+
hidden_channels,
|
162 |
+
n_heads,
|
163 |
+
p_dropout=p_dropout,
|
164 |
+
proximal_bias=proximal_bias,
|
165 |
+
proximal_init=proximal_init,
|
166 |
+
)
|
167 |
+
)
|
168 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
169 |
+
self.encdec_attn_layers.append(
|
170 |
+
MultiHeadAttention(
|
171 |
+
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
172 |
+
)
|
173 |
+
)
|
174 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
175 |
+
self.ffn_layers.append(
|
176 |
+
FFN(
|
177 |
+
hidden_channels,
|
178 |
+
hidden_channels,
|
179 |
+
filter_channels,
|
180 |
+
kernel_size,
|
181 |
+
p_dropout=p_dropout,
|
182 |
+
causal=True,
|
183 |
+
)
|
184 |
+
)
|
185 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
186 |
+
|
187 |
+
def forward(
|
188 |
+
self,
|
189 |
+
x: torch.Tensor,
|
190 |
+
x_mask: torch.Tensor,
|
191 |
+
h: torch.Tensor,
|
192 |
+
h_mask: torch.Tensor,
|
193 |
+
) -> torch.Tensor:
|
194 |
+
"""
|
195 |
+
x: decoder input
|
196 |
+
h: encoder output
|
197 |
+
"""
|
198 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
199 |
+
device=x.device, dtype=x.dtype
|
200 |
+
)
|
201 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
202 |
+
x = x * x_mask
|
203 |
+
for i in range(self.n_layers):
|
204 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
205 |
+
y = self.drop(y)
|
206 |
+
x = self.norm_layers_0[i](x + y)
|
207 |
+
|
208 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
209 |
+
y = self.drop(y)
|
210 |
+
x = self.norm_layers_1[i](x + y)
|
211 |
+
|
212 |
+
y = self.ffn_layers[i](x, x_mask)
|
213 |
+
y = self.drop(y)
|
214 |
+
x = self.norm_layers_2[i](x + y)
|
215 |
+
x = x * x_mask
|
216 |
+
return x
|
217 |
+
|
218 |
+
|
219 |
+
class MultiHeadAttention(nn.Module):
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
channels: int,
|
223 |
+
out_channels: int,
|
224 |
+
n_heads: int,
|
225 |
+
p_dropout: float = 0.0,
|
226 |
+
window_size: Optional[int] = None,
|
227 |
+
heads_share: bool = True,
|
228 |
+
block_length: Optional[int] = None,
|
229 |
+
proximal_bias: bool = False,
|
230 |
+
proximal_init: bool = False,
|
231 |
+
) -> None:
|
232 |
+
super().__init__()
|
233 |
+
assert channels % n_heads == 0
|
234 |
+
|
235 |
+
self.channels = channels
|
236 |
+
self.out_channels = out_channels
|
237 |
+
self.n_heads = n_heads
|
238 |
+
self.p_dropout = p_dropout
|
239 |
+
self.window_size = window_size
|
240 |
+
self.heads_share = heads_share
|
241 |
+
self.block_length = block_length
|
242 |
+
self.proximal_bias = proximal_bias
|
243 |
+
self.proximal_init = proximal_init
|
244 |
+
self.attn = None
|
245 |
+
|
246 |
+
self.k_channels = channels // n_heads
|
247 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
248 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
249 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
250 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
251 |
+
self.drop = nn.Dropout(p_dropout)
|
252 |
+
|
253 |
+
if window_size is not None:
|
254 |
+
n_heads_rel = 1 if heads_share else n_heads
|
255 |
+
rel_stddev = self.k_channels**-0.5
|
256 |
+
self.emb_rel_k = nn.Parameter(
|
257 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
258 |
+
* rel_stddev
|
259 |
+
)
|
260 |
+
self.emb_rel_v = nn.Parameter(
|
261 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
262 |
+
* rel_stddev
|
263 |
+
)
|
264 |
+
|
265 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
266 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
267 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
268 |
+
if proximal_init:
|
269 |
+
with torch.no_grad():
|
270 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
271 |
+
assert self.conv_k.bias is not None
|
272 |
+
assert self.conv_q.bias is not None
|
273 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
274 |
+
|
275 |
+
def forward(
|
276 |
+
self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
|
277 |
+
) -> torch.Tensor:
|
278 |
+
q = self.conv_q(x)
|
279 |
+
k = self.conv_k(c)
|
280 |
+
v = self.conv_v(c)
|
281 |
+
|
282 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
283 |
+
|
284 |
+
x = self.conv_o(x)
|
285 |
+
return x
|
286 |
+
|
287 |
+
def attention(
|
288 |
+
self,
|
289 |
+
query: torch.Tensor,
|
290 |
+
key: torch.Tensor,
|
291 |
+
value: torch.Tensor,
|
292 |
+
mask: Optional[torch.Tensor] = None,
|
293 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
294 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
295 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
296 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
297 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
298 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
299 |
+
|
300 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
301 |
+
if self.window_size is not None:
|
302 |
+
assert (
|
303 |
+
t_s == t_t
|
304 |
+
), "Relative attention is only available for self-attention."
|
305 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
306 |
+
rel_logits = self._matmul_with_relative_keys(
|
307 |
+
query / math.sqrt(self.k_channels), key_relative_embeddings
|
308 |
+
)
|
309 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
310 |
+
scores = scores + scores_local
|
311 |
+
if self.proximal_bias:
|
312 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
313 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
314 |
+
device=scores.device, dtype=scores.dtype
|
315 |
+
)
|
316 |
+
if mask is not None:
|
317 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
318 |
+
if self.block_length is not None:
|
319 |
+
assert (
|
320 |
+
t_s == t_t
|
321 |
+
), "Local attention is only available for self-attention."
|
322 |
+
block_mask = (
|
323 |
+
torch.ones_like(scores)
|
324 |
+
.triu(-self.block_length)
|
325 |
+
.tril(self.block_length)
|
326 |
+
)
|
327 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
328 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
329 |
+
p_attn = self.drop(p_attn)
|
330 |
+
output = torch.matmul(p_attn, value)
|
331 |
+
if self.window_size is not None:
|
332 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
333 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
334 |
+
self.emb_rel_v, t_s
|
335 |
+
)
|
336 |
+
output = output + self._matmul_with_relative_values(
|
337 |
+
relative_weights, value_relative_embeddings
|
338 |
+
)
|
339 |
+
output = (
|
340 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
341 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
342 |
+
return output, p_attn
|
343 |
+
|
344 |
+
def _matmul_with_relative_values(
|
345 |
+
self, x: torch.Tensor, y: torch.Tensor
|
346 |
+
) -> torch.Tensor:
|
347 |
+
"""
|
348 |
+
x: [b, h, l, m]
|
349 |
+
y: [h or 1, m, d]
|
350 |
+
ret: [b, h, l, d]
|
351 |
+
"""
|
352 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
353 |
+
return ret
|
354 |
+
|
355 |
+
def _matmul_with_relative_keys(
|
356 |
+
self, x: torch.Tensor, y: torch.Tensor
|
357 |
+
) -> torch.Tensor:
|
358 |
+
"""
|
359 |
+
x: [b, h, l, d]
|
360 |
+
y: [h or 1, m, d]
|
361 |
+
ret: [b, h, l, m]
|
362 |
+
"""
|
363 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
364 |
+
return ret
|
365 |
+
|
366 |
+
def _get_relative_embeddings(
|
367 |
+
self, relative_embeddings: torch.Tensor, length: int
|
368 |
+
) -> torch.Tensor:
|
369 |
+
assert self.window_size is not None
|
370 |
+
2 * self.window_size + 1 # type: ignore
|
371 |
+
# Pad first before slice to avoid using cond ops.
|
372 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
373 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
374 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
375 |
+
if pad_length > 0:
|
376 |
+
padded_relative_embeddings = F.pad(
|
377 |
+
relative_embeddings,
|
378 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
379 |
+
)
|
380 |
+
else:
|
381 |
+
padded_relative_embeddings = relative_embeddings
|
382 |
+
used_relative_embeddings = padded_relative_embeddings[
|
383 |
+
:, slice_start_position:slice_end_position
|
384 |
+
]
|
385 |
+
return used_relative_embeddings
|
386 |
+
|
387 |
+
def _relative_position_to_absolute_position(self, x: torch.Tensor) -> torch.Tensor:
|
388 |
+
"""
|
389 |
+
x: [b, h, l, 2*l-1]
|
390 |
+
ret: [b, h, l, l]
|
391 |
+
"""
|
392 |
+
batch, heads, length, _ = x.size()
|
393 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
394 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
395 |
+
|
396 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
397 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
398 |
+
x_flat = F.pad(
|
399 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
400 |
+
)
|
401 |
+
|
402 |
+
# Reshape and slice out the padded elements.
|
403 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
404 |
+
:, :, :length, length - 1 :
|
405 |
+
]
|
406 |
+
return x_final
|
407 |
+
|
408 |
+
def _absolute_position_to_relative_position(self, x: torch.Tensor) -> torch.Tensor:
|
409 |
+
"""
|
410 |
+
x: [b, h, l, l]
|
411 |
+
ret: [b, h, l, 2*l-1]
|
412 |
+
"""
|
413 |
+
batch, heads, length, _ = x.size()
|
414 |
+
# pad along column
|
415 |
+
x = F.pad(
|
416 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
417 |
+
)
|
418 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
419 |
+
# add 0's in the beginning that will skew the elements after reshape
|
420 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
421 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
422 |
+
return x_final
|
423 |
+
|
424 |
+
def _attention_bias_proximal(self, length: int) -> torch.Tensor:
|
425 |
+
"""Bias for self-attention to encourage attention to close positions.
|
426 |
+
Args:
|
427 |
+
length: an integer scalar.
|
428 |
+
Returns:
|
429 |
+
a Tensor with shape [1, 1, length, length]
|
430 |
+
"""
|
431 |
+
r = torch.arange(length, dtype=torch.float32)
|
432 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
433 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
434 |
+
|
435 |
+
|
436 |
+
class FFN(nn.Module):
|
437 |
+
def __init__(
|
438 |
+
self,
|
439 |
+
in_channels: int,
|
440 |
+
out_channels: int,
|
441 |
+
filter_channels: int,
|
442 |
+
kernel_size: int,
|
443 |
+
p_dropout: float = 0.0,
|
444 |
+
activation: Optional[str] = None,
|
445 |
+
causal: bool = False,
|
446 |
+
) -> None:
|
447 |
+
super().__init__()
|
448 |
+
self.in_channels = in_channels
|
449 |
+
self.out_channels = out_channels
|
450 |
+
self.filter_channels = filter_channels
|
451 |
+
self.kernel_size = kernel_size
|
452 |
+
self.p_dropout = p_dropout
|
453 |
+
self.activation = activation
|
454 |
+
self.causal = causal
|
455 |
+
|
456 |
+
if causal:
|
457 |
+
self.padding = self._causal_padding
|
458 |
+
else:
|
459 |
+
self.padding = self._same_padding
|
460 |
+
|
461 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
462 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
463 |
+
self.drop = nn.Dropout(p_dropout)
|
464 |
+
|
465 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
466 |
+
x = self.conv_1(self.padding(x * x_mask))
|
467 |
+
if self.activation == "gelu":
|
468 |
+
x = x * torch.sigmoid(1.702 * x)
|
469 |
+
else:
|
470 |
+
x = torch.relu(x)
|
471 |
+
x = self.drop(x)
|
472 |
+
x = self.conv_2(self.padding(x * x_mask))
|
473 |
+
return x * x_mask
|
474 |
+
|
475 |
+
def _causal_padding(self, x: torch.Tensor) -> torch.Tensor:
|
476 |
+
if self.kernel_size == 1:
|
477 |
+
return x
|
478 |
+
pad_l = self.kernel_size - 1
|
479 |
+
pad_r = 0
|
480 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
481 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
482 |
+
return x
|
483 |
+
|
484 |
+
def _same_padding(self, x: torch.Tensor) -> torch.Tensor:
|
485 |
+
if self.kernel_size == 1:
|
486 |
+
return x
|
487 |
+
pad_l = (self.kernel_size - 1) // 2
|
488 |
+
pad_r = self.kernel_size // 2
|
489 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
490 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
491 |
+
return x
|
style_bert_vits2/models/commons.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
以下に記述されている関数のコメントはリファクタリング時に GPT-4 に生成させたもので、
|
3 |
+
コードと完全に一致している保証はない。あくまで参考程度とすること。
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Any, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
|
13 |
+
"""
|
14 |
+
モジュールの重みを初期化する
|
15 |
+
|
16 |
+
Args:
|
17 |
+
m (torch.nn.Module): 重みを初期化する対象のモジュール
|
18 |
+
mean (float): 正規分布の平均
|
19 |
+
std (float): 正規分布の標準偏差
|
20 |
+
"""
|
21 |
+
classname = m.__class__.__name__
|
22 |
+
if classname.find("Conv") != -1:
|
23 |
+
m.weight.data.normal_(mean, std)
|
24 |
+
|
25 |
+
|
26 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
27 |
+
"""
|
28 |
+
カーネルサイズと膨張率からパディングの大きさを計算する
|
29 |
+
|
30 |
+
Args:
|
31 |
+
kernel_size (int): カーネルのサイズ
|
32 |
+
dilation (int): 膨張率
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
int: 計算されたパディングの大きさ
|
36 |
+
"""
|
37 |
+
return int((kernel_size * dilation - dilation) / 2)
|
38 |
+
|
39 |
+
|
40 |
+
def convert_pad_shape(pad_shape: list[list[Any]]) -> list[Any]:
|
41 |
+
"""
|
42 |
+
パディングの形状を変換する
|
43 |
+
|
44 |
+
Args:
|
45 |
+
pad_shape (list[list[Any]]): 変換前のパディングの形状
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
list[Any]: 変換後のパディングの形状
|
49 |
+
"""
|
50 |
+
layer = pad_shape[::-1]
|
51 |
+
new_pad_shape = [item for sublist in layer for item in sublist]
|
52 |
+
return new_pad_shape
|
53 |
+
|
54 |
+
|
55 |
+
def intersperse(lst: list[Any], item: Any) -> list[Any]:
|
56 |
+
"""
|
57 |
+
リストの要素の間に特定のアイテムを挿入する
|
58 |
+
|
59 |
+
Args:
|
60 |
+
lst (list[Any]): 元のリスト
|
61 |
+
item (Any): 挿入するアイテム
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
list[Any]: 新しいリスト
|
65 |
+
"""
|
66 |
+
result = [item] * (len(lst) * 2 + 1)
|
67 |
+
result[1::2] = lst
|
68 |
+
return result
|
69 |
+
|
70 |
+
|
71 |
+
def slice_segments(
|
72 |
+
x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4
|
73 |
+
) -> torch.Tensor:
|
74 |
+
"""
|
75 |
+
テンソルからセグメントをスライスする
|
76 |
+
|
77 |
+
Args:
|
78 |
+
x (torch.Tensor): 入力テンソル
|
79 |
+
ids_str (torch.Tensor): スライスを開始するインデックス
|
80 |
+
segment_size (int, optional): スライスのサイズ (デフォルト: 4)
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
torch.Tensor: スライスされたセグメント
|
84 |
+
"""
|
85 |
+
gather_indices = ids_str.view(x.size(0), 1, 1).repeat(
|
86 |
+
1, x.size(1), 1
|
87 |
+
) + torch.arange(segment_size, device=x.device)
|
88 |
+
return torch.gather(x, 2, gather_indices)
|
89 |
+
|
90 |
+
|
91 |
+
def rand_slice_segments(
|
92 |
+
x: torch.Tensor, x_lengths: Optional[torch.Tensor] = None, segment_size: int = 4
|
93 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
94 |
+
"""
|
95 |
+
ランダムなセグメントをスライスする
|
96 |
+
|
97 |
+
Args:
|
98 |
+
x (torch.Tensor): 入力テンソル
|
99 |
+
x_lengths (Optional[torch.Tensor], optional): 各バッチの長さ (デフォルト: None)
|
100 |
+
segment_size (int, optional): スライスのサイズ (デフォルト: 4)
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
tuple[torch.Tensor, torch.Tensor]: スライスされたセグメントと開始インデックス
|
104 |
+
"""
|
105 |
+
b, d, t = x.size()
|
106 |
+
if x_lengths is None:
|
107 |
+
x_lengths = t # type: ignore
|
108 |
+
ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) # type: ignore
|
109 |
+
ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
|
110 |
+
ret = slice_segments(x, ids_str, segment_size)
|
111 |
+
return ret, ids_str
|
112 |
+
|
113 |
+
|
114 |
+
def subsequent_mask(length: int) -> torch.Tensor:
|
115 |
+
"""
|
116 |
+
後続のマスクを生成する
|
117 |
+
|
118 |
+
Args:
|
119 |
+
length (int): マスクのサイズ
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
torch.Tensor: 生成されたマスク
|
123 |
+
"""
|
124 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
125 |
+
return mask
|
126 |
+
|
127 |
+
|
128 |
+
@torch.jit.script # type: ignore
|
129 |
+
def fused_add_tanh_sigmoid_multiply(
|
130 |
+
input_a: torch.Tensor, input_b: torch.Tensor, n_channels: torch.Tensor
|
131 |
+
) -> torch.Tensor:
|
132 |
+
"""
|
133 |
+
加算、tanh、sigmoid の活性化関数を組み合わせた演算を行う
|
134 |
+
|
135 |
+
Args:
|
136 |
+
input_a (torch.Tensor): 入力テンソル A
|
137 |
+
input_b (torch.Tensor): 入力テンソル B
|
138 |
+
n_channels (torch.Tensor): チャネル数
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
torch.Tensor: 演算結果
|
142 |
+
"""
|
143 |
+
n_channels_int = n_channels[0]
|
144 |
+
in_act = input_a + input_b
|
145 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
146 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
147 |
+
acts = t_act * s_act
|
148 |
+
return acts
|
149 |
+
|
150 |
+
|
151 |
+
def sequence_mask(
|
152 |
+
length: torch.Tensor, max_length: Optional[int] = None
|
153 |
+
) -> torch.Tensor:
|
154 |
+
"""
|
155 |
+
シーケンスマスクを生成する
|
156 |
+
|
157 |
+
Args:
|
158 |
+
length (torch.Tensor): 各シーケンスの長さ
|
159 |
+
max_length (Optional[int]): 最大のシーケンス長さ。指定されていない場合は length の最大値を使用
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
torch.Tensor: 生成されたシーケンスマスク
|
163 |
+
"""
|
164 |
+
if max_length is None:
|
165 |
+
max_length = length.max() # type: ignore
|
166 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device) # type: ignore
|
167 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
168 |
+
|
169 |
+
|
170 |
+
def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
171 |
+
"""
|
172 |
+
パスを生成する
|
173 |
+
|
174 |
+
Args:
|
175 |
+
duration (torch.Tensor): 各時間ステップの持続時間
|
176 |
+
mask (torch.Tensor): マスクテンソル
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
torch.Tensor: 生成されたパス
|
180 |
+
"""
|
181 |
+
b, _, t_y, t_x = mask.shape
|
182 |
+
cum_duration = torch.cumsum(duration, -1)
|
183 |
+
|
184 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
185 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
186 |
+
path = path.view(b, t_x, t_y)
|
187 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
188 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
189 |
+
return path
|
190 |
+
|
191 |
+
|
192 |
+
def clip_grad_value_(
|
193 |
+
parameters: Union[torch.Tensor, list[torch.Tensor]],
|
194 |
+
clip_value: Optional[float],
|
195 |
+
norm_type: float = 2.0,
|
196 |
+
) -> float:
|
197 |
+
"""
|
198 |
+
勾配の値をクリップする
|
199 |
+
|
200 |
+
Args:
|
201 |
+
parameters (Union[torch.Tensor, list[torch.Tensor]]): クリップするパラメータ
|
202 |
+
clip_value (Optional[float]): クリップする値。None の場合はクリップしない
|
203 |
+
norm_type (float): ノルムの種類
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
float: 総ノルム
|
207 |
+
"""
|
208 |
+
if isinstance(parameters, torch.Tensor):
|
209 |
+
parameters = [parameters]
|
210 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
211 |
+
norm_type = float(norm_type)
|
212 |
+
if clip_value is not None:
|
213 |
+
clip_value = float(clip_value)
|
214 |
+
|
215 |
+
total_norm = 0.0
|
216 |
+
for p in parameters:
|
217 |
+
assert p.grad is not None
|
218 |
+
param_norm = p.grad.data.norm(norm_type)
|
219 |
+
total_norm += param_norm.item() ** norm_type
|
220 |
+
if clip_value is not None:
|
221 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
222 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
223 |
+
return total_norm
|
style_bert_vits2/models/hyper_parameters.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Style-Bert-VITS2 モデルのハイパーパラメータを表す Pydantic モデル。
|
3 |
+
デフォルト値は configs/config_jp_extra.json 内の定義と概ね同一で、
|
4 |
+
万が一ロードした config.json に存在しないキーがあった際のフェイルセーフとして適用される。
|
5 |
+
"""
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Optional, Union
|
9 |
+
|
10 |
+
from pydantic import BaseModel, ConfigDict
|
11 |
+
|
12 |
+
|
13 |
+
class HyperParametersTrain(BaseModel):
|
14 |
+
log_interval: int = 200
|
15 |
+
eval_interval: int = 1000
|
16 |
+
seed: int = 42
|
17 |
+
epochs: int = 1000
|
18 |
+
learning_rate: float = 0.0001
|
19 |
+
betas: tuple[float, float] = (0.8, 0.99)
|
20 |
+
eps: float = 1e-9
|
21 |
+
batch_size: int = 2
|
22 |
+
bf16_run: bool = False
|
23 |
+
fp16_run: bool = False
|
24 |
+
lr_decay: float = 0.99996
|
25 |
+
segment_size: int = 16384
|
26 |
+
init_lr_ratio: int = 1
|
27 |
+
warmup_epochs: int = 0
|
28 |
+
c_mel: int = 45
|
29 |
+
c_kl: float = 1.0
|
30 |
+
c_commit: int = 100
|
31 |
+
skip_optimizer: bool = False
|
32 |
+
freeze_ZH_bert: bool = False
|
33 |
+
freeze_JP_bert: bool = False
|
34 |
+
freeze_EN_bert: bool = False
|
35 |
+
freeze_emo: bool = False
|
36 |
+
freeze_style: bool = False
|
37 |
+
freeze_decoder: bool = False
|
38 |
+
|
39 |
+
|
40 |
+
class HyperParametersData(BaseModel):
|
41 |
+
use_jp_extra: bool = True
|
42 |
+
training_files: str = "Data/Dummy/train.list"
|
43 |
+
validation_files: str = "Data/Dummy/val.list"
|
44 |
+
max_wav_value: float = 32768.0
|
45 |
+
sampling_rate: int = 44100
|
46 |
+
filter_length: int = 2048
|
47 |
+
hop_length: int = 512
|
48 |
+
win_length: int = 2048
|
49 |
+
n_mel_channels: int = 128
|
50 |
+
mel_fmin: float = 0.0
|
51 |
+
mel_fmax: Optional[float] = None
|
52 |
+
add_blank: bool = True
|
53 |
+
n_speakers: int = 1
|
54 |
+
cleaned_text: bool = True
|
55 |
+
spk2id: dict[str, int] = {
|
56 |
+
"Dummy": 0,
|
57 |
+
}
|
58 |
+
num_styles: int = 1
|
59 |
+
style2id: dict[str, int] = {
|
60 |
+
"Neutral": 0,
|
61 |
+
}
|
62 |
+
|
63 |
+
|
64 |
+
class HyperParametersModelSLM(BaseModel):
|
65 |
+
model: str = "./slm/wavlm-base-plus"
|
66 |
+
sr: int = 16000
|
67 |
+
hidden: int = 768
|
68 |
+
nlayers: int = 13
|
69 |
+
initial_channel: int = 64
|
70 |
+
|
71 |
+
|
72 |
+
class HyperParametersModel(BaseModel):
|
73 |
+
use_spk_conditioned_encoder: bool = True
|
74 |
+
use_noise_scaled_mas: bool = True
|
75 |
+
use_mel_posterior_encoder: bool = False
|
76 |
+
use_duration_discriminator: bool = False
|
77 |
+
use_wavlm_discriminator: bool = True
|
78 |
+
inter_channels: int = 192
|
79 |
+
hidden_channels: int = 192
|
80 |
+
filter_channels: int = 768
|
81 |
+
n_heads: int = 2
|
82 |
+
n_layers: int = 6
|
83 |
+
kernel_size: int = 3
|
84 |
+
p_dropout: float = 0.1
|
85 |
+
resblock: str = "1"
|
86 |
+
resblock_kernel_sizes: list[int] = [3, 7, 11]
|
87 |
+
resblock_dilation_sizes: list[list[int]] = [
|
88 |
+
[1, 3, 5],
|
89 |
+
[1, 3, 5],
|
90 |
+
[1, 3, 5],
|
91 |
+
]
|
92 |
+
upsample_rates: list[int] = [8, 8, 2, 2, 2]
|
93 |
+
upsample_initial_channel: int = 512
|
94 |
+
upsample_kernel_sizes: list[int] = [16, 16, 8, 2, 2]
|
95 |
+
n_layers_q: int = 3
|
96 |
+
use_spectral_norm: bool = False
|
97 |
+
gin_channels: int = 512
|
98 |
+
slm: HyperParametersModelSLM = HyperParametersModelSLM()
|
99 |
+
|
100 |
+
|
101 |
+
class HyperParameters(BaseModel):
|
102 |
+
model_name: str = "Dummy"
|
103 |
+
version: str = "2.0-JP-Extra"
|
104 |
+
train: HyperParametersTrain = HyperParametersTrain()
|
105 |
+
data: HyperParametersData = HyperParametersData()
|
106 |
+
model: HyperParametersModel = HyperParametersModel()
|
107 |
+
|
108 |
+
# 以下は学習時にのみ動的に設定されるパラメータ (通常 config.json には存在しない)
|
109 |
+
model_dir: Optional[str] = None
|
110 |
+
speedup: bool = False
|
111 |
+
repo_id: Optional[str] = None
|
112 |
+
|
113 |
+
# model_ 以下を Pydantic の保護対象から除外する
|
114 |
+
model_config = ConfigDict(protected_namespaces=())
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def load_from_json(json_path: Union[str, Path]) -> "HyperParameters":
|
118 |
+
"""
|
119 |
+
与えられた JSON ファイルからハイパーパラメータを読み込む。
|
120 |
+
|
121 |
+
Args:
|
122 |
+
json_path (Union[str, Path]): JSON ファイルのパス
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
HyperParameters: ハイパーパラメータ
|
126 |
+
"""
|
127 |
+
|
128 |
+
with open(json_path, encoding="utf-8") as f:
|
129 |
+
return HyperParameters.model_validate_json(f.read())
|
style_bert_vits2/models/infer.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional, Union, cast
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from numpy.typing import NDArray
|
5 |
+
|
6 |
+
from style_bert_vits2.constants import Languages
|
7 |
+
from style_bert_vits2.logging import logger
|
8 |
+
from style_bert_vits2.models import commons, utils
|
9 |
+
from style_bert_vits2.models.hyper_parameters import HyperParameters
|
10 |
+
from style_bert_vits2.models.models import SynthesizerTrn
|
11 |
+
from style_bert_vits2.models.models_jp_extra import (
|
12 |
+
SynthesizerTrn as SynthesizerTrnJPExtra,
|
13 |
+
)
|
14 |
+
from style_bert_vits2.nlp import (
|
15 |
+
clean_text,
|
16 |
+
cleaned_text_to_sequence,
|
17 |
+
extract_bert_feature,
|
18 |
+
)
|
19 |
+
from style_bert_vits2.nlp.symbols import SYMBOLS
|
20 |
+
|
21 |
+
|
22 |
+
def get_net_g(model_path: str, version: str, device: str, hps: HyperParameters):
|
23 |
+
if version.endswith("JP-Extra"):
|
24 |
+
logger.info("Using JP-Extra model")
|
25 |
+
net_g = SynthesizerTrnJPExtra(
|
26 |
+
n_vocab=len(SYMBOLS),
|
27 |
+
spec_channels=hps.data.filter_length // 2 + 1,
|
28 |
+
segment_size=hps.train.segment_size // hps.data.hop_length,
|
29 |
+
n_speakers=hps.data.n_speakers,
|
30 |
+
# hps.model 以下のすべての値を引数に渡す
|
31 |
+
use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder,
|
32 |
+
use_noise_scaled_mas=hps.model.use_noise_scaled_mas,
|
33 |
+
use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder,
|
34 |
+
use_duration_discriminator=hps.model.use_duration_discriminator,
|
35 |
+
use_wavlm_discriminator=hps.model.use_wavlm_discriminator,
|
36 |
+
inter_channels=hps.model.inter_channels,
|
37 |
+
hidden_channels=hps.model.hidden_channels,
|
38 |
+
filter_channels=hps.model.filter_channels,
|
39 |
+
n_heads=hps.model.n_heads,
|
40 |
+
n_layers=hps.model.n_layers,
|
41 |
+
kernel_size=hps.model.kernel_size,
|
42 |
+
p_dropout=hps.model.p_dropout,
|
43 |
+
resblock=hps.model.resblock,
|
44 |
+
resblock_kernel_sizes=hps.model.resblock_kernel_sizes,
|
45 |
+
resblock_dilation_sizes=hps.model.resblock_dilation_sizes,
|
46 |
+
upsample_rates=hps.model.upsample_rates,
|
47 |
+
upsample_initial_channel=hps.model.upsample_initial_channel,
|
48 |
+
upsample_kernel_sizes=hps.model.upsample_kernel_sizes,
|
49 |
+
n_layers_q=hps.model.n_layers_q,
|
50 |
+
use_spectral_norm=hps.model.use_spectral_norm,
|
51 |
+
gin_channels=hps.model.gin_channels,
|
52 |
+
slm=hps.model.slm,
|
53 |
+
).to(device)
|
54 |
+
else:
|
55 |
+
logger.info("Using normal model")
|
56 |
+
net_g = SynthesizerTrn(
|
57 |
+
n_vocab=len(SYMBOLS),
|
58 |
+
spec_channels=hps.data.filter_length // 2 + 1,
|
59 |
+
segment_size=hps.train.segment_size // hps.data.hop_length,
|
60 |
+
n_speakers=hps.data.n_speakers,
|
61 |
+
# hps.model 以下のすべての値を引数に渡す
|
62 |
+
use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder,
|
63 |
+
use_noise_scaled_mas=hps.model.use_noise_scaled_mas,
|
64 |
+
use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder,
|
65 |
+
use_duration_discriminator=hps.model.use_duration_discriminator,
|
66 |
+
use_wavlm_discriminator=hps.model.use_wavlm_discriminator,
|
67 |
+
inter_channels=hps.model.inter_channels,
|
68 |
+
hidden_channels=hps.model.hidden_channels,
|
69 |
+
filter_channels=hps.model.filter_channels,
|
70 |
+
n_heads=hps.model.n_heads,
|
71 |
+
n_layers=hps.model.n_layers,
|
72 |
+
kernel_size=hps.model.kernel_size,
|
73 |
+
p_dropout=hps.model.p_dropout,
|
74 |
+
resblock=hps.model.resblock,
|
75 |
+
resblock_kernel_sizes=hps.model.resblock_kernel_sizes,
|
76 |
+
resblock_dilation_sizes=hps.model.resblock_dilation_sizes,
|
77 |
+
upsample_rates=hps.model.upsample_rates,
|
78 |
+
upsample_initial_channel=hps.model.upsample_initial_channel,
|
79 |
+
upsample_kernel_sizes=hps.model.upsample_kernel_sizes,
|
80 |
+
n_layers_q=hps.model.n_layers_q,
|
81 |
+
use_spectral_norm=hps.model.use_spectral_norm,
|
82 |
+
gin_channels=hps.model.gin_channels,
|
83 |
+
slm=hps.model.slm,
|
84 |
+
).to(device)
|
85 |
+
net_g.state_dict()
|
86 |
+
_ = net_g.eval()
|
87 |
+
if model_path.endswith(".pth") or model_path.endswith(".pt"):
|
88 |
+
_ = utils.checkpoints.load_checkpoint(
|
89 |
+
model_path, net_g, None, skip_optimizer=True
|
90 |
+
)
|
91 |
+
elif model_path.endswith(".safetensors"):
|
92 |
+
_ = utils.safetensors.load_safetensors(model_path, net_g, True)
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Unknown model format: {model_path}")
|
95 |
+
return net_g
|
96 |
+
|
97 |
+
|
98 |
+
def get_text(
|
99 |
+
text: str,
|
100 |
+
language_str: Languages,
|
101 |
+
hps: HyperParameters,
|
102 |
+
device: str,
|
103 |
+
assist_text: Optional[str] = None,
|
104 |
+
assist_text_weight: float = 0.7,
|
105 |
+
given_phone: Optional[list[str]] = None,
|
106 |
+
given_tone: Optional[list[int]] = None,
|
107 |
+
):
|
108 |
+
use_jp_extra = hps.version.endswith("JP-Extra")
|
109 |
+
# 推論時のみ呼び出されるので、raise_yomi_error は False に設定
|
110 |
+
norm_text, phone, tone, word2ph = clean_text(
|
111 |
+
text,
|
112 |
+
language_str,
|
113 |
+
use_jp_extra=use_jp_extra,
|
114 |
+
raise_yomi_error=False,
|
115 |
+
)
|
116 |
+
# phone と tone の両方が与えられた場合はそれを使う
|
117 |
+
if given_phone is not None and given_tone is not None:
|
118 |
+
# 指定された phone と指定された tone 両方の長さが一致していなければならない
|
119 |
+
if len(given_phone) != len(given_tone):
|
120 |
+
raise InvalidPhoneError(
|
121 |
+
f"Length of given_phone ({len(given_phone)}) != length of given_tone ({len(given_tone)})"
|
122 |
+
)
|
123 |
+
# 与えられた音素数と pyopenjtalk で生成した読みの音素数が一致しない
|
124 |
+
if len(given_phone) != sum(word2ph):
|
125 |
+
# 日本語の場合、len(given_phone) と sum(word2ph) が一致するように word2ph を適切に調整する
|
126 |
+
# 他の言語は word2ph の調整方法が思いつかないのでエラー
|
127 |
+
if language_str == Languages.JP:
|
128 |
+
from style_bert_vits2.nlp.japanese.g2p import adjust_word2ph
|
129 |
+
|
130 |
+
word2ph = adjust_word2ph(word2ph, phone, given_phone)
|
131 |
+
# 上記処理により word2ph の合計が given_phone の長さと一致するはず
|
132 |
+
# それでも一致しない場合、大半は読み上げテキストと given_phone が著しく乖離していて調整し切れなかったことを意味する
|
133 |
+
if len(given_phone) != sum(word2ph):
|
134 |
+
raise InvalidPhoneError(
|
135 |
+
f"Length of given_phone ({len(given_phone)}) != sum of word2ph ({sum(word2ph)})"
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
raise InvalidPhoneError(
|
139 |
+
f"Length of given_phone ({len(given_phone)}) != sum of word2ph ({sum(word2ph)})"
|
140 |
+
)
|
141 |
+
phone = given_phone
|
142 |
+
# 生成あるいは指定された phone と指定された tone 両方の長さが一致していなければならない
|
143 |
+
if len(phone) != len(given_tone):
|
144 |
+
raise InvalidToneError(
|
145 |
+
f"Length of phone ({len(phone)}) != length of given_tone ({len(given_tone)})"
|
146 |
+
)
|
147 |
+
tone = given_tone
|
148 |
+
# tone だけが与えられた場合は clean_text() で生成した phone と合わせて使う
|
149 |
+
elif given_tone is not None:
|
150 |
+
# 生成した phone と指定された tone 両方の長さが一致していなければならない
|
151 |
+
if len(phone) != len(given_tone):
|
152 |
+
raise InvalidToneError(
|
153 |
+
f"Length of phone ({len(phone)}) != length of given_tone ({len(given_tone)})"
|
154 |
+
)
|
155 |
+
tone = given_tone
|
156 |
+
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
157 |
+
|
158 |
+
if hps.data.add_blank:
|
159 |
+
phone = commons.intersperse(phone, 0)
|
160 |
+
tone = commons.intersperse(tone, 0)
|
161 |
+
language = commons.intersperse(language, 0)
|
162 |
+
for i in range(len(word2ph)):
|
163 |
+
word2ph[i] = word2ph[i] * 2
|
164 |
+
word2ph[0] += 1
|
165 |
+
bert_ori = extract_bert_feature(
|
166 |
+
norm_text,
|
167 |
+
word2ph,
|
168 |
+
language_str,
|
169 |
+
device,
|
170 |
+
assist_text,
|
171 |
+
assist_text_weight,
|
172 |
+
)
|
173 |
+
del word2ph
|
174 |
+
assert bert_ori.shape[-1] == len(phone), phone
|
175 |
+
|
176 |
+
if language_str == Languages.ZH:
|
177 |
+
bert = bert_ori
|
178 |
+
ja_bert = torch.zeros(1024, len(phone))
|
179 |
+
en_bert = torch.zeros(1024, len(phone))
|
180 |
+
elif language_str == Languages.JP:
|
181 |
+
bert = torch.zeros(1024, len(phone))
|
182 |
+
ja_bert = bert_ori
|
183 |
+
en_bert = torch.zeros(1024, len(phone))
|
184 |
+
elif language_str == Languages.EN:
|
185 |
+
bert = torch.zeros(1024, len(phone))
|
186 |
+
ja_bert = torch.zeros(1024, len(phone))
|
187 |
+
en_bert = bert_ori
|
188 |
+
else:
|
189 |
+
raise ValueError("language_str should be ZH, JP or EN")
|
190 |
+
|
191 |
+
assert bert.shape[-1] == len(
|
192 |
+
phone
|
193 |
+
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
194 |
+
|
195 |
+
phone = torch.LongTensor(phone)
|
196 |
+
tone = torch.LongTensor(tone)
|
197 |
+
language = torch.LongTensor(language)
|
198 |
+
return bert, ja_bert, en_bert, phone, tone, language
|
199 |
+
|
200 |
+
|
201 |
+
def infer(
|
202 |
+
text: str,
|
203 |
+
style_vec: NDArray[Any],
|
204 |
+
sdp_ratio: float,
|
205 |
+
noise_scale: float,
|
206 |
+
noise_scale_w: float,
|
207 |
+
length_scale: float,
|
208 |
+
sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id
|
209 |
+
language: Languages,
|
210 |
+
hps: HyperParameters,
|
211 |
+
net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra],
|
212 |
+
device: str,
|
213 |
+
skip_start: bool = False,
|
214 |
+
skip_end: bool = False,
|
215 |
+
assist_text: Optional[str] = None,
|
216 |
+
assist_text_weight: float = 0.7,
|
217 |
+
given_phone: Optional[list[str]] = None,
|
218 |
+
given_tone: Optional[list[int]] = None,
|
219 |
+
):
|
220 |
+
is_jp_extra = hps.version.endswith("JP-Extra")
|
221 |
+
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
222 |
+
text,
|
223 |
+
language,
|
224 |
+
hps,
|
225 |
+
device,
|
226 |
+
assist_text=assist_text,
|
227 |
+
assist_text_weight=assist_text_weight,
|
228 |
+
given_phone=given_phone,
|
229 |
+
given_tone=given_tone,
|
230 |
+
)
|
231 |
+
if skip_start:
|
232 |
+
phones = phones[3:]
|
233 |
+
tones = tones[3:]
|
234 |
+
lang_ids = lang_ids[3:]
|
235 |
+
bert = bert[:, 3:]
|
236 |
+
ja_bert = ja_bert[:, 3:]
|
237 |
+
en_bert = en_bert[:, 3:]
|
238 |
+
if skip_end:
|
239 |
+
phones = phones[:-2]
|
240 |
+
tones = tones[:-2]
|
241 |
+
lang_ids = lang_ids[:-2]
|
242 |
+
bert = bert[:, :-2]
|
243 |
+
ja_bert = ja_bert[:, :-2]
|
244 |
+
en_bert = en_bert[:, :-2]
|
245 |
+
with torch.no_grad():
|
246 |
+
x_tst = phones.to(device).unsqueeze(0)
|
247 |
+
tones = tones.to(device).unsqueeze(0)
|
248 |
+
lang_ids = lang_ids.to(device).unsqueeze(0)
|
249 |
+
bert = bert.to(device).unsqueeze(0)
|
250 |
+
ja_bert = ja_bert.to(device).unsqueeze(0)
|
251 |
+
en_bert = en_bert.to(device).unsqueeze(0)
|
252 |
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
253 |
+
style_vec_tensor = torch.from_numpy(style_vec).to(device).unsqueeze(0)
|
254 |
+
del phones
|
255 |
+
sid_tensor = torch.LongTensor([sid]).to(device)
|
256 |
+
if is_jp_extra:
|
257 |
+
output = cast(SynthesizerTrnJPExtra, net_g).infer(
|
258 |
+
x_tst,
|
259 |
+
x_tst_lengths,
|
260 |
+
sid_tensor,
|
261 |
+
tones,
|
262 |
+
lang_ids,
|
263 |
+
ja_bert,
|
264 |
+
style_vec=style_vec_tensor,
|
265 |
+
sdp_ratio=sdp_ratio,
|
266 |
+
noise_scale=noise_scale,
|
267 |
+
noise_scale_w=noise_scale_w,
|
268 |
+
length_scale=length_scale,
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
output = cast(SynthesizerTrn, net_g).infer(
|
272 |
+
x_tst,
|
273 |
+
x_tst_lengths,
|
274 |
+
sid_tensor,
|
275 |
+
tones,
|
276 |
+
lang_ids,
|
277 |
+
bert,
|
278 |
+
ja_bert,
|
279 |
+
en_bert,
|
280 |
+
style_vec=style_vec_tensor,
|
281 |
+
sdp_ratio=sdp_ratio,
|
282 |
+
noise_scale=noise_scale,
|
283 |
+
noise_scale_w=noise_scale_w,
|
284 |
+
length_scale=length_scale,
|
285 |
+
)
|
286 |
+
audio = output[0][0, 0].data.cpu().float().numpy()
|
287 |
+
del (
|
288 |
+
x_tst,
|
289 |
+
tones,
|
290 |
+
lang_ids,
|
291 |
+
bert,
|
292 |
+
x_tst_lengths,
|
293 |
+
sid_tensor,
|
294 |
+
ja_bert,
|
295 |
+
en_bert,
|
296 |
+
style_vec,
|
297 |
+
) # , emo
|
298 |
+
if torch.cuda.is_available():
|
299 |
+
torch.cuda.empty_cache()
|
300 |
+
return audio
|
301 |
+
|
302 |
+
|
303 |
+
class InvalidPhoneError(ValueError):
|
304 |
+
pass
|
305 |
+
|
306 |
+
|
307 |
+
class InvalidToneError(ValueError):
|
308 |
+
pass
|
style_bert_vits2/models/models.py
ADDED
@@ -0,0 +1,1102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
9 |
+
|
10 |
+
from style_bert_vits2.models import attentions, commons, modules, monotonic_alignment
|
11 |
+
from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS
|
12 |
+
|
13 |
+
|
14 |
+
class DurationDiscriminator(nn.Module): # vits2
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_channels: int,
|
18 |
+
filter_channels: int,
|
19 |
+
kernel_size: int,
|
20 |
+
p_dropout: float,
|
21 |
+
gin_channels: int = 0,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.in_channels = in_channels
|
26 |
+
self.filter_channels = filter_channels
|
27 |
+
self.kernel_size = kernel_size
|
28 |
+
self.p_dropout = p_dropout
|
29 |
+
self.gin_channels = gin_channels
|
30 |
+
|
31 |
+
self.drop = nn.Dropout(p_dropout)
|
32 |
+
self.conv_1 = nn.Conv1d(
|
33 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
34 |
+
)
|
35 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
36 |
+
self.conv_2 = nn.Conv1d(
|
37 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
38 |
+
)
|
39 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
40 |
+
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
41 |
+
|
42 |
+
self.pre_out_conv_1 = nn.Conv1d(
|
43 |
+
2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
44 |
+
)
|
45 |
+
self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
|
46 |
+
self.pre_out_conv_2 = nn.Conv1d(
|
47 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
48 |
+
)
|
49 |
+
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
|
50 |
+
|
51 |
+
if gin_channels != 0:
|
52 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
53 |
+
|
54 |
+
self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
|
55 |
+
|
56 |
+
def forward_probability(
|
57 |
+
self,
|
58 |
+
x: torch.Tensor,
|
59 |
+
x_mask: torch.Tensor,
|
60 |
+
dur: torch.Tensor,
|
61 |
+
g: Optional[torch.Tensor] = None,
|
62 |
+
) -> torch.Tensor:
|
63 |
+
dur = self.dur_proj(dur)
|
64 |
+
x = torch.cat([x, dur], dim=1)
|
65 |
+
x = self.pre_out_conv_1(x * x_mask)
|
66 |
+
x = torch.relu(x)
|
67 |
+
x = self.pre_out_norm_1(x)
|
68 |
+
x = self.drop(x)
|
69 |
+
x = self.pre_out_conv_2(x * x_mask)
|
70 |
+
x = torch.relu(x)
|
71 |
+
x = self.pre_out_norm_2(x)
|
72 |
+
x = self.drop(x)
|
73 |
+
x = x * x_mask
|
74 |
+
x = x.transpose(1, 2)
|
75 |
+
output_prob = self.output_layer(x)
|
76 |
+
return output_prob
|
77 |
+
|
78 |
+
def forward(
|
79 |
+
self,
|
80 |
+
x: torch.Tensor,
|
81 |
+
x_mask: torch.Tensor,
|
82 |
+
dur_r: torch.Tensor,
|
83 |
+
dur_hat: torch.Tensor,
|
84 |
+
g: Optional[torch.Tensor] = None,
|
85 |
+
) -> list[torch.Tensor]:
|
86 |
+
x = torch.detach(x)
|
87 |
+
if g is not None:
|
88 |
+
g = torch.detach(g)
|
89 |
+
x = x + self.cond(g)
|
90 |
+
x = self.conv_1(x * x_mask)
|
91 |
+
x = torch.relu(x)
|
92 |
+
x = self.norm_1(x)
|
93 |
+
x = self.drop(x)
|
94 |
+
x = self.conv_2(x * x_mask)
|
95 |
+
x = torch.relu(x)
|
96 |
+
x = self.norm_2(x)
|
97 |
+
x = self.drop(x)
|
98 |
+
|
99 |
+
output_probs = []
|
100 |
+
for dur in [dur_r, dur_hat]:
|
101 |
+
output_prob = self.forward_probability(x, x_mask, dur, g)
|
102 |
+
output_probs.append(output_prob)
|
103 |
+
|
104 |
+
return output_probs
|
105 |
+
|
106 |
+
|
107 |
+
class TransformerCouplingBlock(nn.Module):
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
channels: int,
|
111 |
+
hidden_channels: int,
|
112 |
+
filter_channels: int,
|
113 |
+
n_heads: int,
|
114 |
+
n_layers: int,
|
115 |
+
kernel_size: int,
|
116 |
+
p_dropout: float,
|
117 |
+
n_flows: int = 4,
|
118 |
+
gin_channels: int = 0,
|
119 |
+
share_parameter: bool = False,
|
120 |
+
) -> None:
|
121 |
+
super().__init__()
|
122 |
+
self.channels = channels
|
123 |
+
self.hidden_channels = hidden_channels
|
124 |
+
self.kernel_size = kernel_size
|
125 |
+
self.n_layers = n_layers
|
126 |
+
self.n_flows = n_flows
|
127 |
+
self.gin_channels = gin_channels
|
128 |
+
|
129 |
+
self.flows = nn.ModuleList()
|
130 |
+
|
131 |
+
self.wn = (
|
132 |
+
# attentions.FFT(
|
133 |
+
# hidden_channels,
|
134 |
+
# filter_channels,
|
135 |
+
# n_heads,
|
136 |
+
# n_layers,
|
137 |
+
# kernel_size,
|
138 |
+
# p_dropout,
|
139 |
+
# isflow=True,
|
140 |
+
# gin_channels=self.gin_channels,
|
141 |
+
# )
|
142 |
+
None
|
143 |
+
if share_parameter
|
144 |
+
else None
|
145 |
+
)
|
146 |
+
|
147 |
+
for i in range(n_flows):
|
148 |
+
self.flows.append(
|
149 |
+
modules.TransformerCouplingLayer(
|
150 |
+
channels,
|
151 |
+
hidden_channels,
|
152 |
+
kernel_size,
|
153 |
+
n_layers,
|
154 |
+
n_heads,
|
155 |
+
p_dropout,
|
156 |
+
filter_channels,
|
157 |
+
mean_only=True,
|
158 |
+
wn_sharing_parameter=self.wn,
|
159 |
+
gin_channels=self.gin_channels,
|
160 |
+
)
|
161 |
+
)
|
162 |
+
self.flows.append(modules.Flip())
|
163 |
+
|
164 |
+
def forward(
|
165 |
+
self,
|
166 |
+
x: torch.Tensor,
|
167 |
+
x_mask: torch.Tensor,
|
168 |
+
g: Optional[torch.Tensor] = None,
|
169 |
+
reverse: bool = False,
|
170 |
+
) -> torch.Tensor:
|
171 |
+
if not reverse:
|
172 |
+
for flow in self.flows:
|
173 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
174 |
+
else:
|
175 |
+
for flow in reversed(self.flows):
|
176 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
177 |
+
return x
|
178 |
+
|
179 |
+
|
180 |
+
class StochasticDurationPredictor(nn.Module):
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
in_channels: int,
|
184 |
+
filter_channels: int,
|
185 |
+
kernel_size: int,
|
186 |
+
p_dropout: float,
|
187 |
+
n_flows: int = 4,
|
188 |
+
gin_channels: int = 0,
|
189 |
+
) -> None:
|
190 |
+
super().__init__()
|
191 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
192 |
+
self.in_channels = in_channels
|
193 |
+
self.filter_channels = filter_channels
|
194 |
+
self.kernel_size = kernel_size
|
195 |
+
self.p_dropout = p_dropout
|
196 |
+
self.n_flows = n_flows
|
197 |
+
self.gin_channels = gin_channels
|
198 |
+
|
199 |
+
self.log_flow = modules.Log()
|
200 |
+
self.flows = nn.ModuleList()
|
201 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
202 |
+
for i in range(n_flows):
|
203 |
+
self.flows.append(
|
204 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
205 |
+
)
|
206 |
+
self.flows.append(modules.Flip())
|
207 |
+
|
208 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
209 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
210 |
+
self.post_convs = modules.DDSConv(
|
211 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
212 |
+
)
|
213 |
+
self.post_flows = nn.ModuleList()
|
214 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
215 |
+
for i in range(4):
|
216 |
+
self.post_flows.append(
|
217 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
218 |
+
)
|
219 |
+
self.post_flows.append(modules.Flip())
|
220 |
+
|
221 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
222 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
223 |
+
self.convs = modules.DDSConv(
|
224 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
225 |
+
)
|
226 |
+
if gin_channels != 0:
|
227 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
228 |
+
|
229 |
+
def forward(
|
230 |
+
self,
|
231 |
+
x: torch.Tensor,
|
232 |
+
x_mask: torch.Tensor,
|
233 |
+
w: Optional[torch.Tensor] = None,
|
234 |
+
g: Optional[torch.Tensor] = None,
|
235 |
+
reverse: bool = False,
|
236 |
+
noise_scale: float = 1.0,
|
237 |
+
) -> torch.Tensor:
|
238 |
+
x = torch.detach(x)
|
239 |
+
x = self.pre(x)
|
240 |
+
if g is not None:
|
241 |
+
g = torch.detach(g)
|
242 |
+
x = x + self.cond(g)
|
243 |
+
x = self.convs(x, x_mask)
|
244 |
+
x = self.proj(x) * x_mask
|
245 |
+
|
246 |
+
if not reverse:
|
247 |
+
flows = self.flows
|
248 |
+
assert w is not None
|
249 |
+
|
250 |
+
logdet_tot_q = 0
|
251 |
+
h_w = self.post_pre(w)
|
252 |
+
h_w = self.post_convs(h_w, x_mask)
|
253 |
+
h_w = self.post_proj(h_w) * x_mask
|
254 |
+
e_q = (
|
255 |
+
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
256 |
+
* x_mask
|
257 |
+
)
|
258 |
+
z_q = e_q
|
259 |
+
for flow in self.post_flows:
|
260 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
261 |
+
logdet_tot_q += logdet_q
|
262 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
263 |
+
u = torch.sigmoid(z_u) * x_mask
|
264 |
+
z0 = (w - u) * x_mask
|
265 |
+
logdet_tot_q += torch.sum(
|
266 |
+
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
267 |
+
)
|
268 |
+
logq = (
|
269 |
+
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
270 |
+
- logdet_tot_q
|
271 |
+
)
|
272 |
+
|
273 |
+
logdet_tot = 0
|
274 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
275 |
+
logdet_tot += logdet
|
276 |
+
z = torch.cat([z0, z1], 1)
|
277 |
+
for flow in flows:
|
278 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
279 |
+
logdet_tot = logdet_tot + logdet
|
280 |
+
nll = (
|
281 |
+
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
282 |
+
- logdet_tot
|
283 |
+
)
|
284 |
+
return nll + logq # [b]
|
285 |
+
else:
|
286 |
+
flows = list(reversed(self.flows))
|
287 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
288 |
+
z = (
|
289 |
+
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
290 |
+
* noise_scale
|
291 |
+
)
|
292 |
+
for flow in flows:
|
293 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
294 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
295 |
+
logw = z0
|
296 |
+
return logw
|
297 |
+
|
298 |
+
|
299 |
+
class DurationPredictor(nn.Module):
|
300 |
+
def __init__(
|
301 |
+
self,
|
302 |
+
in_channels: int,
|
303 |
+
filter_channels: int,
|
304 |
+
kernel_size: int,
|
305 |
+
p_dropout: float,
|
306 |
+
gin_channels: int = 0,
|
307 |
+
) -> None:
|
308 |
+
super().__init__()
|
309 |
+
|
310 |
+
self.in_channels = in_channels
|
311 |
+
self.filter_channels = filter_channels
|
312 |
+
self.kernel_size = kernel_size
|
313 |
+
self.p_dropout = p_dropout
|
314 |
+
self.gin_channels = gin_channels
|
315 |
+
|
316 |
+
self.drop = nn.Dropout(p_dropout)
|
317 |
+
self.conv_1 = nn.Conv1d(
|
318 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
319 |
+
)
|
320 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
321 |
+
self.conv_2 = nn.Conv1d(
|
322 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
323 |
+
)
|
324 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
325 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
326 |
+
|
327 |
+
if gin_channels != 0:
|
328 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
329 |
+
|
330 |
+
def forward(
|
331 |
+
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
332 |
+
) -> torch.Tensor:
|
333 |
+
x = torch.detach(x)
|
334 |
+
if g is not None:
|
335 |
+
g = torch.detach(g)
|
336 |
+
x = x + self.cond(g)
|
337 |
+
x = self.conv_1(x * x_mask)
|
338 |
+
x = torch.relu(x)
|
339 |
+
x = self.norm_1(x)
|
340 |
+
x = self.drop(x)
|
341 |
+
x = self.conv_2(x * x_mask)
|
342 |
+
x = torch.relu(x)
|
343 |
+
x = self.norm_2(x)
|
344 |
+
x = self.drop(x)
|
345 |
+
x = self.proj(x * x_mask)
|
346 |
+
return x * x_mask
|
347 |
+
|
348 |
+
|
349 |
+
class TextEncoder(nn.Module):
|
350 |
+
def __init__(
|
351 |
+
self,
|
352 |
+
n_vocab: int,
|
353 |
+
out_channels: int,
|
354 |
+
hidden_channels: int,
|
355 |
+
filter_channels: int,
|
356 |
+
n_heads: int,
|
357 |
+
n_layers: int,
|
358 |
+
kernel_size: int,
|
359 |
+
p_dropout: float,
|
360 |
+
n_speakers: int,
|
361 |
+
gin_channels: int = 0,
|
362 |
+
) -> None:
|
363 |
+
super().__init__()
|
364 |
+
self.n_vocab = n_vocab
|
365 |
+
self.out_channels = out_channels
|
366 |
+
self.hidden_channels = hidden_channels
|
367 |
+
self.filter_channels = filter_channels
|
368 |
+
self.n_heads = n_heads
|
369 |
+
self.n_layers = n_layers
|
370 |
+
self.kernel_size = kernel_size
|
371 |
+
self.p_dropout = p_dropout
|
372 |
+
self.gin_channels = gin_channels
|
373 |
+
self.emb = nn.Embedding(len(SYMBOLS), hidden_channels)
|
374 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
375 |
+
self.tone_emb = nn.Embedding(NUM_TONES, hidden_channels)
|
376 |
+
nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
|
377 |
+
self.language_emb = nn.Embedding(NUM_LANGUAGES, hidden_channels)
|
378 |
+
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
|
379 |
+
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
380 |
+
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
381 |
+
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
382 |
+
self.style_proj = nn.Linear(256, hidden_channels)
|
383 |
+
|
384 |
+
self.encoder = attentions.Encoder(
|
385 |
+
hidden_channels,
|
386 |
+
filter_channels,
|
387 |
+
n_heads,
|
388 |
+
n_layers,
|
389 |
+
kernel_size,
|
390 |
+
p_dropout,
|
391 |
+
gin_channels=self.gin_channels,
|
392 |
+
)
|
393 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
394 |
+
|
395 |
+
def forward(
|
396 |
+
self,
|
397 |
+
x: torch.Tensor,
|
398 |
+
x_lengths: torch.Tensor,
|
399 |
+
tone: torch.Tensor,
|
400 |
+
language: torch.Tensor,
|
401 |
+
bert: torch.Tensor,
|
402 |
+
ja_bert: torch.Tensor,
|
403 |
+
en_bert: torch.Tensor,
|
404 |
+
style_vec: torch.Tensor,
|
405 |
+
sid: torch.Tensor,
|
406 |
+
g: Optional[torch.Tensor] = None,
|
407 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
408 |
+
bert_emb = self.bert_proj(bert).transpose(1, 2)
|
409 |
+
ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
|
410 |
+
en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
|
411 |
+
style_emb = self.style_proj(style_vec.unsqueeze(1))
|
412 |
+
|
413 |
+
x = (
|
414 |
+
self.emb(x)
|
415 |
+
+ self.tone_emb(tone)
|
416 |
+
+ self.language_emb(language)
|
417 |
+
+ bert_emb
|
418 |
+
+ ja_bert_emb
|
419 |
+
+ en_bert_emb
|
420 |
+
+ style_emb
|
421 |
+
) * math.sqrt(
|
422 |
+
self.hidden_channels
|
423 |
+
) # [b, t, h]
|
424 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
425 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
426 |
+
x.dtype
|
427 |
+
)
|
428 |
+
|
429 |
+
x = self.encoder(x * x_mask, x_mask, g=g)
|
430 |
+
stats = self.proj(x) * x_mask
|
431 |
+
|
432 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
433 |
+
return x, m, logs, x_mask
|
434 |
+
|
435 |
+
|
436 |
+
class ResidualCouplingBlock(nn.Module):
|
437 |
+
def __init__(
|
438 |
+
self,
|
439 |
+
channels: int,
|
440 |
+
hidden_channels: int,
|
441 |
+
kernel_size: int,
|
442 |
+
dilation_rate: int,
|
443 |
+
n_layers: int,
|
444 |
+
n_flows: int = 4,
|
445 |
+
gin_channels: int = 0,
|
446 |
+
) -> None:
|
447 |
+
super().__init__()
|
448 |
+
self.channels = channels
|
449 |
+
self.hidden_channels = hidden_channels
|
450 |
+
self.kernel_size = kernel_size
|
451 |
+
self.dilation_rate = dilation_rate
|
452 |
+
self.n_layers = n_layers
|
453 |
+
self.n_flows = n_flows
|
454 |
+
self.gin_channels = gin_channels
|
455 |
+
|
456 |
+
self.flows = nn.ModuleList()
|
457 |
+
for i in range(n_flows):
|
458 |
+
self.flows.append(
|
459 |
+
modules.ResidualCouplingLayer(
|
460 |
+
channels,
|
461 |
+
hidden_channels,
|
462 |
+
kernel_size,
|
463 |
+
dilation_rate,
|
464 |
+
n_layers,
|
465 |
+
gin_channels=gin_channels,
|
466 |
+
mean_only=True,
|
467 |
+
)
|
468 |
+
)
|
469 |
+
self.flows.append(modules.Flip())
|
470 |
+
|
471 |
+
def forward(
|
472 |
+
self,
|
473 |
+
x: torch.Tensor,
|
474 |
+
x_mask: torch.Tensor,
|
475 |
+
g: Optional[torch.Tensor] = None,
|
476 |
+
reverse: bool = False,
|
477 |
+
) -> torch.Tensor:
|
478 |
+
if not reverse:
|
479 |
+
for flow in self.flows:
|
480 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
481 |
+
else:
|
482 |
+
for flow in reversed(self.flows):
|
483 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
484 |
+
return x
|
485 |
+
|
486 |
+
|
487 |
+
class PosteriorEncoder(nn.Module):
|
488 |
+
def __init__(
|
489 |
+
self,
|
490 |
+
in_channels: int,
|
491 |
+
out_channels: int,
|
492 |
+
hidden_channels: int,
|
493 |
+
kernel_size: int,
|
494 |
+
dilation_rate: int,
|
495 |
+
n_layers: int,
|
496 |
+
gin_channels: int = 0,
|
497 |
+
) -> None:
|
498 |
+
super().__init__()
|
499 |
+
self.in_channels = in_channels
|
500 |
+
self.out_channels = out_channels
|
501 |
+
self.hidden_channels = hidden_channels
|
502 |
+
self.kernel_size = kernel_size
|
503 |
+
self.dilation_rate = dilation_rate
|
504 |
+
self.n_layers = n_layers
|
505 |
+
self.gin_channels = gin_channels
|
506 |
+
|
507 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
508 |
+
self.enc = modules.WN(
|
509 |
+
hidden_channels,
|
510 |
+
kernel_size,
|
511 |
+
dilation_rate,
|
512 |
+
n_layers,
|
513 |
+
gin_channels=gin_channels,
|
514 |
+
)
|
515 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
516 |
+
|
517 |
+
def forward(
|
518 |
+
self,
|
519 |
+
x: torch.Tensor,
|
520 |
+
x_lengths: torch.Tensor,
|
521 |
+
g: Optional[torch.Tensor] = None,
|
522 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
523 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
524 |
+
x.dtype
|
525 |
+
)
|
526 |
+
x = self.pre(x) * x_mask
|
527 |
+
x = self.enc(x, x_mask, g=g)
|
528 |
+
stats = self.proj(x) * x_mask
|
529 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
530 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
531 |
+
return z, m, logs, x_mask
|
532 |
+
|
533 |
+
|
534 |
+
class Generator(torch.nn.Module):
|
535 |
+
def __init__(
|
536 |
+
self,
|
537 |
+
initial_channel: int,
|
538 |
+
resblock_str: str,
|
539 |
+
resblock_kernel_sizes: list[int],
|
540 |
+
resblock_dilation_sizes: list[list[int]],
|
541 |
+
upsample_rates: list[int],
|
542 |
+
upsample_initial_channel: int,
|
543 |
+
upsample_kernel_sizes: list[int],
|
544 |
+
gin_channels: int = 0,
|
545 |
+
) -> None:
|
546 |
+
super(Generator, self).__init__()
|
547 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
548 |
+
self.num_upsamples = len(upsample_rates)
|
549 |
+
self.conv_pre = Conv1d(
|
550 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
551 |
+
)
|
552 |
+
resblock = modules.ResBlock1 if resblock_str == "1" else modules.ResBlock2
|
553 |
+
|
554 |
+
self.ups = nn.ModuleList()
|
555 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
556 |
+
self.ups.append(
|
557 |
+
weight_norm(
|
558 |
+
ConvTranspose1d(
|
559 |
+
upsample_initial_channel // (2**i),
|
560 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
561 |
+
k,
|
562 |
+
u,
|
563 |
+
padding=(k - u) // 2,
|
564 |
+
)
|
565 |
+
)
|
566 |
+
)
|
567 |
+
|
568 |
+
self.resblocks = nn.ModuleList()
|
569 |
+
ch = None
|
570 |
+
for i in range(len(self.ups)):
|
571 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
572 |
+
for j, (k, d) in enumerate(
|
573 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
574 |
+
):
|
575 |
+
self.resblocks.append(resblock(ch, k, d)) # type: ignore
|
576 |
+
|
577 |
+
assert ch is not None
|
578 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
579 |
+
self.ups.apply(commons.init_weights)
|
580 |
+
|
581 |
+
if gin_channels != 0:
|
582 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
583 |
+
|
584 |
+
def forward(
|
585 |
+
self, x: torch.Tensor, g: Optional[torch.Tensor] = None
|
586 |
+
) -> torch.Tensor:
|
587 |
+
x = self.conv_pre(x)
|
588 |
+
if g is not None:
|
589 |
+
x = x + self.cond(g)
|
590 |
+
|
591 |
+
for i in range(self.num_upsamples):
|
592 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
593 |
+
x = self.ups[i](x)
|
594 |
+
xs = None
|
595 |
+
for j in range(self.num_kernels):
|
596 |
+
if xs is None:
|
597 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
598 |
+
else:
|
599 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
600 |
+
assert xs is not None
|
601 |
+
x = xs / self.num_kernels
|
602 |
+
x = F.leaky_relu(x)
|
603 |
+
x = self.conv_post(x)
|
604 |
+
x = torch.tanh(x)
|
605 |
+
|
606 |
+
return x
|
607 |
+
|
608 |
+
def remove_weight_norm(self) -> None:
|
609 |
+
print("Removing weight norm...")
|
610 |
+
for layer in self.ups:
|
611 |
+
remove_weight_norm(layer)
|
612 |
+
for layer in self.resblocks:
|
613 |
+
layer.remove_weight_norm()
|
614 |
+
|
615 |
+
|
616 |
+
class DiscriminatorP(torch.nn.Module):
|
617 |
+
def __init__(
|
618 |
+
self,
|
619 |
+
period: int,
|
620 |
+
kernel_size: int = 5,
|
621 |
+
stride: int = 3,
|
622 |
+
use_spectral_norm: bool = False,
|
623 |
+
) -> None:
|
624 |
+
super(DiscriminatorP, self).__init__()
|
625 |
+
self.period = period
|
626 |
+
self.use_spectral_norm = use_spectral_norm
|
627 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
628 |
+
self.convs = nn.ModuleList(
|
629 |
+
[
|
630 |
+
norm_f(
|
631 |
+
Conv2d(
|
632 |
+
1,
|
633 |
+
32,
|
634 |
+
(kernel_size, 1),
|
635 |
+
(stride, 1),
|
636 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
637 |
+
)
|
638 |
+
),
|
639 |
+
norm_f(
|
640 |
+
Conv2d(
|
641 |
+
32,
|
642 |
+
128,
|
643 |
+
(kernel_size, 1),
|
644 |
+
(stride, 1),
|
645 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
646 |
+
)
|
647 |
+
),
|
648 |
+
norm_f(
|
649 |
+
Conv2d(
|
650 |
+
128,
|
651 |
+
512,
|
652 |
+
(kernel_size, 1),
|
653 |
+
(stride, 1),
|
654 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
655 |
+
)
|
656 |
+
),
|
657 |
+
norm_f(
|
658 |
+
Conv2d(
|
659 |
+
512,
|
660 |
+
1024,
|
661 |
+
(kernel_size, 1),
|
662 |
+
(stride, 1),
|
663 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
664 |
+
)
|
665 |
+
),
|
666 |
+
norm_f(
|
667 |
+
Conv2d(
|
668 |
+
1024,
|
669 |
+
1024,
|
670 |
+
(kernel_size, 1),
|
671 |
+
1,
|
672 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
673 |
+
)
|
674 |
+
),
|
675 |
+
]
|
676 |
+
)
|
677 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
678 |
+
|
679 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
680 |
+
fmap = []
|
681 |
+
|
682 |
+
# 1d to 2d
|
683 |
+
b, c, t = x.shape
|
684 |
+
if t % self.period != 0: # pad first
|
685 |
+
n_pad = self.period - (t % self.period)
|
686 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
687 |
+
t = t + n_pad
|
688 |
+
x = x.view(b, c, t // self.period, self.period)
|
689 |
+
|
690 |
+
for layer in self.convs:
|
691 |
+
x = layer(x)
|
692 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
693 |
+
fmap.append(x)
|
694 |
+
x = self.conv_post(x)
|
695 |
+
fmap.append(x)
|
696 |
+
x = torch.flatten(x, 1, -1)
|
697 |
+
|
698 |
+
return x, fmap
|
699 |
+
|
700 |
+
|
701 |
+
class DiscriminatorS(torch.nn.Module):
|
702 |
+
def __init__(self, use_spectral_norm: bool = False) -> None:
|
703 |
+
super(DiscriminatorS, self).__init__()
|
704 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
705 |
+
self.convs = nn.ModuleList(
|
706 |
+
[
|
707 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
708 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
709 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
710 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
711 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
712 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
713 |
+
]
|
714 |
+
)
|
715 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
716 |
+
|
717 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
718 |
+
fmap = []
|
719 |
+
|
720 |
+
for layer in self.convs:
|
721 |
+
x = layer(x)
|
722 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
723 |
+
fmap.append(x)
|
724 |
+
x = self.conv_post(x)
|
725 |
+
fmap.append(x)
|
726 |
+
x = torch.flatten(x, 1, -1)
|
727 |
+
|
728 |
+
return x, fmap
|
729 |
+
|
730 |
+
|
731 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
732 |
+
def __init__(self, use_spectral_norm: bool = False) -> None:
|
733 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
734 |
+
periods = [2, 3, 5, 7, 11]
|
735 |
+
|
736 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
737 |
+
discs = discs + [
|
738 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
739 |
+
]
|
740 |
+
self.discriminators = nn.ModuleList(discs)
|
741 |
+
|
742 |
+
def forward(
|
743 |
+
self,
|
744 |
+
y: torch.Tensor,
|
745 |
+
y_hat: torch.Tensor,
|
746 |
+
) -> tuple[
|
747 |
+
list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]
|
748 |
+
]:
|
749 |
+
y_d_rs = []
|
750 |
+
y_d_gs = []
|
751 |
+
fmap_rs = []
|
752 |
+
fmap_gs = []
|
753 |
+
for i, d in enumerate(self.discriminators):
|
754 |
+
y_d_r, fmap_r = d(y)
|
755 |
+
y_d_g, fmap_g = d(y_hat)
|
756 |
+
y_d_rs.append(y_d_r)
|
757 |
+
y_d_gs.append(y_d_g)
|
758 |
+
fmap_rs.append(fmap_r)
|
759 |
+
fmap_gs.append(fmap_g)
|
760 |
+
|
761 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
762 |
+
|
763 |
+
|
764 |
+
class ReferenceEncoder(nn.Module):
|
765 |
+
"""
|
766 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
767 |
+
outputs --- [N, ref_enc_gru_size]
|
768 |
+
"""
|
769 |
+
|
770 |
+
def __init__(self, spec_channels: int, gin_channels: int = 0) -> None:
|
771 |
+
super().__init__()
|
772 |
+
self.spec_channels = spec_channels
|
773 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
774 |
+
K = len(ref_enc_filters)
|
775 |
+
filters = [1] + ref_enc_filters
|
776 |
+
convs = [
|
777 |
+
weight_norm(
|
778 |
+
nn.Conv2d(
|
779 |
+
in_channels=filters[i],
|
780 |
+
out_channels=filters[i + 1],
|
781 |
+
kernel_size=(3, 3),
|
782 |
+
stride=(2, 2),
|
783 |
+
padding=(1, 1),
|
784 |
+
)
|
785 |
+
)
|
786 |
+
for i in range(K)
|
787 |
+
]
|
788 |
+
self.convs = nn.ModuleList(convs)
|
789 |
+
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
|
790 |
+
|
791 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
792 |
+
self.gru = nn.GRU(
|
793 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
794 |
+
hidden_size=256 // 2,
|
795 |
+
batch_first=True,
|
796 |
+
)
|
797 |
+
self.proj = nn.Linear(128, gin_channels)
|
798 |
+
|
799 |
+
def forward(
|
800 |
+
self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None
|
801 |
+
) -> torch.Tensor:
|
802 |
+
N = inputs.size(0)
|
803 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
804 |
+
for conv in self.convs:
|
805 |
+
out = conv(out)
|
806 |
+
# out = wn(out)
|
807 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
808 |
+
|
809 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
810 |
+
T = out.size(1)
|
811 |
+
N = out.size(0)
|
812 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
813 |
+
|
814 |
+
self.gru.flatten_parameters()
|
815 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
816 |
+
|
817 |
+
return self.proj(out.squeeze(0))
|
818 |
+
|
819 |
+
def calculate_channels(
|
820 |
+
self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
|
821 |
+
) -> int:
|
822 |
+
for i in range(n_convs):
|
823 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
824 |
+
return L
|
825 |
+
|
826 |
+
|
827 |
+
class SynthesizerTrn(nn.Module):
|
828 |
+
"""
|
829 |
+
Synthesizer for Training
|
830 |
+
"""
|
831 |
+
|
832 |
+
def __init__(
|
833 |
+
self,
|
834 |
+
n_vocab: int,
|
835 |
+
spec_channels: int,
|
836 |
+
segment_size: int,
|
837 |
+
inter_channels: int,
|
838 |
+
hidden_channels: int,
|
839 |
+
filter_channels: int,
|
840 |
+
n_heads: int,
|
841 |
+
n_layers: int,
|
842 |
+
kernel_size: int,
|
843 |
+
p_dropout: float,
|
844 |
+
resblock: str,
|
845 |
+
resblock_kernel_sizes: list[int],
|
846 |
+
resblock_dilation_sizes: list[list[int]],
|
847 |
+
upsample_rates: list[int],
|
848 |
+
upsample_initial_channel: int,
|
849 |
+
upsample_kernel_sizes: list[int],
|
850 |
+
n_speakers: int = 256,
|
851 |
+
gin_channels: int = 256,
|
852 |
+
use_sdp: bool = True,
|
853 |
+
n_flow_layer: int = 4,
|
854 |
+
n_layers_trans_flow: int = 4,
|
855 |
+
flow_share_parameter: bool = False,
|
856 |
+
use_transformer_flow: bool = True,
|
857 |
+
**kwargs: Any,
|
858 |
+
) -> None:
|
859 |
+
super().__init__()
|
860 |
+
self.n_vocab = n_vocab
|
861 |
+
self.spec_channels = spec_channels
|
862 |
+
self.inter_channels = inter_channels
|
863 |
+
self.hidden_channels = hidden_channels
|
864 |
+
self.filter_channels = filter_channels
|
865 |
+
self.n_heads = n_heads
|
866 |
+
self.n_layers = n_layers
|
867 |
+
self.kernel_size = kernel_size
|
868 |
+
self.p_dropout = p_dropout
|
869 |
+
self.resblock = resblock
|
870 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
871 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
872 |
+
self.upsample_rates = upsample_rates
|
873 |
+
self.upsample_initial_channel = upsample_initial_channel
|
874 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
875 |
+
self.segment_size = segment_size
|
876 |
+
self.n_speakers = n_speakers
|
877 |
+
self.gin_channels = gin_channels
|
878 |
+
self.n_layers_trans_flow = n_layers_trans_flow
|
879 |
+
self.use_spk_conditioned_encoder = kwargs.get(
|
880 |
+
"use_spk_conditioned_encoder", True
|
881 |
+
)
|
882 |
+
self.use_sdp = use_sdp
|
883 |
+
self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
|
884 |
+
self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
|
885 |
+
self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
|
886 |
+
self.current_mas_noise_scale = self.mas_noise_scale_initial
|
887 |
+
if self.use_spk_conditioned_encoder and gin_channels > 0:
|
888 |
+
self.enc_gin_channels = gin_channels
|
889 |
+
self.enc_p = TextEncoder(
|
890 |
+
n_vocab,
|
891 |
+
inter_channels,
|
892 |
+
hidden_channels,
|
893 |
+
filter_channels,
|
894 |
+
n_heads,
|
895 |
+
n_layers,
|
896 |
+
kernel_size,
|
897 |
+
p_dropout,
|
898 |
+
self.n_speakers,
|
899 |
+
gin_channels=self.enc_gin_channels,
|
900 |
+
)
|
901 |
+
self.dec = Generator(
|
902 |
+
inter_channels,
|
903 |
+
resblock,
|
904 |
+
resblock_kernel_sizes,
|
905 |
+
resblock_dilation_sizes,
|
906 |
+
upsample_rates,
|
907 |
+
upsample_initial_channel,
|
908 |
+
upsample_kernel_sizes,
|
909 |
+
gin_channels=gin_channels,
|
910 |
+
)
|
911 |
+
self.enc_q = PosteriorEncoder(
|
912 |
+
spec_channels,
|
913 |
+
inter_channels,
|
914 |
+
hidden_channels,
|
915 |
+
5,
|
916 |
+
1,
|
917 |
+
16,
|
918 |
+
gin_channels=gin_channels,
|
919 |
+
)
|
920 |
+
if use_transformer_flow:
|
921 |
+
self.flow = TransformerCouplingBlock(
|
922 |
+
inter_channels,
|
923 |
+
hidden_channels,
|
924 |
+
filter_channels,
|
925 |
+
n_heads,
|
926 |
+
n_layers_trans_flow,
|
927 |
+
5,
|
928 |
+
p_dropout,
|
929 |
+
n_flow_layer,
|
930 |
+
gin_channels=gin_channels,
|
931 |
+
share_parameter=flow_share_parameter,
|
932 |
+
)
|
933 |
+
else:
|
934 |
+
self.flow = ResidualCouplingBlock(
|
935 |
+
inter_channels,
|
936 |
+
hidden_channels,
|
937 |
+
5,
|
938 |
+
1,
|
939 |
+
n_flow_layer,
|
940 |
+
gin_channels=gin_channels,
|
941 |
+
)
|
942 |
+
self.sdp = StochasticDurationPredictor(
|
943 |
+
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
|
944 |
+
)
|
945 |
+
self.dp = DurationPredictor(
|
946 |
+
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
|
947 |
+
)
|
948 |
+
|
949 |
+
if n_speakers >= 1:
|
950 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
951 |
+
else:
|
952 |
+
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
953 |
+
|
954 |
+
def forward(
|
955 |
+
self,
|
956 |
+
x: torch.Tensor,
|
957 |
+
x_lengths: torch.Tensor,
|
958 |
+
y: torch.Tensor,
|
959 |
+
y_lengths: torch.Tensor,
|
960 |
+
sid: torch.Tensor,
|
961 |
+
tone: torch.Tensor,
|
962 |
+
language: torch.Tensor,
|
963 |
+
bert: torch.Tensor,
|
964 |
+
ja_bert: torch.Tensor,
|
965 |
+
en_bert: torch.Tensor,
|
966 |
+
style_vec: torch.Tensor,
|
967 |
+
) -> tuple[
|
968 |
+
torch.Tensor,
|
969 |
+
torch.Tensor,
|
970 |
+
torch.Tensor,
|
971 |
+
torch.Tensor,
|
972 |
+
torch.Tensor,
|
973 |
+
torch.Tensor,
|
974 |
+
tuple[torch.Tensor, ...],
|
975 |
+
tuple[torch.Tensor, ...],
|
976 |
+
]:
|
977 |
+
if self.n_speakers > 0:
|
978 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
979 |
+
else:
|
980 |
+
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
981 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
982 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g
|
983 |
+
)
|
984 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
985 |
+
z_p = self.flow(z, y_mask, g=g)
|
986 |
+
|
987 |
+
with torch.no_grad():
|
988 |
+
# negative cross-entropy
|
989 |
+
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
990 |
+
neg_cent1 = torch.sum(
|
991 |
+
-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
|
992 |
+
) # [b, 1, t_s]
|
993 |
+
neg_cent2 = torch.matmul(
|
994 |
+
-0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
|
995 |
+
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
996 |
+
neg_cent3 = torch.matmul(
|
997 |
+
z_p.transpose(1, 2), (m_p * s_p_sq_r)
|
998 |
+
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
999 |
+
neg_cent4 = torch.sum(
|
1000 |
+
-0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
|
1001 |
+
) # [b, 1, t_s]
|
1002 |
+
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
1003 |
+
if self.use_noise_scaled_mas:
|
1004 |
+
epsilon = (
|
1005 |
+
torch.std(neg_cent)
|
1006 |
+
* torch.randn_like(neg_cent)
|
1007 |
+
* self.current_mas_noise_scale
|
1008 |
+
)
|
1009 |
+
neg_cent = neg_cent + epsilon
|
1010 |
+
|
1011 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
1012 |
+
attn = (
|
1013 |
+
monotonic_alignment.maximum_path(neg_cent, attn_mask.squeeze(1))
|
1014 |
+
.unsqueeze(1)
|
1015 |
+
.detach()
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
w = attn.sum(2)
|
1019 |
+
|
1020 |
+
l_length_sdp = self.sdp(x, x_mask, w, g=g)
|
1021 |
+
l_length_sdp = l_length_sdp / torch.sum(x_mask)
|
1022 |
+
|
1023 |
+
logw_ = torch.log(w + 1e-6) * x_mask
|
1024 |
+
logw = self.dp(x, x_mask, g=g)
|
1025 |
+
# logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
|
1026 |
+
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
1027 |
+
x_mask
|
1028 |
+
) # for averaging
|
1029 |
+
# l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
1030 |
+
|
1031 |
+
l_length = l_length_dp + l_length_sdp
|
1032 |
+
|
1033 |
+
# expand prior
|
1034 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
1035 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
1036 |
+
|
1037 |
+
z_slice, ids_slice = commons.rand_slice_segments(
|
1038 |
+
z, y_lengths, self.segment_size
|
1039 |
+
)
|
1040 |
+
o = self.dec(z_slice, g=g)
|
1041 |
+
return (
|
1042 |
+
o,
|
1043 |
+
l_length,
|
1044 |
+
attn,
|
1045 |
+
ids_slice,
|
1046 |
+
x_mask,
|
1047 |
+
y_mask,
|
1048 |
+
(z, z_p, m_p, logs_p, m_q, logs_q),
|
1049 |
+
(x, logw, logw_),
|
1050 |
+
)
|
1051 |
+
|
1052 |
+
def infer(
|
1053 |
+
self,
|
1054 |
+
x: torch.Tensor,
|
1055 |
+
x_lengths: torch.Tensor,
|
1056 |
+
sid: torch.Tensor,
|
1057 |
+
tone: torch.Tensor,
|
1058 |
+
language: torch.Tensor,
|
1059 |
+
bert: torch.Tensor,
|
1060 |
+
ja_bert: torch.Tensor,
|
1061 |
+
en_bert: torch.Tensor,
|
1062 |
+
style_vec: torch.Tensor,
|
1063 |
+
noise_scale: float = 0.667,
|
1064 |
+
length_scale: float = 1.0,
|
1065 |
+
noise_scale_w: float = 0.8,
|
1066 |
+
max_len: Optional[int] = None,
|
1067 |
+
sdp_ratio: float = 0.0,
|
1068 |
+
y: Optional[torch.Tensor] = None,
|
1069 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, ...]]:
|
1070 |
+
# x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
|
1071 |
+
# g = self.gst(y)
|
1072 |
+
if self.n_speakers > 0:
|
1073 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
1074 |
+
else:
|
1075 |
+
assert y is not None
|
1076 |
+
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
1077 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
1078 |
+
x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g
|
1079 |
+
)
|
1080 |
+
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
|
1081 |
+
sdp_ratio
|
1082 |
+
) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
|
1083 |
+
w = torch.exp(logw) * x_mask * length_scale
|
1084 |
+
w_ceil = torch.ceil(w)
|
1085 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
1086 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
|
1087 |
+
x_mask.dtype
|
1088 |
+
)
|
1089 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
1090 |
+
attn = commons.generate_path(w_ceil, attn_mask)
|
1091 |
+
|
1092 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
|
1093 |
+
1, 2
|
1094 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
1095 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
|
1096 |
+
1, 2
|
1097 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
1098 |
+
|
1099 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
1100 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
1101 |
+
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
|
1102 |
+
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
style_bert_vits2/models/models_jp_extra.py
ADDED
@@ -0,0 +1,1157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
9 |
+
|
10 |
+
from style_bert_vits2.models import attentions, commons, modules, monotonic_alignment
|
11 |
+
from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS
|
12 |
+
|
13 |
+
|
14 |
+
class DurationDiscriminator(nn.Module): # vits2
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_channels: int,
|
18 |
+
filter_channels: int,
|
19 |
+
kernel_size: int,
|
20 |
+
p_dropout: float,
|
21 |
+
gin_channels: int = 0,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.in_channels = in_channels
|
26 |
+
self.filter_channels = filter_channels
|
27 |
+
self.kernel_size = kernel_size
|
28 |
+
self.p_dropout = p_dropout
|
29 |
+
self.gin_channels = gin_channels
|
30 |
+
|
31 |
+
self.drop = nn.Dropout(p_dropout)
|
32 |
+
self.conv_1 = nn.Conv1d(
|
33 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
34 |
+
)
|
35 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
36 |
+
self.conv_2 = nn.Conv1d(
|
37 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
38 |
+
)
|
39 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
40 |
+
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
41 |
+
|
42 |
+
self.LSTM = nn.LSTM(
|
43 |
+
2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
|
44 |
+
)
|
45 |
+
|
46 |
+
if gin_channels != 0:
|
47 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
48 |
+
|
49 |
+
self.output_layer = nn.Sequential(
|
50 |
+
nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward_probability(self, x: torch.Tensor, dur: torch.Tensor) -> torch.Tensor:
|
54 |
+
dur = self.dur_proj(dur)
|
55 |
+
x = torch.cat([x, dur], dim=1)
|
56 |
+
x = x.transpose(1, 2)
|
57 |
+
x, _ = self.LSTM(x)
|
58 |
+
output_prob = self.output_layer(x)
|
59 |
+
return output_prob
|
60 |
+
|
61 |
+
def forward(
|
62 |
+
self,
|
63 |
+
x: torch.Tensor,
|
64 |
+
x_mask: torch.Tensor,
|
65 |
+
dur_r: torch.Tensor,
|
66 |
+
dur_hat: torch.Tensor,
|
67 |
+
g: Optional[torch.Tensor] = None,
|
68 |
+
) -> list[torch.Tensor]:
|
69 |
+
x = torch.detach(x)
|
70 |
+
if g is not None:
|
71 |
+
g = torch.detach(g)
|
72 |
+
x = x + self.cond(g)
|
73 |
+
x = self.conv_1(x * x_mask)
|
74 |
+
x = torch.relu(x)
|
75 |
+
x = self.norm_1(x)
|
76 |
+
x = self.drop(x)
|
77 |
+
x = self.conv_2(x * x_mask)
|
78 |
+
x = torch.relu(x)
|
79 |
+
x = self.norm_2(x)
|
80 |
+
x = self.drop(x)
|
81 |
+
|
82 |
+
output_probs = []
|
83 |
+
for dur in [dur_r, dur_hat]:
|
84 |
+
output_prob = self.forward_probability(x, dur)
|
85 |
+
output_probs.append(output_prob)
|
86 |
+
|
87 |
+
return output_probs
|
88 |
+
|
89 |
+
|
90 |
+
class TransformerCouplingBlock(nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
channels: int,
|
94 |
+
hidden_channels: int,
|
95 |
+
filter_channels: int,
|
96 |
+
n_heads: int,
|
97 |
+
n_layers: int,
|
98 |
+
kernel_size: int,
|
99 |
+
p_dropout: float,
|
100 |
+
n_flows: int = 4,
|
101 |
+
gin_channels: int = 0,
|
102 |
+
share_parameter: bool = False,
|
103 |
+
) -> None:
|
104 |
+
super().__init__()
|
105 |
+
self.channels = channels
|
106 |
+
self.hidden_channels = hidden_channels
|
107 |
+
self.kernel_size = kernel_size
|
108 |
+
self.n_layers = n_layers
|
109 |
+
self.n_flows = n_flows
|
110 |
+
self.gin_channels = gin_channels
|
111 |
+
|
112 |
+
self.flows = nn.ModuleList()
|
113 |
+
|
114 |
+
self.wn = (
|
115 |
+
# attentions.FFT(
|
116 |
+
# hidden_channels,
|
117 |
+
# filter_channels,
|
118 |
+
# n_heads,
|
119 |
+
# n_layers,
|
120 |
+
# kernel_size,
|
121 |
+
# p_dropout,
|
122 |
+
# isflow=True,
|
123 |
+
# gin_channels=self.gin_channels,
|
124 |
+
# )
|
125 |
+
None
|
126 |
+
if share_parameter
|
127 |
+
else None
|
128 |
+
)
|
129 |
+
|
130 |
+
for i in range(n_flows):
|
131 |
+
self.flows.append(
|
132 |
+
modules.TransformerCouplingLayer(
|
133 |
+
channels,
|
134 |
+
hidden_channels,
|
135 |
+
kernel_size,
|
136 |
+
n_layers,
|
137 |
+
n_heads,
|
138 |
+
p_dropout,
|
139 |
+
filter_channels,
|
140 |
+
mean_only=True,
|
141 |
+
wn_sharing_parameter=self.wn,
|
142 |
+
gin_channels=self.gin_channels,
|
143 |
+
)
|
144 |
+
)
|
145 |
+
self.flows.append(modules.Flip())
|
146 |
+
|
147 |
+
def forward(
|
148 |
+
self,
|
149 |
+
x: torch.Tensor,
|
150 |
+
x_mask: torch.Tensor,
|
151 |
+
g: Optional[torch.Tensor] = None,
|
152 |
+
reverse: bool = False,
|
153 |
+
) -> torch.Tensor:
|
154 |
+
if not reverse:
|
155 |
+
for flow in self.flows:
|
156 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
157 |
+
else:
|
158 |
+
for flow in reversed(self.flows):
|
159 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
class StochasticDurationPredictor(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
in_channels: int,
|
167 |
+
filter_channels: int,
|
168 |
+
kernel_size: int,
|
169 |
+
p_dropout: float,
|
170 |
+
n_flows: int = 4,
|
171 |
+
gin_channels: int = 0,
|
172 |
+
) -> None:
|
173 |
+
super().__init__()
|
174 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
175 |
+
self.in_channels = in_channels
|
176 |
+
self.filter_channels = filter_channels
|
177 |
+
self.kernel_size = kernel_size
|
178 |
+
self.p_dropout = p_dropout
|
179 |
+
self.n_flows = n_flows
|
180 |
+
self.gin_channels = gin_channels
|
181 |
+
|
182 |
+
self.log_flow = modules.Log()
|
183 |
+
self.flows = nn.ModuleList()
|
184 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
185 |
+
for i in range(n_flows):
|
186 |
+
self.flows.append(
|
187 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
188 |
+
)
|
189 |
+
self.flows.append(modules.Flip())
|
190 |
+
|
191 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
192 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
193 |
+
self.post_convs = modules.DDSConv(
|
194 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
195 |
+
)
|
196 |
+
self.post_flows = nn.ModuleList()
|
197 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
198 |
+
for i in range(4):
|
199 |
+
self.post_flows.append(
|
200 |
+
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
201 |
+
)
|
202 |
+
self.post_flows.append(modules.Flip())
|
203 |
+
|
204 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
205 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
206 |
+
self.convs = modules.DDSConv(
|
207 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
208 |
+
)
|
209 |
+
if gin_channels != 0:
|
210 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
211 |
+
|
212 |
+
def forward(
|
213 |
+
self,
|
214 |
+
x: torch.Tensor,
|
215 |
+
x_mask: torch.Tensor,
|
216 |
+
w: Optional[torch.Tensor] = None,
|
217 |
+
g: Optional[torch.Tensor] = None,
|
218 |
+
reverse: bool = False,
|
219 |
+
noise_scale: float = 1.0,
|
220 |
+
) -> torch.Tensor:
|
221 |
+
x = torch.detach(x)
|
222 |
+
x = self.pre(x)
|
223 |
+
if g is not None:
|
224 |
+
g = torch.detach(g)
|
225 |
+
x = x + self.cond(g)
|
226 |
+
x = self.convs(x, x_mask)
|
227 |
+
x = self.proj(x) * x_mask
|
228 |
+
|
229 |
+
if not reverse:
|
230 |
+
flows = self.flows
|
231 |
+
assert w is not None
|
232 |
+
|
233 |
+
logdet_tot_q = 0
|
234 |
+
h_w = self.post_pre(w)
|
235 |
+
h_w = self.post_convs(h_w, x_mask)
|
236 |
+
h_w = self.post_proj(h_w) * x_mask
|
237 |
+
e_q = (
|
238 |
+
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
239 |
+
* x_mask
|
240 |
+
)
|
241 |
+
z_q = e_q
|
242 |
+
for flow in self.post_flows:
|
243 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
244 |
+
logdet_tot_q += logdet_q
|
245 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
246 |
+
u = torch.sigmoid(z_u) * x_mask
|
247 |
+
z0 = (w - u) * x_mask
|
248 |
+
logdet_tot_q += torch.sum(
|
249 |
+
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
250 |
+
)
|
251 |
+
logq = (
|
252 |
+
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
253 |
+
- logdet_tot_q
|
254 |
+
)
|
255 |
+
|
256 |
+
logdet_tot = 0
|
257 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
258 |
+
logdet_tot += logdet
|
259 |
+
z = torch.cat([z0, z1], 1)
|
260 |
+
for flow in flows:
|
261 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
262 |
+
logdet_tot = logdet_tot + logdet
|
263 |
+
nll = (
|
264 |
+
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
265 |
+
- logdet_tot
|
266 |
+
)
|
267 |
+
return nll + logq # [b]
|
268 |
+
else:
|
269 |
+
flows = list(reversed(self.flows))
|
270 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
271 |
+
z = (
|
272 |
+
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
273 |
+
* noise_scale
|
274 |
+
)
|
275 |
+
for flow in flows:
|
276 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
277 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
278 |
+
logw = z0
|
279 |
+
return logw
|
280 |
+
|
281 |
+
|
282 |
+
class DurationPredictor(nn.Module):
|
283 |
+
def __init__(
|
284 |
+
self,
|
285 |
+
in_channels: int,
|
286 |
+
filter_channels: int,
|
287 |
+
kernel_size: int,
|
288 |
+
p_dropout: float,
|
289 |
+
gin_channels: int = 0,
|
290 |
+
) -> None:
|
291 |
+
super().__init__()
|
292 |
+
|
293 |
+
self.in_channels = in_channels
|
294 |
+
self.filter_channels = filter_channels
|
295 |
+
self.kernel_size = kernel_size
|
296 |
+
self.p_dropout = p_dropout
|
297 |
+
self.gin_channels = gin_channels
|
298 |
+
|
299 |
+
self.drop = nn.Dropout(p_dropout)
|
300 |
+
self.conv_1 = nn.Conv1d(
|
301 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
302 |
+
)
|
303 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
304 |
+
self.conv_2 = nn.Conv1d(
|
305 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
306 |
+
)
|
307 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
308 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
309 |
+
|
310 |
+
if gin_channels != 0:
|
311 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
315 |
+
) -> torch.Tensor:
|
316 |
+
x = torch.detach(x)
|
317 |
+
if g is not None:
|
318 |
+
g = torch.detach(g)
|
319 |
+
x = x + self.cond(g)
|
320 |
+
x = self.conv_1(x * x_mask)
|
321 |
+
x = torch.relu(x)
|
322 |
+
x = self.norm_1(x)
|
323 |
+
x = self.drop(x)
|
324 |
+
x = self.conv_2(x * x_mask)
|
325 |
+
x = torch.relu(x)
|
326 |
+
x = self.norm_2(x)
|
327 |
+
x = self.drop(x)
|
328 |
+
x = self.proj(x * x_mask)
|
329 |
+
return x * x_mask
|
330 |
+
|
331 |
+
|
332 |
+
class Bottleneck(nn.Sequential):
|
333 |
+
def __init__(self, in_dim: int, hidden_dim: int) -> None:
|
334 |
+
c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
335 |
+
c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
|
336 |
+
super().__init__(c_fc1, c_fc2)
|
337 |
+
|
338 |
+
|
339 |
+
class Block(nn.Module):
|
340 |
+
def __init__(self, in_dim: int, hidden_dim: int) -> None:
|
341 |
+
super().__init__()
|
342 |
+
self.norm = nn.LayerNorm(in_dim)
|
343 |
+
self.mlp = MLP(in_dim, hidden_dim)
|
344 |
+
|
345 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
346 |
+
x = x + self.mlp(self.norm(x))
|
347 |
+
return x
|
348 |
+
|
349 |
+
|
350 |
+
class MLP(nn.Module):
|
351 |
+
def __init__(self, in_dim: int, hidden_dim: int) -> None:
|
352 |
+
super().__init__()
|
353 |
+
self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
354 |
+
self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
|
355 |
+
self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
|
356 |
+
|
357 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
358 |
+
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
|
359 |
+
x = self.c_proj(x)
|
360 |
+
return x
|
361 |
+
|
362 |
+
|
363 |
+
class TextEncoder(nn.Module):
|
364 |
+
def __init__(
|
365 |
+
self,
|
366 |
+
n_vocab: int,
|
367 |
+
out_channels: int,
|
368 |
+
hidden_channels: int,
|
369 |
+
filter_channels: int,
|
370 |
+
n_heads: int,
|
371 |
+
n_layers: int,
|
372 |
+
kernel_size: int,
|
373 |
+
p_dropout: float,
|
374 |
+
gin_channels: int = 0,
|
375 |
+
) -> None:
|
376 |
+
super().__init__()
|
377 |
+
self.n_vocab = n_vocab
|
378 |
+
self.out_channels = out_channels
|
379 |
+
self.hidden_channels = hidden_channels
|
380 |
+
self.filter_channels = filter_channels
|
381 |
+
self.n_heads = n_heads
|
382 |
+
self.n_layers = n_layers
|
383 |
+
self.kernel_size = kernel_size
|
384 |
+
self.p_dropout = p_dropout
|
385 |
+
self.gin_channels = gin_channels
|
386 |
+
self.emb = nn.Embedding(len(SYMBOLS), hidden_channels)
|
387 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
388 |
+
self.tone_emb = nn.Embedding(NUM_TONES, hidden_channels)
|
389 |
+
nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
|
390 |
+
self.language_emb = nn.Embedding(NUM_LANGUAGES, hidden_channels)
|
391 |
+
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
|
392 |
+
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
|
393 |
+
|
394 |
+
# Remove emo_vq since it's not working well.
|
395 |
+
self.style_proj = nn.Linear(256, hidden_channels)
|
396 |
+
|
397 |
+
self.encoder = attentions.Encoder(
|
398 |
+
hidden_channels,
|
399 |
+
filter_channels,
|
400 |
+
n_heads,
|
401 |
+
n_layers,
|
402 |
+
kernel_size,
|
403 |
+
p_dropout,
|
404 |
+
gin_channels=self.gin_channels,
|
405 |
+
)
|
406 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
407 |
+
|
408 |
+
def forward(
|
409 |
+
self,
|
410 |
+
x: torch.Tensor,
|
411 |
+
x_lengths: torch.Tensor,
|
412 |
+
tone: torch.Tensor,
|
413 |
+
language: torch.Tensor,
|
414 |
+
bert: torch.Tensor,
|
415 |
+
style_vec: torch.Tensor,
|
416 |
+
g: Optional[torch.Tensor] = None,
|
417 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
418 |
+
bert_emb = self.bert_proj(bert).transpose(1, 2)
|
419 |
+
style_emb = self.style_proj(style_vec.unsqueeze(1))
|
420 |
+
x = (
|
421 |
+
self.emb(x)
|
422 |
+
+ self.tone_emb(tone)
|
423 |
+
+ self.language_emb(language)
|
424 |
+
+ bert_emb
|
425 |
+
+ style_emb
|
426 |
+
) * math.sqrt(
|
427 |
+
self.hidden_channels
|
428 |
+
) # [b, t, h]
|
429 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
430 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
431 |
+
x.dtype
|
432 |
+
)
|
433 |
+
|
434 |
+
x = self.encoder(x * x_mask, x_mask, g=g)
|
435 |
+
stats = self.proj(x) * x_mask
|
436 |
+
|
437 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
438 |
+
return x, m, logs, x_mask
|
439 |
+
|
440 |
+
|
441 |
+
class ResidualCouplingBlock(nn.Module):
|
442 |
+
def __init__(
|
443 |
+
self,
|
444 |
+
channels: int,
|
445 |
+
hidden_channels: int,
|
446 |
+
kernel_size: int,
|
447 |
+
dilation_rate: int,
|
448 |
+
n_layers: int,
|
449 |
+
n_flows: int = 4,
|
450 |
+
gin_channels: int = 0,
|
451 |
+
) -> None:
|
452 |
+
super().__init__()
|
453 |
+
self.channels = channels
|
454 |
+
self.hidden_channels = hidden_channels
|
455 |
+
self.kernel_size = kernel_size
|
456 |
+
self.dilation_rate = dilation_rate
|
457 |
+
self.n_layers = n_layers
|
458 |
+
self.n_flows = n_flows
|
459 |
+
self.gin_channels = gin_channels
|
460 |
+
|
461 |
+
self.flows = nn.ModuleList()
|
462 |
+
for i in range(n_flows):
|
463 |
+
self.flows.append(
|
464 |
+
modules.ResidualCouplingLayer(
|
465 |
+
channels,
|
466 |
+
hidden_channels,
|
467 |
+
kernel_size,
|
468 |
+
dilation_rate,
|
469 |
+
n_layers,
|
470 |
+
gin_channels=gin_channels,
|
471 |
+
mean_only=True,
|
472 |
+
)
|
473 |
+
)
|
474 |
+
self.flows.append(modules.Flip())
|
475 |
+
|
476 |
+
def forward(
|
477 |
+
self,
|
478 |
+
x: torch.Tensor,
|
479 |
+
x_mask: torch.Tensor,
|
480 |
+
g: Optional[torch.Tensor] = None,
|
481 |
+
reverse: bool = False,
|
482 |
+
) -> torch.Tensor:
|
483 |
+
if not reverse:
|
484 |
+
for flow in self.flows:
|
485 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
486 |
+
else:
|
487 |
+
for flow in reversed(self.flows):
|
488 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
489 |
+
return x
|
490 |
+
|
491 |
+
|
492 |
+
class PosteriorEncoder(nn.Module):
|
493 |
+
def __init__(
|
494 |
+
self,
|
495 |
+
in_channels: int,
|
496 |
+
out_channels: int,
|
497 |
+
hidden_channels: int,
|
498 |
+
kernel_size: int,
|
499 |
+
dilation_rate: int,
|
500 |
+
n_layers: int,
|
501 |
+
gin_channels: int = 0,
|
502 |
+
) -> None:
|
503 |
+
super().__init__()
|
504 |
+
self.in_channels = in_channels
|
505 |
+
self.out_channels = out_channels
|
506 |
+
self.hidden_channels = hidden_channels
|
507 |
+
self.kernel_size = kernel_size
|
508 |
+
self.dilation_rate = dilation_rate
|
509 |
+
self.n_layers = n_layers
|
510 |
+
self.gin_channels = gin_channels
|
511 |
+
|
512 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
513 |
+
self.enc = modules.WN(
|
514 |
+
hidden_channels,
|
515 |
+
kernel_size,
|
516 |
+
dilation_rate,
|
517 |
+
n_layers,
|
518 |
+
gin_channels=gin_channels,
|
519 |
+
)
|
520 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
521 |
+
|
522 |
+
def forward(
|
523 |
+
self,
|
524 |
+
x: torch.Tensor,
|
525 |
+
x_lengths: torch.Tensor,
|
526 |
+
g: Optional[torch.Tensor] = None,
|
527 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
528 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
529 |
+
x.dtype
|
530 |
+
)
|
531 |
+
x = self.pre(x) * x_mask
|
532 |
+
x = self.enc(x, x_mask, g=g)
|
533 |
+
stats = self.proj(x) * x_mask
|
534 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
535 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
536 |
+
return z, m, logs, x_mask
|
537 |
+
|
538 |
+
|
539 |
+
class Generator(torch.nn.Module):
|
540 |
+
def __init__(
|
541 |
+
self,
|
542 |
+
initial_channel: int,
|
543 |
+
resblock_str: str,
|
544 |
+
resblock_kernel_sizes: list[int],
|
545 |
+
resblock_dilation_sizes: list[list[int]],
|
546 |
+
upsample_rates: list[int],
|
547 |
+
upsample_initial_channel: int,
|
548 |
+
upsample_kernel_sizes: list[int],
|
549 |
+
gin_channels: int = 0,
|
550 |
+
) -> None:
|
551 |
+
super(Generator, self).__init__()
|
552 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
553 |
+
self.num_upsamples = len(upsample_rates)
|
554 |
+
self.conv_pre = Conv1d(
|
555 |
+
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
556 |
+
)
|
557 |
+
resblock = modules.ResBlock1 if resblock_str == "1" else modules.ResBlock2
|
558 |
+
|
559 |
+
self.ups = nn.ModuleList()
|
560 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
561 |
+
self.ups.append(
|
562 |
+
weight_norm(
|
563 |
+
ConvTranspose1d(
|
564 |
+
upsample_initial_channel // (2**i),
|
565 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
566 |
+
k,
|
567 |
+
u,
|
568 |
+
padding=(k - u) // 2,
|
569 |
+
)
|
570 |
+
)
|
571 |
+
)
|
572 |
+
|
573 |
+
self.resblocks = nn.ModuleList()
|
574 |
+
ch = None
|
575 |
+
for i in range(len(self.ups)):
|
576 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
577 |
+
for j, (k, d) in enumerate(
|
578 |
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
579 |
+
):
|
580 |
+
self.resblocks.append(resblock(ch, k, d)) # type: ignore
|
581 |
+
|
582 |
+
assert ch is not None
|
583 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
584 |
+
self.ups.apply(commons.init_weights)
|
585 |
+
|
586 |
+
if gin_channels != 0:
|
587 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
588 |
+
|
589 |
+
def forward(
|
590 |
+
self, x: torch.Tensor, g: Optional[torch.Tensor] = None
|
591 |
+
) -> torch.Tensor:
|
592 |
+
x = self.conv_pre(x)
|
593 |
+
if g is not None:
|
594 |
+
x = x + self.cond(g)
|
595 |
+
|
596 |
+
for i in range(self.num_upsamples):
|
597 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
598 |
+
x = self.ups[i](x)
|
599 |
+
xs = None
|
600 |
+
for j in range(self.num_kernels):
|
601 |
+
if xs is None:
|
602 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
603 |
+
else:
|
604 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
605 |
+
assert xs is not None
|
606 |
+
x = xs / self.num_kernels
|
607 |
+
x = F.leaky_relu(x)
|
608 |
+
x = self.conv_post(x)
|
609 |
+
x = torch.tanh(x)
|
610 |
+
|
611 |
+
return x
|
612 |
+
|
613 |
+
def remove_weight_norm(self) -> None:
|
614 |
+
print("Removing weight norm...")
|
615 |
+
for layer in self.ups:
|
616 |
+
remove_weight_norm(layer)
|
617 |
+
for layer in self.resblocks:
|
618 |
+
layer.remove_weight_norm()
|
619 |
+
|
620 |
+
|
621 |
+
class DiscriminatorP(torch.nn.Module):
|
622 |
+
def __init__(
|
623 |
+
self,
|
624 |
+
period: int,
|
625 |
+
kernel_size: int = 5,
|
626 |
+
stride: int = 3,
|
627 |
+
use_spectral_norm: bool = False,
|
628 |
+
) -> None:
|
629 |
+
super(DiscriminatorP, self).__init__()
|
630 |
+
self.period = period
|
631 |
+
self.use_spectral_norm = use_spectral_norm
|
632 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
633 |
+
self.convs = nn.ModuleList(
|
634 |
+
[
|
635 |
+
norm_f(
|
636 |
+
Conv2d(
|
637 |
+
1,
|
638 |
+
32,
|
639 |
+
(kernel_size, 1),
|
640 |
+
(stride, 1),
|
641 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
642 |
+
)
|
643 |
+
),
|
644 |
+
norm_f(
|
645 |
+
Conv2d(
|
646 |
+
32,
|
647 |
+
128,
|
648 |
+
(kernel_size, 1),
|
649 |
+
(stride, 1),
|
650 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
651 |
+
)
|
652 |
+
),
|
653 |
+
norm_f(
|
654 |
+
Conv2d(
|
655 |
+
128,
|
656 |
+
512,
|
657 |
+
(kernel_size, 1),
|
658 |
+
(stride, 1),
|
659 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
660 |
+
)
|
661 |
+
),
|
662 |
+
norm_f(
|
663 |
+
Conv2d(
|
664 |
+
512,
|
665 |
+
1024,
|
666 |
+
(kernel_size, 1),
|
667 |
+
(stride, 1),
|
668 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
669 |
+
)
|
670 |
+
),
|
671 |
+
norm_f(
|
672 |
+
Conv2d(
|
673 |
+
1024,
|
674 |
+
1024,
|
675 |
+
(kernel_size, 1),
|
676 |
+
1,
|
677 |
+
padding=(commons.get_padding(kernel_size, 1), 0),
|
678 |
+
)
|
679 |
+
),
|
680 |
+
]
|
681 |
+
)
|
682 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
683 |
+
|
684 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
685 |
+
fmap = []
|
686 |
+
|
687 |
+
# 1d to 2d
|
688 |
+
b, c, t = x.shape
|
689 |
+
if t % self.period != 0: # pad first
|
690 |
+
n_pad = self.period - (t % self.period)
|
691 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
692 |
+
t = t + n_pad
|
693 |
+
x = x.view(b, c, t // self.period, self.period)
|
694 |
+
|
695 |
+
for layer in self.convs:
|
696 |
+
x = layer(x)
|
697 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
698 |
+
fmap.append(x)
|
699 |
+
x = self.conv_post(x)
|
700 |
+
fmap.append(x)
|
701 |
+
x = torch.flatten(x, 1, -1)
|
702 |
+
|
703 |
+
return x, fmap
|
704 |
+
|
705 |
+
|
706 |
+
class DiscriminatorS(torch.nn.Module):
|
707 |
+
def __init__(self, use_spectral_norm: bool = False) -> None:
|
708 |
+
super(DiscriminatorS, self).__init__()
|
709 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
710 |
+
self.convs = nn.ModuleList(
|
711 |
+
[
|
712 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
713 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
714 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
715 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
716 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
717 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
718 |
+
]
|
719 |
+
)
|
720 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
721 |
+
|
722 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
723 |
+
fmap = []
|
724 |
+
|
725 |
+
for layer in self.convs:
|
726 |
+
x = layer(x)
|
727 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
728 |
+
fmap.append(x)
|
729 |
+
x = self.conv_post(x)
|
730 |
+
fmap.append(x)
|
731 |
+
x = torch.flatten(x, 1, -1)
|
732 |
+
|
733 |
+
return x, fmap
|
734 |
+
|
735 |
+
|
736 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
737 |
+
def __init__(self, use_spectral_norm: bool = False) -> None:
|
738 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
739 |
+
periods = [2, 3, 5, 7, 11]
|
740 |
+
|
741 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
742 |
+
discs = discs + [
|
743 |
+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
744 |
+
]
|
745 |
+
self.discriminators = nn.ModuleList(discs)
|
746 |
+
|
747 |
+
def forward(
|
748 |
+
self,
|
749 |
+
y: torch.Tensor,
|
750 |
+
y_hat: torch.Tensor,
|
751 |
+
) -> tuple[
|
752 |
+
list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]
|
753 |
+
]:
|
754 |
+
y_d_rs = []
|
755 |
+
y_d_gs = []
|
756 |
+
fmap_rs = []
|
757 |
+
fmap_gs = []
|
758 |
+
for i, d in enumerate(self.discriminators):
|
759 |
+
y_d_r, fmap_r = d(y)
|
760 |
+
y_d_g, fmap_g = d(y_hat)
|
761 |
+
y_d_rs.append(y_d_r)
|
762 |
+
y_d_gs.append(y_d_g)
|
763 |
+
fmap_rs.append(fmap_r)
|
764 |
+
fmap_gs.append(fmap_g)
|
765 |
+
|
766 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
767 |
+
|
768 |
+
|
769 |
+
class WavLMDiscriminator(nn.Module):
|
770 |
+
"""docstring for Discriminator."""
|
771 |
+
|
772 |
+
def __init__(
|
773 |
+
self,
|
774 |
+
slm_hidden: int = 768,
|
775 |
+
slm_layers: int = 13,
|
776 |
+
initial_channel: int = 64,
|
777 |
+
use_spectral_norm: bool = False,
|
778 |
+
) -> None:
|
779 |
+
super(WavLMDiscriminator, self).__init__()
|
780 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
781 |
+
self.pre = norm_f(
|
782 |
+
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
|
783 |
+
)
|
784 |
+
|
785 |
+
self.convs = nn.ModuleList(
|
786 |
+
[
|
787 |
+
norm_f(
|
788 |
+
nn.Conv1d(
|
789 |
+
initial_channel, initial_channel * 2, kernel_size=5, padding=2
|
790 |
+
)
|
791 |
+
),
|
792 |
+
norm_f(
|
793 |
+
nn.Conv1d(
|
794 |
+
initial_channel * 2,
|
795 |
+
initial_channel * 4,
|
796 |
+
kernel_size=5,
|
797 |
+
padding=2,
|
798 |
+
)
|
799 |
+
),
|
800 |
+
norm_f(
|
801 |
+
nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
|
802 |
+
),
|
803 |
+
]
|
804 |
+
)
|
805 |
+
|
806 |
+
self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
|
807 |
+
|
808 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
809 |
+
x = self.pre(x)
|
810 |
+
|
811 |
+
fmap = []
|
812 |
+
for l in self.convs:
|
813 |
+
x = l(x)
|
814 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
815 |
+
fmap.append(x)
|
816 |
+
x = self.conv_post(x)
|
817 |
+
x = torch.flatten(x, 1, -1)
|
818 |
+
|
819 |
+
return x
|
820 |
+
|
821 |
+
|
822 |
+
class ReferenceEncoder(nn.Module):
|
823 |
+
"""
|
824 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
825 |
+
outputs --- [N, ref_enc_gru_size]
|
826 |
+
"""
|
827 |
+
|
828 |
+
def __init__(self, spec_channels: int, gin_channels: int = 0) -> None:
|
829 |
+
super().__init__()
|
830 |
+
self.spec_channels = spec_channels
|
831 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
832 |
+
K = len(ref_enc_filters)
|
833 |
+
filters = [1] + ref_enc_filters
|
834 |
+
convs = [
|
835 |
+
weight_norm(
|
836 |
+
nn.Conv2d(
|
837 |
+
in_channels=filters[i],
|
838 |
+
out_channels=filters[i + 1],
|
839 |
+
kernel_size=(3, 3),
|
840 |
+
stride=(2, 2),
|
841 |
+
padding=(1, 1),
|
842 |
+
)
|
843 |
+
)
|
844 |
+
for i in range(K)
|
845 |
+
]
|
846 |
+
self.convs = nn.ModuleList(convs)
|
847 |
+
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
|
848 |
+
|
849 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
850 |
+
self.gru = nn.GRU(
|
851 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
852 |
+
hidden_size=256 // 2,
|
853 |
+
batch_first=True,
|
854 |
+
)
|
855 |
+
self.proj = nn.Linear(128, gin_channels)
|
856 |
+
|
857 |
+
def forward(
|
858 |
+
self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None
|
859 |
+
) -> torch.Tensor:
|
860 |
+
N = inputs.size(0)
|
861 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
862 |
+
for conv in self.convs:
|
863 |
+
out = conv(out)
|
864 |
+
# out = wn(out)
|
865 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
866 |
+
|
867 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
868 |
+
T = out.size(1)
|
869 |
+
N = out.size(0)
|
870 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
871 |
+
|
872 |
+
self.gru.flatten_parameters()
|
873 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
874 |
+
|
875 |
+
return self.proj(out.squeeze(0))
|
876 |
+
|
877 |
+
def calculate_channels(
|
878 |
+
self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
|
879 |
+
) -> int:
|
880 |
+
for i in range(n_convs):
|
881 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
882 |
+
return L
|
883 |
+
|
884 |
+
|
885 |
+
class SynthesizerTrn(nn.Module):
|
886 |
+
"""
|
887 |
+
Synthesizer for Training
|
888 |
+
"""
|
889 |
+
|
890 |
+
def __init__(
|
891 |
+
self,
|
892 |
+
n_vocab: int,
|
893 |
+
spec_channels: int,
|
894 |
+
segment_size: int,
|
895 |
+
inter_channels: int,
|
896 |
+
hidden_channels: int,
|
897 |
+
filter_channels: int,
|
898 |
+
n_heads: int,
|
899 |
+
n_layers: int,
|
900 |
+
kernel_size: int,
|
901 |
+
p_dropout: float,
|
902 |
+
resblock: str,
|
903 |
+
resblock_kernel_sizes: list[int],
|
904 |
+
resblock_dilation_sizes: list[list[int]],
|
905 |
+
upsample_rates: list[int],
|
906 |
+
upsample_initial_channel: int,
|
907 |
+
upsample_kernel_sizes: list[int],
|
908 |
+
n_speakers: int = 256,
|
909 |
+
gin_channels: int = 256,
|
910 |
+
use_sdp: bool = True,
|
911 |
+
n_flow_layer: int = 4,
|
912 |
+
n_layers_trans_flow: int = 6,
|
913 |
+
flow_share_parameter: bool = False,
|
914 |
+
use_transformer_flow: bool = True,
|
915 |
+
**kwargs: Any,
|
916 |
+
) -> None:
|
917 |
+
super().__init__()
|
918 |
+
self.n_vocab = n_vocab
|
919 |
+
self.spec_channels = spec_channels
|
920 |
+
self.inter_channels = inter_channels
|
921 |
+
self.hidden_channels = hidden_channels
|
922 |
+
self.filter_channels = filter_channels
|
923 |
+
self.n_heads = n_heads
|
924 |
+
self.n_layers = n_layers
|
925 |
+
self.kernel_size = kernel_size
|
926 |
+
self.p_dropout = p_dropout
|
927 |
+
self.resblock = resblock
|
928 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
929 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
930 |
+
self.upsample_rates = upsample_rates
|
931 |
+
self.upsample_initial_channel = upsample_initial_channel
|
932 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
933 |
+
self.segment_size = segment_size
|
934 |
+
self.n_speakers = n_speakers
|
935 |
+
self.gin_channels = gin_channels
|
936 |
+
self.n_layers_trans_flow = n_layers_trans_flow
|
937 |
+
self.use_spk_conditioned_encoder = kwargs.get(
|
938 |
+
"use_spk_conditioned_encoder", True
|
939 |
+
)
|
940 |
+
self.use_sdp = use_sdp
|
941 |
+
self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
|
942 |
+
self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
|
943 |
+
self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
|
944 |
+
self.current_mas_noise_scale = self.mas_noise_scale_initial
|
945 |
+
if self.use_spk_conditioned_encoder and gin_channels > 0:
|
946 |
+
self.enc_gin_channels = gin_channels
|
947 |
+
self.enc_p = TextEncoder(
|
948 |
+
n_vocab,
|
949 |
+
inter_channels,
|
950 |
+
hidden_channels,
|
951 |
+
filter_channels,
|
952 |
+
n_heads,
|
953 |
+
n_layers,
|
954 |
+
kernel_size,
|
955 |
+
p_dropout,
|
956 |
+
gin_channels=self.enc_gin_channels,
|
957 |
+
)
|
958 |
+
self.dec = Generator(
|
959 |
+
inter_channels,
|
960 |
+
resblock,
|
961 |
+
resblock_kernel_sizes,
|
962 |
+
resblock_dilation_sizes,
|
963 |
+
upsample_rates,
|
964 |
+
upsample_initial_channel,
|
965 |
+
upsample_kernel_sizes,
|
966 |
+
gin_channels=gin_channels,
|
967 |
+
)
|
968 |
+
self.enc_q = PosteriorEncoder(
|
969 |
+
spec_channels,
|
970 |
+
inter_channels,
|
971 |
+
hidden_channels,
|
972 |
+
5,
|
973 |
+
1,
|
974 |
+
16,
|
975 |
+
gin_channels=gin_channels,
|
976 |
+
)
|
977 |
+
if use_transformer_flow:
|
978 |
+
self.flow = TransformerCouplingBlock(
|
979 |
+
inter_channels,
|
980 |
+
hidden_channels,
|
981 |
+
filter_channels,
|
982 |
+
n_heads,
|
983 |
+
n_layers_trans_flow,
|
984 |
+
5,
|
985 |
+
p_dropout,
|
986 |
+
n_flow_layer,
|
987 |
+
gin_channels=gin_channels,
|
988 |
+
share_parameter=flow_share_parameter,
|
989 |
+
)
|
990 |
+
else:
|
991 |
+
self.flow = ResidualCouplingBlock(
|
992 |
+
inter_channels,
|
993 |
+
hidden_channels,
|
994 |
+
5,
|
995 |
+
1,
|
996 |
+
n_flow_layer,
|
997 |
+
gin_channels=gin_channels,
|
998 |
+
)
|
999 |
+
self.sdp = StochasticDurationPredictor(
|
1000 |
+
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
|
1001 |
+
)
|
1002 |
+
self.dp = DurationPredictor(
|
1003 |
+
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
if n_speakers >= 1:
|
1007 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
1008 |
+
else:
|
1009 |
+
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
|
1010 |
+
|
1011 |
+
def forward(
|
1012 |
+
self,
|
1013 |
+
x: torch.Tensor,
|
1014 |
+
x_lengths: torch.Tensor,
|
1015 |
+
y: torch.Tensor,
|
1016 |
+
y_lengths: torch.Tensor,
|
1017 |
+
sid: torch.Tensor,
|
1018 |
+
tone: torch.Tensor,
|
1019 |
+
language: torch.Tensor,
|
1020 |
+
bert: torch.Tensor,
|
1021 |
+
style_vec: torch.Tensor,
|
1022 |
+
) -> tuple[
|
1023 |
+
torch.Tensor,
|
1024 |
+
torch.Tensor,
|
1025 |
+
torch.Tensor,
|
1026 |
+
torch.Tensor,
|
1027 |
+
torch.Tensor,
|
1028 |
+
torch.Tensor,
|
1029 |
+
torch.Tensor,
|
1030 |
+
tuple[torch.Tensor, ...],
|
1031 |
+
tuple[torch.Tensor, ...],
|
1032 |
+
]:
|
1033 |
+
if self.n_speakers > 0:
|
1034 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
1035 |
+
else:
|
1036 |
+
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
1037 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
1038 |
+
x, x_lengths, tone, language, bert, style_vec, g=g
|
1039 |
+
)
|
1040 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
1041 |
+
z_p = self.flow(z, y_mask, g=g)
|
1042 |
+
|
1043 |
+
with torch.no_grad():
|
1044 |
+
# negative cross-entropy
|
1045 |
+
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
1046 |
+
neg_cent1 = torch.sum(
|
1047 |
+
-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
|
1048 |
+
) # [b, 1, t_s]
|
1049 |
+
neg_cent2 = torch.matmul(
|
1050 |
+
-0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
|
1051 |
+
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
1052 |
+
neg_cent3 = torch.matmul(
|
1053 |
+
z_p.transpose(1, 2), (m_p * s_p_sq_r)
|
1054 |
+
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
1055 |
+
neg_cent4 = torch.sum(
|
1056 |
+
-0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
|
1057 |
+
) # [b, 1, t_s]
|
1058 |
+
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
1059 |
+
if self.use_noise_scaled_mas:
|
1060 |
+
epsilon = (
|
1061 |
+
torch.std(neg_cent)
|
1062 |
+
* torch.randn_like(neg_cent)
|
1063 |
+
* self.current_mas_noise_scale
|
1064 |
+
)
|
1065 |
+
neg_cent = neg_cent + epsilon
|
1066 |
+
|
1067 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
1068 |
+
attn = (
|
1069 |
+
monotonic_alignment.maximum_path(neg_cent, attn_mask.squeeze(1))
|
1070 |
+
.unsqueeze(1)
|
1071 |
+
.detach()
|
1072 |
+
)
|
1073 |
+
|
1074 |
+
w = attn.sum(2)
|
1075 |
+
|
1076 |
+
l_length_sdp = self.sdp(x, x_mask, w, g=g)
|
1077 |
+
l_length_sdp = l_length_sdp / torch.sum(x_mask)
|
1078 |
+
|
1079 |
+
logw_ = torch.log(w + 1e-6) * x_mask
|
1080 |
+
logw = self.dp(x, x_mask, g=g)
|
1081 |
+
# logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
|
1082 |
+
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
|
1083 |
+
x_mask
|
1084 |
+
) # for averaging
|
1085 |
+
# l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
1086 |
+
|
1087 |
+
l_length = l_length_dp + l_length_sdp
|
1088 |
+
|
1089 |
+
# expand prior
|
1090 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
1091 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
1092 |
+
|
1093 |
+
z_slice, ids_slice = commons.rand_slice_segments(
|
1094 |
+
z, y_lengths, self.segment_size
|
1095 |
+
)
|
1096 |
+
o = self.dec(z_slice, g=g)
|
1097 |
+
return (
|
1098 |
+
o,
|
1099 |
+
l_length,
|
1100 |
+
attn,
|
1101 |
+
ids_slice,
|
1102 |
+
x_mask,
|
1103 |
+
y_mask,
|
1104 |
+
(z, z_p, m_p, logs_p, m_q, logs_q), # type: ignore
|
1105 |
+
(x, logw, logw_), # , logw_sdp),
|
1106 |
+
g,
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
def infer(
|
1110 |
+
self,
|
1111 |
+
x: torch.Tensor,
|
1112 |
+
x_lengths: torch.Tensor,
|
1113 |
+
sid: torch.Tensor,
|
1114 |
+
tone: torch.Tensor,
|
1115 |
+
language: torch.Tensor,
|
1116 |
+
bert: torch.Tensor,
|
1117 |
+
style_vec: torch.Tensor,
|
1118 |
+
noise_scale: float = 0.667,
|
1119 |
+
length_scale: float = 1.0,
|
1120 |
+
noise_scale_w: float = 0.8,
|
1121 |
+
max_len: Optional[int] = None,
|
1122 |
+
sdp_ratio: float = 0.0,
|
1123 |
+
y: Optional[torch.Tensor] = None,
|
1124 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, ...]]:
|
1125 |
+
# x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
|
1126 |
+
# g = self.gst(y)
|
1127 |
+
if self.n_speakers > 0:
|
1128 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
1129 |
+
else:
|
1130 |
+
assert y is not None
|
1131 |
+
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
|
1132 |
+
x, m_p, logs_p, x_mask = self.enc_p(
|
1133 |
+
x, x_lengths, tone, language, bert, style_vec, g=g
|
1134 |
+
)
|
1135 |
+
logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
|
1136 |
+
sdp_ratio
|
1137 |
+
) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
|
1138 |
+
w = torch.exp(logw) * x_mask * length_scale
|
1139 |
+
w_ceil = torch.ceil(w)
|
1140 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
1141 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
|
1142 |
+
x_mask.dtype
|
1143 |
+
)
|
1144 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
1145 |
+
attn = commons.generate_path(w_ceil, attn_mask)
|
1146 |
+
|
1147 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
|
1148 |
+
1, 2
|
1149 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
1150 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
|
1151 |
+
1, 2
|
1152 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
1153 |
+
|
1154 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
1155 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
1156 |
+
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
|
1157 |
+
return o, attn, y_mask, (z, z_p, m_p, logs_p)
|
style_bert_vits2/models/modules.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Conv1d
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
9 |
+
|
10 |
+
from style_bert_vits2.models import commons
|
11 |
+
from style_bert_vits2.models.attentions import Encoder
|
12 |
+
from style_bert_vits2.models.transforms import piecewise_rational_quadratic_transform
|
13 |
+
|
14 |
+
|
15 |
+
LRELU_SLOPE = 0.1
|
16 |
+
|
17 |
+
|
18 |
+
class LayerNorm(nn.Module):
|
19 |
+
def __init__(self, channels: int, eps: float = 1e-5) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.channels = channels
|
22 |
+
self.eps = eps
|
23 |
+
|
24 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
25 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
x = x.transpose(1, -1)
|
29 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
30 |
+
return x.transpose(1, -1)
|
31 |
+
|
32 |
+
|
33 |
+
class ConvReluNorm(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
in_channels: int,
|
37 |
+
hidden_channels: int,
|
38 |
+
out_channels: int,
|
39 |
+
kernel_size: int,
|
40 |
+
n_layers: int,
|
41 |
+
p_dropout: float,
|
42 |
+
) -> None:
|
43 |
+
super().__init__()
|
44 |
+
self.in_channels = in_channels
|
45 |
+
self.hidden_channels = hidden_channels
|
46 |
+
self.out_channels = out_channels
|
47 |
+
self.kernel_size = kernel_size
|
48 |
+
self.n_layers = n_layers
|
49 |
+
self.p_dropout = p_dropout
|
50 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
51 |
+
|
52 |
+
self.conv_layers = nn.ModuleList()
|
53 |
+
self.norm_layers = nn.ModuleList()
|
54 |
+
self.conv_layers.append(
|
55 |
+
nn.Conv1d(
|
56 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
57 |
+
)
|
58 |
+
)
|
59 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
60 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
61 |
+
for _ in range(n_layers - 1):
|
62 |
+
self.conv_layers.append(
|
63 |
+
nn.Conv1d(
|
64 |
+
hidden_channels,
|
65 |
+
hidden_channels,
|
66 |
+
kernel_size,
|
67 |
+
padding=kernel_size // 2,
|
68 |
+
)
|
69 |
+
)
|
70 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
71 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
72 |
+
self.proj.weight.data.zero_()
|
73 |
+
assert self.proj.bias is not None
|
74 |
+
self.proj.bias.data.zero_()
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
|
77 |
+
x_org = x
|
78 |
+
for i in range(self.n_layers):
|
79 |
+
x = self.conv_layers[i](x * x_mask)
|
80 |
+
x = self.norm_layers[i](x)
|
81 |
+
x = self.relu_drop(x)
|
82 |
+
x = x_org + self.proj(x)
|
83 |
+
return x * x_mask
|
84 |
+
|
85 |
+
|
86 |
+
class DDSConv(nn.Module):
|
87 |
+
"""
|
88 |
+
Dialted and Depth-Separable Convolution
|
89 |
+
"""
|
90 |
+
|
91 |
+
def __init__(
|
92 |
+
self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0
|
93 |
+
) -> None:
|
94 |
+
super().__init__()
|
95 |
+
self.channels = channels
|
96 |
+
self.kernel_size = kernel_size
|
97 |
+
self.n_layers = n_layers
|
98 |
+
self.p_dropout = p_dropout
|
99 |
+
|
100 |
+
self.drop = nn.Dropout(p_dropout)
|
101 |
+
self.convs_sep = nn.ModuleList()
|
102 |
+
self.convs_1x1 = nn.ModuleList()
|
103 |
+
self.norms_1 = nn.ModuleList()
|
104 |
+
self.norms_2 = nn.ModuleList()
|
105 |
+
for i in range(n_layers):
|
106 |
+
dilation = kernel_size**i
|
107 |
+
padding = (kernel_size * dilation - dilation) // 2
|
108 |
+
self.convs_sep.append(
|
109 |
+
nn.Conv1d(
|
110 |
+
channels,
|
111 |
+
channels,
|
112 |
+
kernel_size,
|
113 |
+
groups=channels,
|
114 |
+
dilation=dilation,
|
115 |
+
padding=padding,
|
116 |
+
)
|
117 |
+
)
|
118 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
119 |
+
self.norms_1.append(LayerNorm(channels))
|
120 |
+
self.norms_2.append(LayerNorm(channels))
|
121 |
+
|
122 |
+
def forward(
|
123 |
+
self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
|
124 |
+
) -> torch.Tensor:
|
125 |
+
if g is not None:
|
126 |
+
x = x + g
|
127 |
+
for i in range(self.n_layers):
|
128 |
+
y = self.convs_sep[i](x * x_mask)
|
129 |
+
y = self.norms_1[i](y)
|
130 |
+
y = F.gelu(y)
|
131 |
+
y = self.convs_1x1[i](y)
|
132 |
+
y = self.norms_2[i](y)
|
133 |
+
y = F.gelu(y)
|
134 |
+
y = self.drop(y)
|
135 |
+
x = x + y
|
136 |
+
return x * x_mask
|
137 |
+
|
138 |
+
|
139 |
+
class WN(torch.nn.Module):
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
hidden_channels: int,
|
143 |
+
kernel_size: int,
|
144 |
+
dilation_rate: int,
|
145 |
+
n_layers: int,
|
146 |
+
gin_channels: int = 0,
|
147 |
+
p_dropout: float = 0,
|
148 |
+
) -> None:
|
149 |
+
super(WN, self).__init__()
|
150 |
+
assert kernel_size % 2 == 1
|
151 |
+
self.hidden_channels = hidden_channels
|
152 |
+
self.kernel_size = (kernel_size,)
|
153 |
+
self.dilation_rate = dilation_rate
|
154 |
+
self.n_layers = n_layers
|
155 |
+
self.gin_channels = gin_channels
|
156 |
+
self.p_dropout = p_dropout
|
157 |
+
|
158 |
+
self.in_layers = torch.nn.ModuleList()
|
159 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
160 |
+
self.drop = nn.Dropout(p_dropout)
|
161 |
+
|
162 |
+
if gin_channels != 0:
|
163 |
+
cond_layer = torch.nn.Conv1d(
|
164 |
+
gin_channels, 2 * hidden_channels * n_layers, 1
|
165 |
+
)
|
166 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
167 |
+
|
168 |
+
for i in range(n_layers):
|
169 |
+
dilation = dilation_rate**i
|
170 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
171 |
+
in_layer = torch.nn.Conv1d(
|
172 |
+
hidden_channels,
|
173 |
+
2 * hidden_channels,
|
174 |
+
kernel_size,
|
175 |
+
dilation=dilation,
|
176 |
+
padding=padding,
|
177 |
+
)
|
178 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
179 |
+
self.in_layers.append(in_layer)
|
180 |
+
|
181 |
+
# last one is not necessary
|
182 |
+
if i < n_layers - 1:
|
183 |
+
res_skip_channels = 2 * hidden_channels
|
184 |
+
else:
|
185 |
+
res_skip_channels = hidden_channels
|
186 |
+
|
187 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
188 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
189 |
+
self.res_skip_layers.append(res_skip_layer)
|
190 |
+
|
191 |
+
def forward(
|
192 |
+
self,
|
193 |
+
x: torch.Tensor,
|
194 |
+
x_mask: torch.Tensor,
|
195 |
+
g: Optional[torch.Tensor] = None,
|
196 |
+
**kwargs: Any,
|
197 |
+
) -> torch.Tensor:
|
198 |
+
output = torch.zeros_like(x)
|
199 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
200 |
+
|
201 |
+
if g is not None:
|
202 |
+
g = self.cond_layer(g)
|
203 |
+
|
204 |
+
for i in range(self.n_layers):
|
205 |
+
x_in = self.in_layers[i](x)
|
206 |
+
if g is not None:
|
207 |
+
cond_offset = i * 2 * self.hidden_channels
|
208 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
209 |
+
else:
|
210 |
+
g_l = torch.zeros_like(x_in)
|
211 |
+
|
212 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
213 |
+
acts = self.drop(acts)
|
214 |
+
|
215 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
216 |
+
if i < self.n_layers - 1:
|
217 |
+
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
218 |
+
x = (x + res_acts) * x_mask
|
219 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
220 |
+
else:
|
221 |
+
output = output + res_skip_acts
|
222 |
+
return output * x_mask
|
223 |
+
|
224 |
+
def remove_weight_norm(self) -> None:
|
225 |
+
if self.gin_channels != 0:
|
226 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
227 |
+
for l in self.in_layers:
|
228 |
+
torch.nn.utils.remove_weight_norm(l)
|
229 |
+
for l in self.res_skip_layers:
|
230 |
+
torch.nn.utils.remove_weight_norm(l)
|
231 |
+
|
232 |
+
|
233 |
+
class ResBlock1(torch.nn.Module):
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
channels: int,
|
237 |
+
kernel_size: int = 3,
|
238 |
+
dilation: tuple[int, int, int] = (1, 3, 5),
|
239 |
+
) -> None:
|
240 |
+
super(ResBlock1, self).__init__()
|
241 |
+
self.convs1 = nn.ModuleList(
|
242 |
+
[
|
243 |
+
weight_norm(
|
244 |
+
Conv1d(
|
245 |
+
channels,
|
246 |
+
channels,
|
247 |
+
kernel_size,
|
248 |
+
1,
|
249 |
+
dilation=dilation[0],
|
250 |
+
padding=commons.get_padding(kernel_size, dilation[0]),
|
251 |
+
)
|
252 |
+
),
|
253 |
+
weight_norm(
|
254 |
+
Conv1d(
|
255 |
+
channels,
|
256 |
+
channels,
|
257 |
+
kernel_size,
|
258 |
+
1,
|
259 |
+
dilation=dilation[1],
|
260 |
+
padding=commons.get_padding(kernel_size, dilation[1]),
|
261 |
+
)
|
262 |
+
),
|
263 |
+
weight_norm(
|
264 |
+
Conv1d(
|
265 |
+
channels,
|
266 |
+
channels,
|
267 |
+
kernel_size,
|
268 |
+
1,
|
269 |
+
dilation=dilation[2],
|
270 |
+
padding=commons.get_padding(kernel_size, dilation[2]),
|
271 |
+
)
|
272 |
+
),
|
273 |
+
]
|
274 |
+
)
|
275 |
+
self.convs1.apply(commons.init_weights)
|
276 |
+
|
277 |
+
self.convs2 = nn.ModuleList(
|
278 |
+
[
|
279 |
+
weight_norm(
|
280 |
+
Conv1d(
|
281 |
+
channels,
|
282 |
+
channels,
|
283 |
+
kernel_size,
|
284 |
+
1,
|
285 |
+
dilation=1,
|
286 |
+
padding=commons.get_padding(kernel_size, 1),
|
287 |
+
)
|
288 |
+
),
|
289 |
+
weight_norm(
|
290 |
+
Conv1d(
|
291 |
+
channels,
|
292 |
+
channels,
|
293 |
+
kernel_size,
|
294 |
+
1,
|
295 |
+
dilation=1,
|
296 |
+
padding=commons.get_padding(kernel_size, 1),
|
297 |
+
)
|
298 |
+
),
|
299 |
+
weight_norm(
|
300 |
+
Conv1d(
|
301 |
+
channels,
|
302 |
+
channels,
|
303 |
+
kernel_size,
|
304 |
+
1,
|
305 |
+
dilation=1,
|
306 |
+
padding=commons.get_padding(kernel_size, 1),
|
307 |
+
)
|
308 |
+
),
|
309 |
+
]
|
310 |
+
)
|
311 |
+
self.convs2.apply(commons.init_weights)
|
312 |
+
|
313 |
+
def forward(
|
314 |
+
self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None
|
315 |
+
) -> torch.Tensor:
|
316 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
317 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
318 |
+
if x_mask is not None:
|
319 |
+
xt = xt * x_mask
|
320 |
+
xt = c1(xt)
|
321 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
322 |
+
if x_mask is not None:
|
323 |
+
xt = xt * x_mask
|
324 |
+
xt = c2(xt)
|
325 |
+
x = xt + x
|
326 |
+
if x_mask is not None:
|
327 |
+
x = x * x_mask
|
328 |
+
return x
|
329 |
+
|
330 |
+
def remove_weight_norm(self) -> None:
|
331 |
+
for l in self.convs1:
|
332 |
+
remove_weight_norm(l)
|
333 |
+
for l in self.convs2:
|
334 |
+
remove_weight_norm(l)
|
335 |
+
|
336 |
+
|
337 |
+
class ResBlock2(torch.nn.Module):
|
338 |
+
def __init__(
|
339 |
+
self, channels: int, kernel_size: int = 3, dilation: tuple[int, int] = (1, 3)
|
340 |
+
) -> None:
|
341 |
+
super(ResBlock2, self).__init__()
|
342 |
+
self.convs = nn.ModuleList(
|
343 |
+
[
|
344 |
+
weight_norm(
|
345 |
+
Conv1d(
|
346 |
+
channels,
|
347 |
+
channels,
|
348 |
+
kernel_size,
|
349 |
+
1,
|
350 |
+
dilation=dilation[0],
|
351 |
+
padding=commons.get_padding(kernel_size, dilation[0]),
|
352 |
+
)
|
353 |
+
),
|
354 |
+
weight_norm(
|
355 |
+
Conv1d(
|
356 |
+
channels,
|
357 |
+
channels,
|
358 |
+
kernel_size,
|
359 |
+
1,
|
360 |
+
dilation=dilation[1],
|
361 |
+
padding=commons.get_padding(kernel_size, dilation[1]),
|
362 |
+
)
|
363 |
+
),
|
364 |
+
]
|
365 |
+
)
|
366 |
+
self.convs.apply(commons.init_weights)
|
367 |
+
|
368 |
+
def forward(
|
369 |
+
self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None
|
370 |
+
) -> torch.Tensor:
|
371 |
+
for c in self.convs:
|
372 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
373 |
+
if x_mask is not None:
|
374 |
+
xt = xt * x_mask
|
375 |
+
xt = c(xt)
|
376 |
+
x = xt + x
|
377 |
+
if x_mask is not None:
|
378 |
+
x = x * x_mask
|
379 |
+
return x
|
380 |
+
|
381 |
+
def remove_weight_norm(self) -> None:
|
382 |
+
for l in self.convs:
|
383 |
+
remove_weight_norm(l)
|
384 |
+
|
385 |
+
|
386 |
+
class Log(nn.Module):
|
387 |
+
def forward(
|
388 |
+
self,
|
389 |
+
x: torch.Tensor,
|
390 |
+
x_mask: torch.Tensor,
|
391 |
+
reverse: bool = False,
|
392 |
+
**kwargs: Any,
|
393 |
+
) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
394 |
+
if not reverse:
|
395 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
396 |
+
logdet = torch.sum(-y, [1, 2])
|
397 |
+
return y, logdet
|
398 |
+
else:
|
399 |
+
x = torch.exp(x) * x_mask
|
400 |
+
return x
|
401 |
+
|
402 |
+
|
403 |
+
class Flip(nn.Module):
|
404 |
+
def forward(
|
405 |
+
self,
|
406 |
+
x: torch.Tensor,
|
407 |
+
*args: Any,
|
408 |
+
reverse: bool = False,
|
409 |
+
**kwargs: Any,
|
410 |
+
) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
411 |
+
x = torch.flip(x, [1])
|
412 |
+
if not reverse:
|
413 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
414 |
+
return x, logdet
|
415 |
+
else:
|
416 |
+
return x
|
417 |
+
|
418 |
+
|
419 |
+
class ElementwiseAffine(nn.Module):
|
420 |
+
def __init__(self, channels: int) -> None:
|
421 |
+
super().__init__()
|
422 |
+
self.channels = channels
|
423 |
+
self.m = nn.Parameter(torch.zeros(channels, 1))
|
424 |
+
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
425 |
+
|
426 |
+
def forward(
|
427 |
+
self,
|
428 |
+
x: torch.Tensor,
|
429 |
+
x_mask: torch.Tensor,
|
430 |
+
reverse: bool = False,
|
431 |
+
**kwargs: Any,
|
432 |
+
) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
433 |
+
if not reverse:
|
434 |
+
y = self.m + torch.exp(self.logs) * x
|
435 |
+
y = y * x_mask
|
436 |
+
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
437 |
+
return y, logdet
|
438 |
+
else:
|
439 |
+
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
440 |
+
return x
|
441 |
+
|
442 |
+
|
443 |
+
class ResidualCouplingLayer(nn.Module):
|
444 |
+
def __init__(
|
445 |
+
self,
|
446 |
+
channels: int,
|
447 |
+
hidden_channels: int,
|
448 |
+
kernel_size: int,
|
449 |
+
dilation_rate: int,
|
450 |
+
n_layers: int,
|
451 |
+
p_dropout: float = 0,
|
452 |
+
gin_channels: int = 0,
|
453 |
+
mean_only: bool = False,
|
454 |
+
) -> None:
|
455 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
456 |
+
super().__init__()
|
457 |
+
self.channels = channels
|
458 |
+
self.hidden_channels = hidden_channels
|
459 |
+
self.kernel_size = kernel_size
|
460 |
+
self.dilation_rate = dilation_rate
|
461 |
+
self.n_layers = n_layers
|
462 |
+
self.half_channels = channels // 2
|
463 |
+
self.mean_only = mean_only
|
464 |
+
|
465 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
466 |
+
self.enc = WN(
|
467 |
+
hidden_channels,
|
468 |
+
kernel_size,
|
469 |
+
dilation_rate,
|
470 |
+
n_layers,
|
471 |
+
p_dropout=p_dropout,
|
472 |
+
gin_channels=gin_channels,
|
473 |
+
)
|
474 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
475 |
+
self.post.weight.data.zero_()
|
476 |
+
assert self.post.bias is not None
|
477 |
+
self.post.bias.data.zero_()
|
478 |
+
|
479 |
+
def forward(
|
480 |
+
self,
|
481 |
+
x: torch.Tensor,
|
482 |
+
x_mask: torch.Tensor,
|
483 |
+
g: Optional[torch.Tensor] = None,
|
484 |
+
reverse: bool = False,
|
485 |
+
) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
486 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
487 |
+
h = self.pre(x0) * x_mask
|
488 |
+
h = self.enc(h, x_mask, g=g)
|
489 |
+
stats = self.post(h) * x_mask
|
490 |
+
if not self.mean_only:
|
491 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
492 |
+
else:
|
493 |
+
m = stats
|
494 |
+
logs = torch.zeros_like(m)
|
495 |
+
|
496 |
+
if not reverse:
|
497 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
498 |
+
x = torch.cat([x0, x1], 1)
|
499 |
+
logdet = torch.sum(logs, [1, 2])
|
500 |
+
return x, logdet
|
501 |
+
else:
|
502 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
503 |
+
x = torch.cat([x0, x1], 1)
|
504 |
+
return x
|
505 |
+
|
506 |
+
|
507 |
+
class ConvFlow(nn.Module):
|
508 |
+
def __init__(
|
509 |
+
self,
|
510 |
+
in_channels: int,
|
511 |
+
filter_channels: int,
|
512 |
+
kernel_size: int,
|
513 |
+
n_layers: int,
|
514 |
+
num_bins: int = 10,
|
515 |
+
tail_bound: float = 5.0,
|
516 |
+
) -> None:
|
517 |
+
super().__init__()
|
518 |
+
self.in_channels = in_channels
|
519 |
+
self.filter_channels = filter_channels
|
520 |
+
self.kernel_size = kernel_size
|
521 |
+
self.n_layers = n_layers
|
522 |
+
self.num_bins = num_bins
|
523 |
+
self.tail_bound = tail_bound
|
524 |
+
self.half_channels = in_channels // 2
|
525 |
+
|
526 |
+
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
527 |
+
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
528 |
+
self.proj = nn.Conv1d(
|
529 |
+
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
530 |
+
)
|
531 |
+
self.proj.weight.data.zero_()
|
532 |
+
assert self.proj.bias is not None
|
533 |
+
self.proj.bias.data.zero_()
|
534 |
+
|
535 |
+
def forward(
|
536 |
+
self,
|
537 |
+
x: torch.Tensor,
|
538 |
+
x_mask: torch.Tensor,
|
539 |
+
g: Optional[torch.Tensor] = None,
|
540 |
+
reverse: bool = False,
|
541 |
+
) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
542 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
543 |
+
h = self.pre(x0)
|
544 |
+
h = self.convs(h, x_mask, g=g)
|
545 |
+
h = self.proj(h) * x_mask
|
546 |
+
|
547 |
+
b, c, t = x0.shape
|
548 |
+
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
549 |
+
|
550 |
+
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
551 |
+
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
552 |
+
self.filter_channels
|
553 |
+
)
|
554 |
+
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
555 |
+
|
556 |
+
x1, logabsdet = piecewise_rational_quadratic_transform(
|
557 |
+
x1,
|
558 |
+
unnormalized_widths,
|
559 |
+
unnormalized_heights,
|
560 |
+
unnormalized_derivatives,
|
561 |
+
inverse=reverse,
|
562 |
+
tails="linear",
|
563 |
+
tail_bound=self.tail_bound,
|
564 |
+
)
|
565 |
+
|
566 |
+
x = torch.cat([x0, x1], 1) * x_mask
|
567 |
+
logdet = torch.sum(logabsdet * x_mask, [1, 2])
|
568 |
+
if not reverse:
|
569 |
+
return x, logdet
|
570 |
+
else:
|
571 |
+
return x
|
572 |
+
|
573 |
+
|
574 |
+
class TransformerCouplingLayer(nn.Module):
|
575 |
+
def __init__(
|
576 |
+
self,
|
577 |
+
channels: int,
|
578 |
+
hidden_channels: int,
|
579 |
+
kernel_size: int,
|
580 |
+
n_layers: int,
|
581 |
+
n_heads: int,
|
582 |
+
p_dropout: float = 0,
|
583 |
+
filter_channels: int = 0,
|
584 |
+
mean_only: bool = False,
|
585 |
+
wn_sharing_parameter: Optional[nn.Module] = None,
|
586 |
+
gin_channels: int = 0,
|
587 |
+
) -> None:
|
588 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
589 |
+
super().__init__()
|
590 |
+
self.channels = channels
|
591 |
+
self.hidden_channels = hidden_channels
|
592 |
+
self.kernel_size = kernel_size
|
593 |
+
self.n_layers = n_layers
|
594 |
+
self.half_channels = channels // 2
|
595 |
+
self.mean_only = mean_only
|
596 |
+
|
597 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
598 |
+
self.enc = (
|
599 |
+
Encoder(
|
600 |
+
hidden_channels,
|
601 |
+
filter_channels,
|
602 |
+
n_heads,
|
603 |
+
n_layers,
|
604 |
+
kernel_size,
|
605 |
+
p_dropout,
|
606 |
+
isflow=True,
|
607 |
+
gin_channels=gin_channels,
|
608 |
+
)
|
609 |
+
if wn_sharing_parameter is None
|
610 |
+
else wn_sharing_parameter
|
611 |
+
)
|
612 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
613 |
+
self.post.weight.data.zero_()
|
614 |
+
assert self.post.bias is not None
|
615 |
+
self.post.bias.data.zero_()
|
616 |
+
|
617 |
+
def forward(
|
618 |
+
self,
|
619 |
+
x: torch.Tensor,
|
620 |
+
x_mask: torch.Tensor,
|
621 |
+
g: Optional[torch.Tensor] = None,
|
622 |
+
reverse: bool = False,
|
623 |
+
) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
624 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
625 |
+
h = self.pre(x0) * x_mask
|
626 |
+
h = self.enc(h, x_mask, g=g)
|
627 |
+
stats = self.post(h) * x_mask
|
628 |
+
if not self.mean_only:
|
629 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
630 |
+
else:
|
631 |
+
m = stats
|
632 |
+
logs = torch.zeros_like(m)
|
633 |
+
|
634 |
+
if not reverse:
|
635 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
636 |
+
x = torch.cat([x0, x1], 1)
|
637 |
+
logdet = torch.sum(logs, [1, 2])
|
638 |
+
return x, logdet
|
639 |
+
else:
|
640 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
641 |
+
x = torch.cat([x0, x1], 1)
|
642 |
+
return x
|
style_bert_vits2/models/monotonic_alignment.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
以下に記述されている関数のコメントはリファクタリング時に GPT-4 に生成させたもので、
|
3 |
+
コードと完全に一致している保証はない。あくまで参考程度とすること。
|
4 |
+
"""
|
5 |
+
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import numba
|
9 |
+
import torch
|
10 |
+
from numpy import float32, int32, zeros
|
11 |
+
|
12 |
+
|
13 |
+
def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
14 |
+
"""
|
15 |
+
与えられた負の中心とマスクを使用して最大パスを計算する
|
16 |
+
|
17 |
+
Args:
|
18 |
+
neg_cent (torch.Tensor): 負の中心を表すテンソル
|
19 |
+
mask (torch.Tensor): マスクを表すテンソル
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Tensor: 計算された最大パスを表すテンソル
|
23 |
+
"""
|
24 |
+
|
25 |
+
device = neg_cent.device
|
26 |
+
dtype = neg_cent.dtype
|
27 |
+
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
|
28 |
+
path = zeros(neg_cent.shape, dtype=int32)
|
29 |
+
|
30 |
+
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
|
31 |
+
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
|
32 |
+
__maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
|
33 |
+
|
34 |
+
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
35 |
+
|
36 |
+
|
37 |
+
@numba.jit(
|
38 |
+
numba.void(
|
39 |
+
numba.int32[:, :, ::1],
|
40 |
+
numba.float32[:, :, ::1],
|
41 |
+
numba.int32[::1],
|
42 |
+
numba.int32[::1],
|
43 |
+
),
|
44 |
+
nopython=True,
|
45 |
+
nogil=True,
|
46 |
+
) # type: ignore
|
47 |
+
def __maximum_path_jit(paths: Any, values: Any, t_ys: Any, t_xs: Any) -> None:
|
48 |
+
"""
|
49 |
+
与えられたパス、値、およびターゲットの y と x 座標を使用して JIT で最大パスを計算する
|
50 |
+
|
51 |
+
Args:
|
52 |
+
paths: 計算されたパスを格納するための整数型の 3 次元配列
|
53 |
+
values: 値を格納するための浮動小数点型の 3 次元配列
|
54 |
+
t_ys: ターゲットの y 座標を格納するための整数型の 1 次元配列
|
55 |
+
t_xs: ターゲットの x 座標を格納するための整数型の 1 次元配列
|
56 |
+
"""
|
57 |
+
|
58 |
+
b = paths.shape[0]
|
59 |
+
max_neg_val = -1e9
|
60 |
+
for i in range(int(b)):
|
61 |
+
path = paths[i]
|
62 |
+
value = values[i]
|
63 |
+
t_y = t_ys[i]
|
64 |
+
t_x = t_xs[i]
|
65 |
+
|
66 |
+
v_prev = v_cur = 0.0
|
67 |
+
index = t_x - 1
|
68 |
+
|
69 |
+
for y in range(t_y):
|
70 |
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
71 |
+
if x == y:
|
72 |
+
v_cur = max_neg_val
|
73 |
+
else:
|
74 |
+
v_cur = value[y - 1, x]
|
75 |
+
if x == 0:
|
76 |
+
if y == 0:
|
77 |
+
v_prev = 0.0
|
78 |
+
else:
|
79 |
+
v_prev = max_neg_val
|
80 |
+
else:
|
81 |
+
v_prev = value[y - 1, x - 1]
|
82 |
+
value[y, x] += max(v_prev, v_cur)
|
83 |
+
|
84 |
+
for y in range(t_y - 1, -1, -1):
|
85 |
+
path[y, index] = 1
|
86 |
+
if index != 0 and (
|
87 |
+
index == y or value[y - 1, index] < value[y - 1, index - 1]
|
88 |
+
):
|
89 |
+
index = index - 1
|
style_bert_vits2/models/transforms.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
9 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
10 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
11 |
+
|
12 |
+
|
13 |
+
def piecewise_rational_quadratic_transform(
|
14 |
+
inputs: torch.Tensor,
|
15 |
+
unnormalized_widths: torch.Tensor,
|
16 |
+
unnormalized_heights: torch.Tensor,
|
17 |
+
unnormalized_derivatives: torch.Tensor,
|
18 |
+
inverse: bool = False,
|
19 |
+
tails: Optional[str] = None,
|
20 |
+
tail_bound: float = 1.0,
|
21 |
+
min_bin_width: float = DEFAULT_MIN_BIN_WIDTH,
|
22 |
+
min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT,
|
23 |
+
min_derivative: float = DEFAULT_MIN_DERIVATIVE,
|
24 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
|
26 |
+
if tails is None:
|
27 |
+
spline_fn = rational_quadratic_spline
|
28 |
+
spline_kwargs = {}
|
29 |
+
else:
|
30 |
+
spline_fn = unconstrained_rational_quadratic_spline
|
31 |
+
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
32 |
+
|
33 |
+
outputs, logabsdet = spline_fn(
|
34 |
+
inputs=inputs,
|
35 |
+
unnormalized_widths=unnormalized_widths,
|
36 |
+
unnormalized_heights=unnormalized_heights,
|
37 |
+
unnormalized_derivatives=unnormalized_derivatives,
|
38 |
+
inverse=inverse,
|
39 |
+
min_bin_width=min_bin_width,
|
40 |
+
min_bin_height=min_bin_height,
|
41 |
+
min_derivative=min_derivative,
|
42 |
+
**spline_kwargs, # type: ignore
|
43 |
+
)
|
44 |
+
return outputs, logabsdet
|
45 |
+
|
46 |
+
|
47 |
+
def searchsorted(
|
48 |
+
bin_locations: torch.Tensor, inputs: torch.Tensor, eps: float = 1e-6
|
49 |
+
) -> torch.Tensor:
|
50 |
+
bin_locations[..., -1] += eps
|
51 |
+
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
52 |
+
|
53 |
+
|
54 |
+
def unconstrained_rational_quadratic_spline(
|
55 |
+
inputs: torch.Tensor,
|
56 |
+
unnormalized_widths: torch.Tensor,
|
57 |
+
unnormalized_heights: torch.Tensor,
|
58 |
+
unnormalized_derivatives: torch.Tensor,
|
59 |
+
inverse: bool = False,
|
60 |
+
tails: str = "linear",
|
61 |
+
tail_bound: float = 1.0,
|
62 |
+
min_bin_width: float = DEFAULT_MIN_BIN_WIDTH,
|
63 |
+
min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT,
|
64 |
+
min_derivative: float = DEFAULT_MIN_DERIVATIVE,
|
65 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
66 |
+
|
67 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
68 |
+
outside_interval_mask = ~inside_interval_mask
|
69 |
+
|
70 |
+
outputs = torch.zeros_like(inputs)
|
71 |
+
logabsdet = torch.zeros_like(inputs)
|
72 |
+
|
73 |
+
if tails == "linear":
|
74 |
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
75 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
76 |
+
unnormalized_derivatives[..., 0] = constant
|
77 |
+
unnormalized_derivatives[..., -1] = constant
|
78 |
+
|
79 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
80 |
+
logabsdet[outside_interval_mask] = 0
|
81 |
+
else:
|
82 |
+
raise RuntimeError(f"{tails} tails are not implemented.")
|
83 |
+
|
84 |
+
(
|
85 |
+
outputs[inside_interval_mask],
|
86 |
+
logabsdet[inside_interval_mask],
|
87 |
+
) = rational_quadratic_spline(
|
88 |
+
inputs=inputs[inside_interval_mask],
|
89 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
90 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
91 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
92 |
+
inverse=inverse,
|
93 |
+
left=-tail_bound,
|
94 |
+
right=tail_bound,
|
95 |
+
bottom=-tail_bound,
|
96 |
+
top=tail_bound,
|
97 |
+
min_bin_width=min_bin_width,
|
98 |
+
min_bin_height=min_bin_height,
|
99 |
+
min_derivative=min_derivative,
|
100 |
+
)
|
101 |
+
|
102 |
+
return outputs, logabsdet
|
103 |
+
|
104 |
+
|
105 |
+
def rational_quadratic_spline(
|
106 |
+
inputs: torch.Tensor,
|
107 |
+
unnormalized_widths: torch.Tensor,
|
108 |
+
unnormalized_heights: torch.Tensor,
|
109 |
+
unnormalized_derivatives: torch.Tensor,
|
110 |
+
inverse: bool = False,
|
111 |
+
left: float = 0.0,
|
112 |
+
right: float = 1.0,
|
113 |
+
bottom: float = 0.0,
|
114 |
+
top: float = 1.0,
|
115 |
+
min_bin_width: float = DEFAULT_MIN_BIN_WIDTH,
|
116 |
+
min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT,
|
117 |
+
min_derivative: float = DEFAULT_MIN_DERIVATIVE,
|
118 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
119 |
+
|
120 |
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
121 |
+
raise ValueError("Input to a transform is not within its domain")
|
122 |
+
|
123 |
+
num_bins = unnormalized_widths.shape[-1]
|
124 |
+
|
125 |
+
if min_bin_width * num_bins > 1.0:
|
126 |
+
raise ValueError("Minimal bin width too large for the number of bins")
|
127 |
+
if min_bin_height * num_bins > 1.0:
|
128 |
+
raise ValueError("Minimal bin height too large for the number of bins")
|
129 |
+
|
130 |
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
131 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
132 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
133 |
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
134 |
+
cumwidths = (right - left) * cumwidths + left
|
135 |
+
cumwidths[..., 0] = left
|
136 |
+
cumwidths[..., -1] = right
|
137 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
138 |
+
|
139 |
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
140 |
+
|
141 |
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
142 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
143 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
144 |
+
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
145 |
+
cumheights = (top - bottom) * cumheights + bottom
|
146 |
+
cumheights[..., 0] = bottom
|
147 |
+
cumheights[..., -1] = top
|
148 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
149 |
+
|
150 |
+
if inverse:
|
151 |
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
152 |
+
else:
|
153 |
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
154 |
+
|
155 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
156 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
157 |
+
|
158 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
159 |
+
delta = heights / widths
|
160 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
161 |
+
|
162 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
163 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
164 |
+
|
165 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
166 |
+
|
167 |
+
if inverse:
|
168 |
+
a = (inputs - input_cumheights) * (
|
169 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
170 |
+
) + input_heights * (input_delta - input_derivatives)
|
171 |
+
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
172 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
173 |
+
)
|
174 |
+
c = -input_delta * (inputs - input_cumheights)
|
175 |
+
|
176 |
+
discriminant = b.pow(2) - 4 * a * c
|
177 |
+
assert (discriminant >= 0).all()
|
178 |
+
|
179 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
180 |
+
outputs = root * input_bin_widths + input_cumwidths
|
181 |
+
|
182 |
+
theta_one_minus_theta = root * (1 - root)
|
183 |
+
denominator = input_delta + (
|
184 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
185 |
+
* theta_one_minus_theta
|
186 |
+
)
|
187 |
+
derivative_numerator = input_delta.pow(2) * (
|
188 |
+
input_derivatives_plus_one * root.pow(2)
|
189 |
+
+ 2 * input_delta * theta_one_minus_theta
|
190 |
+
+ input_derivatives * (1 - root).pow(2)
|
191 |
+
)
|
192 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
193 |
+
|
194 |
+
return outputs, -logabsdet
|
195 |
+
else:
|
196 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
197 |
+
theta_one_minus_theta = theta * (1 - theta)
|
198 |
+
|
199 |
+
numerator = input_heights * (
|
200 |
+
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
201 |
+
)
|
202 |
+
denominator = input_delta + (
|
203 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
204 |
+
* theta_one_minus_theta
|
205 |
+
)
|
206 |
+
outputs = input_cumheights + numerator / denominator
|
207 |
+
|
208 |
+
derivative_numerator = input_delta.pow(2) * (
|
209 |
+
input_derivatives_plus_one * theta.pow(2)
|
210 |
+
+ 2 * input_delta * theta_one_minus_theta
|
211 |
+
+ input_derivatives * (1 - theta).pow(2)
|
212 |
+
)
|
213 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
214 |
+
|
215 |
+
return outputs, logabsdet
|
style_bert_vits2/models/utils/__init__.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from numpy.typing import NDArray
|
12 |
+
|
13 |
+
from style_bert_vits2.logging import logger
|
14 |
+
from style_bert_vits2.models.utils import checkpoints # type: ignore
|
15 |
+
from style_bert_vits2.models.utils import safetensors # type: ignore
|
16 |
+
|
17 |
+
|
18 |
+
if TYPE_CHECKING:
|
19 |
+
# tensorboard はライブラリとしてインストールされている場合は依存関係に含まれないため、型チェック時のみインポートする
|
20 |
+
from torch.utils.tensorboard import SummaryWriter
|
21 |
+
|
22 |
+
|
23 |
+
__is_matplotlib_imported = False
|
24 |
+
|
25 |
+
|
26 |
+
def summarize(
|
27 |
+
writer: "SummaryWriter",
|
28 |
+
global_step: int,
|
29 |
+
scalars: dict[str, float] = {},
|
30 |
+
histograms: dict[str, Any] = {},
|
31 |
+
images: dict[str, Any] = {},
|
32 |
+
audios: dict[str, Any] = {},
|
33 |
+
audio_sampling_rate: int = 22050,
|
34 |
+
) -> None:
|
35 |
+
"""
|
36 |
+
指定されたデータを TensorBoard にまとめて追加する
|
37 |
+
|
38 |
+
Args:
|
39 |
+
writer (SummaryWriter): TensorBoard への書き込みを行うオブジェクト
|
40 |
+
global_step (int): グローバルステップ数
|
41 |
+
scalars (dict[str, float]): スカラー値の辞書
|
42 |
+
histograms (dict[str, Any]): ヒストグラムの辞書
|
43 |
+
images (dict[str, Any]): 画像データの辞書
|
44 |
+
audios (dict[str, Any]): 音声データの辞書
|
45 |
+
audio_sampling_rate (int): 音声データのサンプリングレート
|
46 |
+
"""
|
47 |
+
for k, v in scalars.items():
|
48 |
+
writer.add_scalar(k, v, global_step)
|
49 |
+
for k, v in histograms.items():
|
50 |
+
writer.add_histogram(k, v, global_step)
|
51 |
+
for k, v in images.items():
|
52 |
+
writer.add_image(k, v, global_step, dataformats="HWC")
|
53 |
+
for k, v in audios.items():
|
54 |
+
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
55 |
+
|
56 |
+
|
57 |
+
def is_resuming(dir_path: Union[str, Path]) -> bool:
|
58 |
+
"""
|
59 |
+
指定されたディレクトリパスに再開可能なモデルが存在するかどうかを返す
|
60 |
+
|
61 |
+
Args:
|
62 |
+
dir_path: チェックするディレクトリのパス
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
bool: 再開可能なモデルが存在するかどうか
|
66 |
+
"""
|
67 |
+
# JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する
|
68 |
+
g_list = glob.glob(os.path.join(dir_path, "G_*.pth"))
|
69 |
+
# d_list = glob.glob(os.path.join(dir_path, "D_*.pth"))
|
70 |
+
# dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth"))
|
71 |
+
return len(g_list) > 0
|
72 |
+
|
73 |
+
|
74 |
+
def plot_spectrogram_to_numpy(spectrogram: NDArray[Any]) -> NDArray[Any]:
|
75 |
+
"""
|
76 |
+
指定されたスペクトログラムを画像データに変換する
|
77 |
+
|
78 |
+
Args:
|
79 |
+
spectrogram (NDArray[Any]): スペクトログラム
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
NDArray[Any]: 画像データ
|
83 |
+
"""
|
84 |
+
|
85 |
+
global __is_matplotlib_imported
|
86 |
+
if not __is_matplotlib_imported:
|
87 |
+
import matplotlib
|
88 |
+
|
89 |
+
matplotlib.use("Agg")
|
90 |
+
__is_matplotlib_imported = True
|
91 |
+
mpl_logger = logging.getLogger("matplotlib")
|
92 |
+
mpl_logger.setLevel(logging.WARNING)
|
93 |
+
import matplotlib.pylab as plt
|
94 |
+
import numpy as np
|
95 |
+
|
96 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
97 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
98 |
+
plt.colorbar(im, ax=ax)
|
99 |
+
plt.xlabel("Frames")
|
100 |
+
plt.ylabel("Channels")
|
101 |
+
plt.tight_layout()
|
102 |
+
|
103 |
+
fig.canvas.draw()
|
104 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") # type: ignore
|
105 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
106 |
+
plt.close()
|
107 |
+
return data
|
108 |
+
|
109 |
+
|
110 |
+
def plot_alignment_to_numpy(
|
111 |
+
alignment: NDArray[Any], info: Optional[str] = None
|
112 |
+
) -> NDArray[Any]:
|
113 |
+
"""
|
114 |
+
指定されたアライメントを画像データに変換する
|
115 |
+
|
116 |
+
Args:
|
117 |
+
alignment (NDArray[Any]): アライメント
|
118 |
+
info (Optional[str]): 画像に追加する情報
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
NDArray[Any]: 画像データ
|
122 |
+
"""
|
123 |
+
|
124 |
+
global __is_matplotlib_imported
|
125 |
+
if not __is_matplotlib_imported:
|
126 |
+
import matplotlib
|
127 |
+
|
128 |
+
matplotlib.use("Agg")
|
129 |
+
__is_matplotlib_imported = True
|
130 |
+
mpl_logger = logging.getLogger("matplotlib")
|
131 |
+
mpl_logger.setLevel(logging.WARNING)
|
132 |
+
import matplotlib.pylab as plt
|
133 |
+
|
134 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
135 |
+
im = ax.imshow(
|
136 |
+
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
|
137 |
+
)
|
138 |
+
fig.colorbar(im, ax=ax)
|
139 |
+
xlabel = "Decoder timestep"
|
140 |
+
if info is not None:
|
141 |
+
xlabel += "\n\n" + info
|
142 |
+
plt.xlabel(xlabel)
|
143 |
+
plt.ylabel("Encoder timestep")
|
144 |
+
plt.tight_layout()
|
145 |
+
|
146 |
+
fig.canvas.draw()
|
147 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") # type: ignore
|
148 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
149 |
+
plt.close()
|
150 |
+
return data
|
151 |
+
|
152 |
+
|
153 |
+
def load_wav_to_torch(full_path: Union[str, Path]) -> tuple[torch.FloatTensor, int]:
|
154 |
+
"""
|
155 |
+
指定された音声ファイルを読み込み、PyTorch のテンソルに変換して返す
|
156 |
+
|
157 |
+
Args:
|
158 |
+
full_path (Union[str, Path]): 音声ファイルのパス
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
tuple[torch.FloatTensor, int]: 音声データのテンソルとサンプリングレート
|
162 |
+
"""
|
163 |
+
|
164 |
+
# この関数は学習時以外使われないため、ライブラリとしての style_bert_vits2 が
|
165 |
+
# 重たい scipy に依存しないように遅延 import する
|
166 |
+
try:
|
167 |
+
from scipy.io.wavfile import read
|
168 |
+
except ImportError:
|
169 |
+
raise ImportError("scipy is required to load wav file")
|
170 |
+
|
171 |
+
sampling_rate, data = read(full_path)
|
172 |
+
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
173 |
+
|
174 |
+
|
175 |
+
def load_filepaths_and_text(
|
176 |
+
filename: Union[str, Path], split: str = "|"
|
177 |
+
) -> list[list[str]]:
|
178 |
+
"""
|
179 |
+
指定されたファイルからファイルパスとテキストを読み込む
|
180 |
+
|
181 |
+
Args:
|
182 |
+
filename (Union[str, Path]): ファイルのパス
|
183 |
+
split (str): ファイルの区切り文字 (デフォルト: "|")
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
list[list[str]]: ファイルパスとテキストのリスト
|
187 |
+
"""
|
188 |
+
|
189 |
+
with open(filename, encoding="utf-8") as f:
|
190 |
+
filepaths_and_text = [line.strip().split(split) for line in f]
|
191 |
+
return filepaths_and_text
|
192 |
+
|
193 |
+
|
194 |
+
def get_logger(
|
195 |
+
model_dir_path: Union[str, Path], filename: str = "train.log"
|
196 |
+
) -> logging.Logger:
|
197 |
+
"""
|
198 |
+
ロガーを取得する
|
199 |
+
|
200 |
+
Args:
|
201 |
+
model_dir_path (Union[str, Path]): ログを保存するディレクトリのパス
|
202 |
+
filename (str): ログファイルの名前 (デフォルト: "train.log")
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
logging.Logger: ロガー
|
206 |
+
"""
|
207 |
+
|
208 |
+
global logger
|
209 |
+
logger = logging.getLogger(os.path.basename(model_dir_path))
|
210 |
+
logger.setLevel(logging.DEBUG)
|
211 |
+
|
212 |
+
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
213 |
+
if not os.path.exists(model_dir_path):
|
214 |
+
os.makedirs(model_dir_path)
|
215 |
+
h = logging.FileHandler(os.path.join(model_dir_path, filename))
|
216 |
+
h.setLevel(logging.DEBUG)
|
217 |
+
h.setFormatter(formatter)
|
218 |
+
logger.addHandler(h)
|
219 |
+
return logger
|
220 |
+
|
221 |
+
|
222 |
+
def get_steps(model_path: Union[str, Path]) -> Optional[int]:
|
223 |
+
"""
|
224 |
+
モデルのパスからイテレーション回数を取得する
|
225 |
+
|
226 |
+
Args:
|
227 |
+
model_path (Union[str, Path]): モデルのパス
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
Optional[int]: イテレーション回数
|
231 |
+
"""
|
232 |
+
|
233 |
+
matches = re.findall(r"\d+", model_path) # type: ignore
|
234 |
+
return matches[-1] if matches else None
|
235 |
+
|
236 |
+
|
237 |
+
def check_git_hash(model_dir_path: Union[str, Path]) -> None:
|
238 |
+
"""
|
239 |
+
モデルのディレクトリに .git ディレクトリが存在する場合、ハッシュ値を比較する
|
240 |
+
|
241 |
+
Args:
|
242 |
+
model_dir_path (Union[str, Path]): モデルのディレクトリのパス
|
243 |
+
"""
|
244 |
+
|
245 |
+
source_dir = os.path.dirname(os.path.realpath(__file__))
|
246 |
+
if not os.path.exists(os.path.join(source_dir, ".git")):
|
247 |
+
logger.warning(
|
248 |
+
f"{source_dir} is not a git repository, therefore hash value comparison will be ignored."
|
249 |
+
)
|
250 |
+
return
|
251 |
+
|
252 |
+
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
253 |
+
|
254 |
+
path = os.path.join(model_dir_path, "githash")
|
255 |
+
if os.path.exists(path):
|
256 |
+
with open(path, encoding="utf-8") as f:
|
257 |
+
saved_hash = f.read()
|
258 |
+
if saved_hash != cur_hash:
|
259 |
+
logger.warning(
|
260 |
+
f"git hash values are different. {saved_hash[:8]}(saved) != {cur_hash[:8]}(current)"
|
261 |
+
)
|
262 |
+
else:
|
263 |
+
with open(path, "w", encoding="utf-8") as f:
|
264 |
+
f.write(cur_hash)
|
style_bert_vits2/models/utils/checkpoints.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Optional, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from style_bert_vits2.logging import logger
|
10 |
+
|
11 |
+
|
12 |
+
def load_checkpoint(
|
13 |
+
checkpoint_path: Union[str, Path],
|
14 |
+
model: torch.nn.Module,
|
15 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
16 |
+
skip_optimizer: bool = False,
|
17 |
+
for_infer: bool = False,
|
18 |
+
) -> tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]:
|
19 |
+
"""
|
20 |
+
指定されたパスからチェックポイントを読み込み、モデルとオプティマイザーを更新する。
|
21 |
+
|
22 |
+
Args:
|
23 |
+
checkpoint_path (Union[str, Path]): チェックポイントファイルのパス
|
24 |
+
model (torch.nn.Module): 更新するモデル
|
25 |
+
optimizer (Optional[torch.optim.Optimizer]): 更新するオプティマイザー。None の場合は更新しない
|
26 |
+
skip_optimizer (bool): オプティマイザーの更新をスキップするかどうかのフラグ
|
27 |
+
for_infer (bool): 推論用に読み込むかどうかのフラグ
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]: 更新されたモデルとオプティマイザー、学習率、イテレーション回数
|
31 |
+
"""
|
32 |
+
|
33 |
+
assert os.path.isfile(checkpoint_path)
|
34 |
+
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
35 |
+
iteration = checkpoint_dict["iteration"]
|
36 |
+
learning_rate = checkpoint_dict["learning_rate"]
|
37 |
+
logger.info(
|
38 |
+
f"Loading model and optimizer at iteration {iteration} from {checkpoint_path}"
|
39 |
+
)
|
40 |
+
if (
|
41 |
+
optimizer is not None
|
42 |
+
and not skip_optimizer
|
43 |
+
and checkpoint_dict["optimizer"] is not None
|
44 |
+
):
|
45 |
+
optimizer.load_state_dict(checkpoint_dict["optimizer"])
|
46 |
+
elif optimizer is None and not skip_optimizer:
|
47 |
+
# else: Disable this line if Infer and resume checkpoint,then enable the line upper
|
48 |
+
new_opt_dict = optimizer.state_dict() # type: ignore
|
49 |
+
new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
|
50 |
+
new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
|
51 |
+
new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
|
52 |
+
optimizer.load_state_dict(new_opt_dict) # type: ignore
|
53 |
+
|
54 |
+
saved_state_dict = checkpoint_dict["model"]
|
55 |
+
if hasattr(model, "module"):
|
56 |
+
state_dict = model.module.state_dict()
|
57 |
+
else:
|
58 |
+
state_dict = model.state_dict()
|
59 |
+
|
60 |
+
new_state_dict = {}
|
61 |
+
for k, v in state_dict.items():
|
62 |
+
try:
|
63 |
+
# assert "emb_g" not in k
|
64 |
+
new_state_dict[k] = saved_state_dict[k]
|
65 |
+
assert saved_state_dict[k].shape == v.shape, (
|
66 |
+
saved_state_dict[k].shape,
|
67 |
+
v.shape,
|
68 |
+
)
|
69 |
+
except:
|
70 |
+
# For upgrading from the old version
|
71 |
+
if "ja_bert_proj" in k:
|
72 |
+
v = torch.zeros_like(v)
|
73 |
+
logger.warning(
|
74 |
+
f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
|
75 |
+
)
|
76 |
+
elif "enc_q" in k and for_infer:
|
77 |
+
continue
|
78 |
+
else:
|
79 |
+
logger.error(f"{k} is not in the checkpoint {checkpoint_path}")
|
80 |
+
|
81 |
+
new_state_dict[k] = v
|
82 |
+
|
83 |
+
if hasattr(model, "module"):
|
84 |
+
model.module.load_state_dict(new_state_dict, strict=False)
|
85 |
+
else:
|
86 |
+
model.load_state_dict(new_state_dict, strict=False)
|
87 |
+
|
88 |
+
logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})")
|
89 |
+
|
90 |
+
return model, optimizer, learning_rate, iteration
|
91 |
+
|
92 |
+
|
93 |
+
def save_checkpoint(
|
94 |
+
model: torch.nn.Module,
|
95 |
+
optimizer: Union[torch.optim.Optimizer, torch.optim.AdamW],
|
96 |
+
learning_rate: float,
|
97 |
+
iteration: int,
|
98 |
+
checkpoint_path: Union[str, Path],
|
99 |
+
) -> None:
|
100 |
+
"""
|
101 |
+
モデルとオプティマイザーの状態を指定されたパスに保存する。
|
102 |
+
|
103 |
+
Args:
|
104 |
+
model (torch.nn.Module): 保存するモデル
|
105 |
+
optimizer (Union[torch.optim.Optimizer, torch.optim.AdamW]): 保存するオプティマイザー
|
106 |
+
learning_rate (float): 学習率
|
107 |
+
iteration (int): イテレーション回数
|
108 |
+
checkpoint_path (Union[str, Path]): 保存先のパス
|
109 |
+
"""
|
110 |
+
logger.info(
|
111 |
+
f"Saving model and optimizer state at iteration {iteration} to {checkpoint_path}"
|
112 |
+
)
|
113 |
+
if hasattr(model, "module"):
|
114 |
+
state_dict = model.module.state_dict()
|
115 |
+
else:
|
116 |
+
state_dict = model.state_dict()
|
117 |
+
torch.save(
|
118 |
+
{
|
119 |
+
"model": state_dict,
|
120 |
+
"iteration": iteration,
|
121 |
+
"optimizer": optimizer.state_dict(),
|
122 |
+
"learning_rate": learning_rate,
|
123 |
+
},
|
124 |
+
checkpoint_path,
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
def clean_checkpoints(
|
129 |
+
model_dir_path: Union[str, Path] = "logs/44k/",
|
130 |
+
n_ckpts_to_keep: int = 2,
|
131 |
+
sort_by_time: bool = True,
|
132 |
+
) -> None:
|
133 |
+
"""
|
134 |
+
指定されたディレクトリから古いチェックポイントを削除して空き容量を確保する
|
135 |
+
|
136 |
+
Args:
|
137 |
+
model_dir_path (Union[str, Path]): モデルが保存されているディレクトリのパス
|
138 |
+
n_ckpts_to_keep (int): 保持するチェックポイントの数(G_0.pth と D_0.pth を除く)
|
139 |
+
sort_by_time (bool): True の場合、時間順に削除。False の場合、名前順に削除
|
140 |
+
"""
|
141 |
+
|
142 |
+
ckpts_files = [
|
143 |
+
f
|
144 |
+
for f in os.listdir(model_dir_path)
|
145 |
+
if os.path.isfile(os.path.join(model_dir_path, f))
|
146 |
+
]
|
147 |
+
|
148 |
+
def name_key(_f: str) -> int:
|
149 |
+
return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) # type: ignore
|
150 |
+
|
151 |
+
def time_key(_f: str) -> float:
|
152 |
+
return os.path.getmtime(os.path.join(model_dir_path, _f))
|
153 |
+
|
154 |
+
sort_key = time_key if sort_by_time else name_key
|
155 |
+
|
156 |
+
def x_sorted(_x: str) -> list[str]:
|
157 |
+
return sorted(
|
158 |
+
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
|
159 |
+
key=sort_key,
|
160 |
+
)
|
161 |
+
|
162 |
+
to_del = [
|
163 |
+
os.path.join(model_dir_path, fn)
|
164 |
+
for fn in (
|
165 |
+
x_sorted("G_")[:-n_ckpts_to_keep]
|
166 |
+
+ x_sorted("D_")[:-n_ckpts_to_keep]
|
167 |
+
+ x_sorted("WD_")[:-n_ckpts_to_keep]
|
168 |
+
+ x_sorted("DUR_")[:-n_ckpts_to_keep]
|
169 |
+
)
|
170 |
+
]
|
171 |
+
|
172 |
+
def del_info(fn: str) -> None:
|
173 |
+
return logger.info(f"Free up space by deleting ckpt {fn}")
|
174 |
+
|
175 |
+
def del_routine(x: str) -> list[Any]:
|
176 |
+
return [os.remove(x), del_info(x)]
|
177 |
+
|
178 |
+
[del_routine(fn) for fn in to_del]
|
179 |
+
|
180 |
+
|
181 |
+
def get_latest_checkpoint_path(
|
182 |
+
model_dir_path: Union[str, Path], regex: str = "G_*.pth"
|
183 |
+
) -> str:
|
184 |
+
"""
|
185 |
+
指定されたディレクトリから最新のチェックポイントのパスを取得する
|
186 |
+
|
187 |
+
Args:
|
188 |
+
model_dir_path (Union[str, Path]): モデルが保存されているディレクトリのパス
|
189 |
+
regex (str): チェックポイントのファイル名の正規表現
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
str: 最新のチェックポイントのパス
|
193 |
+
"""
|
194 |
+
|
195 |
+
f_list = glob.glob(os.path.join(str(model_dir_path), regex))
|
196 |
+
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
|
197 |
+
try:
|
198 |
+
x = f_list[-1]
|
199 |
+
except IndexError:
|
200 |
+
raise ValueError(f"No checkpoint found in {model_dir_path} with regex {regex}")
|
201 |
+
|
202 |
+
return x
|
style_bert_vits2/models/utils/safetensors.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Any, Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from safetensors import safe_open
|
6 |
+
from safetensors.torch import save_file
|
7 |
+
|
8 |
+
from style_bert_vits2.logging import logger
|
9 |
+
|
10 |
+
|
11 |
+
def load_safetensors(
|
12 |
+
checkpoint_path: Union[str, Path],
|
13 |
+
model: torch.nn.Module,
|
14 |
+
for_infer: bool = False,
|
15 |
+
) -> tuple[torch.nn.Module, Optional[int]]:
|
16 |
+
"""
|
17 |
+
指定されたパスから safetensors モデルを読み込み、モデルとイテレーションを返す。
|
18 |
+
|
19 |
+
Args:
|
20 |
+
checkpoint_path (Union[str, Path]): モデルのチェックポイントファイルのパス
|
21 |
+
model (torch.nn.Module): 読み込む対象のモデル
|
22 |
+
for_infer (bool): 推論用に読み込むかどうかのフラグ
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
tuple[torch.nn.Module, Optional[int]]: 読み込まれたモデルとイテレーション回数(存在する場合)
|
26 |
+
"""
|
27 |
+
|
28 |
+
tensors: dict[str, Any] = {}
|
29 |
+
iteration: Optional[int] = None
|
30 |
+
with safe_open(str(checkpoint_path), framework="pt", device="cpu") as f: # type: ignore
|
31 |
+
for key in f.keys():
|
32 |
+
if key == "iteration":
|
33 |
+
iteration = f.get_tensor(key).item()
|
34 |
+
tensors[key] = f.get_tensor(key)
|
35 |
+
if hasattr(model, "module"):
|
36 |
+
result = model.module.load_state_dict(tensors, strict=False)
|
37 |
+
else:
|
38 |
+
result = model.load_state_dict(tensors, strict=False)
|
39 |
+
for key in result.missing_keys:
|
40 |
+
if key.startswith("enc_q") and for_infer:
|
41 |
+
continue
|
42 |
+
logger.warning(f"Missing key: {key}")
|
43 |
+
for key in result.unexpected_keys:
|
44 |
+
if key == "iteration":
|
45 |
+
continue
|
46 |
+
logger.warning(f"Unexpected key: {key}")
|
47 |
+
if iteration is None:
|
48 |
+
logger.info(f"Loaded '{checkpoint_path}'")
|
49 |
+
else:
|
50 |
+
logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})")
|
51 |
+
|
52 |
+
return model, iteration
|
53 |
+
|
54 |
+
|
55 |
+
def save_safetensors(
|
56 |
+
model: torch.nn.Module,
|
57 |
+
iteration: int,
|
58 |
+
checkpoint_path: Union[str, Path],
|
59 |
+
is_half: bool = False,
|
60 |
+
for_infer: bool = False,
|
61 |
+
) -> None:
|
62 |
+
"""
|
63 |
+
モデルを safetensors 形式で保存する。
|
64 |
+
|
65 |
+
Args:
|
66 |
+
model (torch.nn.Module): 保存するモデル
|
67 |
+
iteration (int): イテレーション回数
|
68 |
+
checkpoint_path (Union[str, Path]): 保存先のパス
|
69 |
+
is_half (bool): モデルを半精度で保存するかどうかのフラグ
|
70 |
+
for_infer (bool): 推論用に保存するかどうかのフラグ
|
71 |
+
"""
|
72 |
+
|
73 |
+
if hasattr(model, "module"):
|
74 |
+
state_dict = model.module.state_dict()
|
75 |
+
else:
|
76 |
+
state_dict = model.state_dict()
|
77 |
+
keys = []
|
78 |
+
for k in state_dict:
|
79 |
+
if "enc_q" in k and for_infer:
|
80 |
+
continue
|
81 |
+
keys.append(k)
|
82 |
+
|
83 |
+
new_dict = (
|
84 |
+
{k: state_dict[k].half() for k in keys}
|
85 |
+
if is_half
|
86 |
+
else {k: state_dict[k] for k in keys}
|
87 |
+
)
|
88 |
+
new_dict["iteration"] = torch.LongTensor([iteration])
|
89 |
+
logger.info(f"Saved safetensors to {checkpoint_path}")
|
90 |
+
|
91 |
+
save_file(new_dict, checkpoint_path)
|
style_bert_vits2/nlp/__init__.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TYPE_CHECKING, Optional
|
2 |
+
|
3 |
+
from style_bert_vits2.constants import Languages
|
4 |
+
from style_bert_vits2.nlp.symbols import (
|
5 |
+
LANGUAGE_ID_MAP,
|
6 |
+
LANGUAGE_TONE_START_MAP,
|
7 |
+
SYMBOLS,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
# __init__.py は配下のモジュールをインポートした時点で実行される
|
12 |
+
# PyTorch のインポートは重いので、型チェック時以外はインポートしない
|
13 |
+
if TYPE_CHECKING:
|
14 |
+
import torch
|
15 |
+
|
16 |
+
|
17 |
+
__symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)}
|
18 |
+
|
19 |
+
|
20 |
+
def extract_bert_feature(
|
21 |
+
text: str,
|
22 |
+
word2ph: list[int],
|
23 |
+
language: Languages,
|
24 |
+
device: str,
|
25 |
+
assist_text: Optional[str] = None,
|
26 |
+
assist_text_weight: float = 0.7,
|
27 |
+
) -> "torch.Tensor":
|
28 |
+
"""
|
29 |
+
テキストから BERT の特徴量を抽出する
|
30 |
+
|
31 |
+
Args:
|
32 |
+
text (str): テキスト
|
33 |
+
word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト
|
34 |
+
language (Languages): テキストの言語
|
35 |
+
device (str): 推論に利用するデバイス
|
36 |
+
assist_text (Optional[str], optional): 補助テキスト (デフォルト: None)
|
37 |
+
assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7)
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: BERT の特徴量
|
41 |
+
"""
|
42 |
+
|
43 |
+
if language == Languages.JP:
|
44 |
+
from style_bert_vits2.nlp.japanese.bert_feature import extract_bert_feature
|
45 |
+
elif language == Languages.EN:
|
46 |
+
from style_bert_vits2.nlp.english.bert_feature import extract_bert_feature
|
47 |
+
elif language == Languages.ZH:
|
48 |
+
from style_bert_vits2.nlp.chinese.bert_feature import extract_bert_feature
|
49 |
+
else:
|
50 |
+
raise ValueError(f"Language {language} not supported")
|
51 |
+
|
52 |
+
return extract_bert_feature(text, word2ph, device, assist_text, assist_text_weight)
|
53 |
+
|
54 |
+
|
55 |
+
def clean_text(
|
56 |
+
text: str,
|
57 |
+
language: Languages,
|
58 |
+
use_jp_extra: bool = True,
|
59 |
+
raise_yomi_error: bool = False,
|
60 |
+
) -> tuple[str, list[str], list[int], list[int]]:
|
61 |
+
"""
|
62 |
+
テキストをクリーニングし、音素に変換する
|
63 |
+
|
64 |
+
Args:
|
65 |
+
text (str): クリーニングするテキスト
|
66 |
+
language (Languages): テキストの言語
|
67 |
+
use_jp_extra (bool, optional): テキストが日本語の場合に JP-Extra モデルを利用するかどうか。Defaults to True.
|
68 |
+
raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
tuple[str, list[str], list[int], list[int]]: クリーニングされたテキストと、音素・アクセント・元のテキストの各文字に音素が何個割り当てられるかのリスト
|
72 |
+
"""
|
73 |
+
|
74 |
+
# Changed to import inside if condition to avoid unnecessary import
|
75 |
+
if language == Languages.JP:
|
76 |
+
from style_bert_vits2.nlp.japanese.g2p import g2p
|
77 |
+
from style_bert_vits2.nlp.japanese.normalizer import normalize_text
|
78 |
+
|
79 |
+
norm_text = normalize_text(text)
|
80 |
+
phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error)
|
81 |
+
elif language == Languages.EN:
|
82 |
+
from style_bert_vits2.nlp.english.g2p import g2p
|
83 |
+
from style_bert_vits2.nlp.english.normalizer import normalize_text
|
84 |
+
|
85 |
+
norm_text = normalize_text(text)
|
86 |
+
phones, tones, word2ph = g2p(norm_text)
|
87 |
+
elif language == Languages.ZH:
|
88 |
+
from style_bert_vits2.nlp.chinese.g2p import g2p
|
89 |
+
from style_bert_vits2.nlp.chinese.normalizer import normalize_text
|
90 |
+
|
91 |
+
norm_text = normalize_text(text)
|
92 |
+
phones, tones, word2ph = g2p(norm_text)
|
93 |
+
else:
|
94 |
+
raise ValueError(f"Language {language} not supported")
|
95 |
+
|
96 |
+
return norm_text, phones, tones, word2ph
|
97 |
+
|
98 |
+
|
99 |
+
def cleaned_text_to_sequence(
|
100 |
+
cleaned_phones: list[str], tones: list[int], language: Languages
|
101 |
+
) -> tuple[list[int], list[int], list[int]]:
|
102 |
+
"""
|
103 |
+
音素リスト・アクセントリスト・言語を、テキスト内の対応する ID に変換する
|
104 |
+
|
105 |
+
Args:
|
106 |
+
cleaned_phones (list[str]): clean_text() でクリーニングされた音素のリスト
|
107 |
+
tones (list[int]): 各音素のアクセント
|
108 |
+
language (Languages): テキストの言語
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
tuple[list[int], list[int], list[int]]: List of integers corresponding to the symbols in the text
|
112 |
+
"""
|
113 |
+
|
114 |
+
phones = [__symbol_to_id[symbol] for symbol in cleaned_phones]
|
115 |
+
tone_start = LANGUAGE_TONE_START_MAP[language]
|
116 |
+
tones = [i + tone_start for i in tones]
|
117 |
+
lang_id = LANGUAGE_ID_MAP[language]
|
118 |
+
lang_ids = [lang_id for i in phones]
|
119 |
+
|
120 |
+
return phones, tones, lang_ids
|