CaiRou-Huang commited on
Commit
448c16f
·
verified ·
1 Parent(s): 93c3567

Upload 5 files

Browse files
common/constants.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+ DEFAULT_STYLE: str = "Neutral"
4
+ DEFAULT_STYLE_WEIGHT: float = 5.0
5
+
6
+
7
+ class Languages(str, enum.Enum):
8
+ JP = "JP"
9
+ EN = "EN"
10
+ ZH = "ZH"
11
+
12
+
13
+ DEFAULT_SDP_RATIO: float = 0.2
14
+ DEFAULT_NOISE: float = 0.6
15
+ DEFAULT_NOISEW: float = 0.8
16
+ DEFAULT_LENGTH: float = 1.0
17
+ DEFAULT_LINE_SPLIT: bool = True
18
+ DEFAULT_SPLIT_INTERVAL: float = 0.5
19
+ DEFAULT_ASSIST_TEXT_WEIGHT: float = 0.7
20
+ DEFAULT_ASSIST_TEXT_WEIGHT: float = 1.0
common/log.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ logger封装
3
+ """
4
+ from loguru import logger
5
+
6
+ from .stdout_wrapper import SAFE_STDOUT
7
+
8
+ # 移除所有默认的处理器
9
+ logger.remove()
10
+
11
+ # 自定义格式并添加到标准输出
12
+ log_format = (
13
+ "<g>{time:MM-DD HH:mm:ss}</g> |<lvl>{level:^8}</lvl>| {file}:{line} | {message}"
14
+ )
15
+
16
+ logger.add(SAFE_STDOUT, format=log_format, backtrace=True, diagnose=True)
common/stdout_wrapper.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import tempfile
3
+
4
+
5
+ class StdoutWrapper:
6
+ def __init__(self):
7
+ self.temp_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
8
+ self.original_stdout = sys.stdout
9
+
10
+ def write(self, message: str):
11
+ self.temp_file.write(message)
12
+ self.temp_file.flush()
13
+ print(message, end="", file=self.original_stdout)
14
+
15
+ def flush(self):
16
+ self.temp_file.flush()
17
+
18
+ def read(self):
19
+ self.temp_file.seek(0)
20
+ return self.temp_file.read()
21
+
22
+ def close(self):
23
+ self.temp_file.close()
24
+
25
+ def fileno(self):
26
+ return self.temp_file.fileno()
27
+
28
+
29
+ try:
30
+ import google.colab
31
+
32
+ SAFE_STDOUT = StdoutWrapper()
33
+ except ImportError:
34
+ SAFE_STDOUT = sys.stdout
common/subprocess_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ from .log import logger
5
+ from .stdout_wrapper import SAFE_STDOUT
6
+
7
+ python = sys.executable
8
+
9
+
10
+ def run_script_with_log(cmd: list[str], ignore_warning=False) -> tuple[bool, str]:
11
+ logger.info(f"Running: {' '.join(cmd)}")
12
+ result = subprocess.run(
13
+ [python] + cmd,
14
+ stdout=SAFE_STDOUT, # type: ignore
15
+ stderr=subprocess.PIPE,
16
+ text=True,
17
+ )
18
+ if result.returncode != 0:
19
+ logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}")
20
+ return False, result.stderr
21
+ elif result.stderr and not ignore_warning:
22
+ logger.warning(f"Warning: {' '.join(cmd)}\n{result.stderr}")
23
+ return True, result.stderr
24
+ logger.success(f"Success: {' '.join(cmd)}")
25
+ return True, ""
26
+
27
+
28
+ def second_elem_of(original_function):
29
+ def inner_function(*args, **kwargs):
30
+ return original_function(*args, **kwargs)[1]
31
+
32
+ return inner_function
common/tts_model.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import warnings
6
+ from gradio.processing_utils import convert_to_16_bit_wav
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ import utils
10
+ from infer import get_net_g, infer
11
+ from models import SynthesizerTrn
12
+ from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra
13
+
14
+ from .log import logger
15
+ from .constants import (
16
+ DEFAULT_ASSIST_TEXT_WEIGHT,
17
+ DEFAULT_LENGTH,
18
+ DEFAULT_LINE_SPLIT,
19
+ DEFAULT_NOISE,
20
+ DEFAULT_NOISEW,
21
+ DEFAULT_SDP_RATIO,
22
+ DEFAULT_SPLIT_INTERVAL,
23
+ DEFAULT_STYLE,
24
+ DEFAULT_STYLE_WEIGHT,
25
+ )
26
+
27
+
28
+ class Model:
29
+ def __init__(
30
+ self, model_path: str, config_path: str, style_vec_path: str, device: str
31
+ ):
32
+ self.model_path: str = model_path
33
+ self.config_path: str = config_path
34
+ self.device: str = device
35
+ self.style_vec_path: str = style_vec_path
36
+ self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path)
37
+ self.spk2id: Dict[str, int] = self.hps.data.spk2id
38
+ self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()}
39
+
40
+ self.num_styles: int = self.hps.data.num_styles
41
+ if hasattr(self.hps.data, "style2id"):
42
+ self.style2id: Dict[str, int] = self.hps.data.style2id
43
+ else:
44
+ self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)}
45
+ if len(self.style2id) != self.num_styles:
46
+ raise ValueError(
47
+ f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})"
48
+ )
49
+
50
+ self.style_vectors: np.ndarray = np.load(self.style_vec_path)
51
+ if self.style_vectors.shape[0] != self.num_styles:
52
+ raise ValueError(
53
+ f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})"
54
+ )
55
+
56
+ self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None
57
+
58
+ def load_net_g(self):
59
+ self.net_g = get_net_g(
60
+ model_path=self.model_path,
61
+ version=self.hps.version,
62
+ device=self.device,
63
+ hps=self.hps,
64
+ )
65
+
66
+ def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray:
67
+ mean = self.style_vectors[0]
68
+ style_vec = self.style_vectors[style_id]
69
+ style_vec = mean + (style_vec - mean) * weight
70
+ return style_vec
71
+
72
+ def get_style_vector_from_audio(
73
+ self, audio_path: str, weight: float = 1.0
74
+ ) -> np.ndarray:
75
+ from style_gen import get_style_vector
76
+
77
+ xvec = get_style_vector(audio_path)
78
+ mean = self.style_vectors[0]
79
+ xvec = mean + (xvec - mean) * weight
80
+ return xvec
81
+
82
+ def infer(
83
+ self,
84
+ text: str,
85
+ language: str = "JP",
86
+ sid: int = 0,
87
+ reference_audio_path: Optional[str] = None,
88
+ sdp_ratio: float = DEFAULT_SDP_RATIO,
89
+ noise: float = DEFAULT_NOISE,
90
+ noisew: float = DEFAULT_NOISEW,
91
+ length: float = DEFAULT_LENGTH,
92
+ line_split: bool = DEFAULT_LINE_SPLIT,
93
+ split_interval: float = DEFAULT_SPLIT_INTERVAL,
94
+ assist_text: Optional[str] = None,
95
+ assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT,
96
+ use_assist_text: bool = False,
97
+ style: str = DEFAULT_STYLE,
98
+ style_weight: float = DEFAULT_STYLE_WEIGHT,
99
+ given_tone: Optional[list[int]] = None,
100
+ ) -> tuple[int, np.ndarray]:
101
+ logger.info(f"Start generating audio data from text:\n{text}")
102
+ if language != "JP" and self.hps.version.endswith("JP-Extra"):
103
+ raise ValueError(
104
+ "The model is trained with JP-Extra, but the language is not JP"
105
+ )
106
+ if reference_audio_path == "":
107
+ reference_audio_path = None
108
+ if assist_text == "" or not use_assist_text:
109
+ assist_text = None
110
+
111
+ if self.net_g is None:
112
+ self.load_net_g()
113
+ if reference_audio_path is None:
114
+ style_id = self.style2id[style]
115
+ style_vector = self.get_style_vector(style_id, style_weight)
116
+ else:
117
+ style_vector = self.get_style_vector_from_audio(
118
+ reference_audio_path, style_weight
119
+ )
120
+ if not line_split:
121
+ with torch.no_grad():
122
+ audio = infer(
123
+ text=text,
124
+ sdp_ratio=sdp_ratio,
125
+ noise_scale=noise,
126
+ noise_scale_w=noisew,
127
+ length_scale=length,
128
+ sid=sid,
129
+ language=language,
130
+ hps=self.hps,
131
+ net_g=self.net_g,
132
+ device=self.device,
133
+ assist_text=assist_text,
134
+ assist_text_weight=assist_text_weight,
135
+ style_vec=style_vector,
136
+ given_tone=given_tone,
137
+ )
138
+ else:
139
+ texts = text.split("\n")
140
+ texts = [t for t in texts if t != ""]
141
+ audios = []
142
+ with torch.no_grad():
143
+ for i, t in enumerate(texts):
144
+ audios.append(
145
+ infer(
146
+ text=t,
147
+ sdp_ratio=sdp_ratio,
148
+ noise_scale=noise,
149
+ noise_scale_w=noisew,
150
+ length_scale=length,
151
+ sid=sid,
152
+ language=language,
153
+ hps=self.hps,
154
+ net_g=self.net_g,
155
+ device=self.device,
156
+ assist_text=assist_text,
157
+ assist_text_weight=assist_text_weight,
158
+ style_vec=style_vector,
159
+ )
160
+ )
161
+ if i != len(texts) - 1:
162
+ audios.append(np.zeros(int(44100 * split_interval)))
163
+ audio = np.concatenate(audios)
164
+ with warnings.catch_warnings():
165
+ warnings.simplefilter("ignore")
166
+ audio = convert_to_16_bit_wav(audio)
167
+ logger.info("Audio data generated successfully")
168
+ return (self.hps.data.sampling_rate, audio)
169
+
170
+
171
+ class ModelHolder:
172
+ def __init__(self, root_dir: str, device: str):
173
+ self.root_dir: str = root_dir
174
+ self.device: str = device
175
+ self.model_files_dict: Dict[str, List[str]] = {}
176
+ self.current_model: Optional[Model] = None
177
+ self.model_names: List[str] = []
178
+ self.models: List[Model] = []
179
+ self.refresh()
180
+
181
+ def refresh(self):
182
+ self.model_files_dict = {}
183
+ self.model_names = []
184
+ self.current_model = None
185
+ model_dirs = [
186
+ d
187
+ for d in os.listdir(self.root_dir)
188
+ if os.path.isdir(os.path.join(self.root_dir, d))
189
+ ]
190
+ for model_name in model_dirs:
191
+ model_dir = os.path.join(self.root_dir, model_name)
192
+ model_files = [
193
+ os.path.join(model_dir, f)
194
+ for f in os.listdir(model_dir)
195
+ if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors")
196
+ ]
197
+ if len(model_files) == 0:
198
+ logger.warning(
199
+ f"No model files found in {self.root_dir}/{model_name}, so skip it"
200
+ )
201
+ continue
202
+ self.model_files_dict[model_name] = model_files
203
+ self.model_names.append(model_name)
204
+
205
+ def load_model_gr(
206
+ self, model_name: str, model_path: str
207
+ ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]:
208
+ if model_name not in self.model_files_dict:
209
+ raise ValueError(f"Model `{model_name}` is not found")
210
+ if model_path not in self.model_files_dict[model_name]:
211
+ raise ValueError(f"Model file `{model_path}` is not found")
212
+ if (
213
+ self.current_model is not None
214
+ and self.current_model.model_path == model_path
215
+ ):
216
+ # Already loaded
217
+ speakers = list(self.current_model.spk2id.keys())
218
+ styles = list(self.current_model.style2id.keys())
219
+ return (
220
+ gr.Dropdown(choices=styles, value=styles[0]),
221
+ gr.Button(interactive=True, value="音声合成"),
222
+ gr.Dropdown(choices=speakers, value=speakers[0]),
223
+ )
224
+ self.current_model = Model(
225
+ model_path=model_path,
226
+ config_path=os.path.join(self.root_dir, model_name, "config.json"),
227
+ style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"),
228
+ device=self.device,
229
+ )
230
+ speakers = list(self.current_model.spk2id.keys())
231
+ styles = list(self.current_model.style2id.keys())
232
+ return (
233
+ gr.Dropdown(choices=styles, value=styles[0]),
234
+ gr.Button(interactive=True, value="音声合成"),
235
+ gr.Dropdown(choices=speakers, value=speakers[0]),
236
+ )
237
+
238
+ def update_model_files_gr(self, model_name: str) -> gr.Dropdown:
239
+ model_files = self.model_files_dict[model_name]
240
+ return gr.Dropdown(choices=model_files, value=model_files[0])
241
+
242
+ def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]:
243
+ self.refresh()
244
+ initial_model_name = self.model_names[0]
245
+ initial_model_files = self.model_files_dict[initial_model_name]
246
+ return (
247
+ gr.Dropdown(choices=self.model_names, value=initial_model_name),
248
+ gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]),
249
+ gr.Button(interactive=False), # For tts_button
250
+ )