Vijish commited on
Commit
dc3db46
1 Parent(s): 4d2fd1c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +169 -174
handler.py CHANGED
@@ -37,201 +37,196 @@ limitation = os.getenv("SYSTEM") == "spaces"
37
 
38
  config = Config()
39
 
40
- # Edge TTS
41
- tts_voice_list = edge_tts.list_voices()
42
- tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
 
 
 
 
 
 
 
 
 
43
 
44
- # RVC models
45
- model_root = "weights"
46
- models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
47
- models.sort()
48
 
49
- def get_unique_filename(extension):
50
- return f"{uuid.uuid4()}.{extension}"
51
 
52
- def model_data(model_name):
53
- # global n_spk, tgt_sr, net_g, vc, cpt, version, index_file
54
- pth_path = [
55
- f"{model_root}/{model_name}/{f}"
56
- for f in os.listdir(f"{model_root}/{model_name}")
57
- if f.endswith(".pth")
58
- ][0]
59
- print(f"Loading {pth_path}")
60
- cpt = torch.load(pth_path, map_location="cpu")
61
- tgt_sr = cpt["config"][-1]
62
- cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
63
- if_f0 = cpt.get("f0", 1)
64
- version = cpt.get("version", "v1")
65
- if version == "v1":
66
- if if_f0 == 1:
67
- net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  else:
69
- net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
70
- elif version == "v2":
71
- if if_f0 == 1:
72
- net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
 
 
 
 
 
 
 
 
73
  else:
74
- net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
75
- else:
76
- raise ValueError("Unknown version")
77
- del net_g.enc_q
78
- net_g.load_state_dict(cpt["weight"], strict=False)
79
- print("Model loaded")
80
- net_g.eval().to(config.device)
81
- if config.is_half:
82
- net_g = net_g.half()
83
- else:
84
- net_g = net_g.float()
85
- vc = VC(tgt_sr, config)
86
- # n_spk = cpt["config"][-3]
87
 
88
- index_files = [
89
- f"{model_root}/{model_name}/{f}"
90
- for f in os.listdir(f"{model_root}/{model_name}")
91
- if f.endswith(".index")
92
- ]
93
- if len(index_files) == 0:
94
- print("No index file found")
95
- index_file = ""
96
- else:
97
- index_file = index_files[0]
98
- print(f"Index file found: {index_file}")
99
 
100
- return tgt_sr, net_g, vc, version, index_file, if_f0
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
 
102
 
103
- def load_hubert():
104
- # global hubert_model
105
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
106
- ["hubert_base.pt"],
107
- suffix="",
108
- )
109
- hubert_model = models[0]
110
- hubert_model = hubert_model.to(config.device)
111
- if config.is_half:
112
- hubert_model = hubert_model.half()
113
- else:
114
- hubert_model = hubert_model.float()
115
- return hubert_model.eval()
116
 
117
- def get_model_names():
118
- model_root = "weights" # Assuming this is where your models are stored
119
- return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")]
120
 
121
- def tts(model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice):
122
- # Default values for parameters used in EdgeTTS
123
- speed = 0 # Default speech speed
124
- f0_up_key = 0 # Default pitch adjustment
125
- f0_method = "rmvpe" # Default pitch extraction method
126
- protect = 0.33 # Default protect value
127
- filter_radius = 3
128
- resample_sr = 0
129
- rms_mix_rate = 0.25
130
- edge_time = 0 # Initialize edge_time
131
 
132
- edge_output_filename = get_unique_filename("mp3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- try:
135
- if use_uploaded_voice:
136
- if uploaded_voice is None:
137
- return "No voice file uploaded.", None, None
138
-
139
- # Process the uploaded voice file
140
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
141
- tmp_file.write(uploaded_voice)
142
- uploaded_file_path = tmp_file.name
143
 
144
- audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
145
- else:
146
- # EdgeTTS processing
147
- if limitation and len(tts_text) > 4000:
148
  return (
149
- f"Text characters should be at most 280 in this huggingface space, but got {len(tts_text)} characters.",
150
  None,
151
  None,
152
  )
153
-
154
- # Invoke Edge TTS
155
- t0 = time.time()
156
- speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
157
- edge_tts.Communicate(tts_text, tts_voice, rate=speed_str).save(edge_output_filename)
158
- t1 = time.time()
159
- edge_time = t1 - t0
160
 
161
- audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
 
162
 
163
- # Common processing after loading the audio
164
- duration = len(audio) / sr
165
- print(f"Audio duration: {duration}s")
166
- if limitation and duration >= 20:
167
- return (
168
- f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
169
- None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  None,
171
  )
172
 
173
- f0_up_key = int(f0_up_key)
174
- tgt_sr, net_g, vc, version, index_file, if_f0 = model_data(model_name)
175
-
176
- # Setup for RMVPE or other pitch extraction methods
177
- if f0_method == "rmvpe":
178
- vc.model_rmvpe = rmvpe_model
179
-
180
- # Perform voice conversion pipeline
181
- times = [0, 0, 0]
182
- audio_opt = vc.pipeline(
183
- hubert_model,
184
- net_g,
185
- 0,
186
- audio,
187
- edge_output_filename if not use_uploaded_voice else uploaded_file_path,
188
- times,
189
- f0_up_key,
190
- f0_method,
191
- index_file,
192
- index_rate,
193
- if_f0,
194
- filter_radius,
195
- tgt_sr,
196
- resample_sr,
197
- rms_mix_rate,
198
- version,
199
- protect,
200
- None,
201
- )
202
-
203
- if tgt_sr != resample_sr and resample_sr >= 16000:
204
- tgt_sr = resample_sr
205
-
206
- info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
207
- print(info)
208
- return (
209
- info,
210
- edge_output_filename if not use_uploaded_voice else None,
211
- (tgt_sr, audio_opt),
212
- edge_output_filename
213
- )
214
-
215
- except EOFError:
216
- info = "Output not valid. This may occur when input text and speaker do not match."
217
- print(info)
218
- return info, None, None
219
- except Exception as e:
220
- traceback_info = traceback.format_exc()
221
- print(traceback_info)
222
- return str(e), None, None
223
-
224
- voice_mapping = {
225
- "Mongolian Male": "mn-MN-BataaNeural",
226
- "Mongolian Female": "mn-MN-YesuiNeural"
227
- }
228
-
229
- hubert_model = load_hubert()
230
- rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
231
 
232
- class EndpointHandler:
233
- def __init__(self, model_dir=None):
234
- self.model_dir = model_dir
 
 
 
 
 
235
 
236
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
237
  try:
@@ -265,11 +260,11 @@ class EndpointHandler:
265
  use_uploaded_voice = json_data["use_uploaded_voice"]
266
  voice_upload_file = json_data.get("voice_upload_file", None)
267
 
268
- edge_tts_voice = voice_mapping.get(selected_voice)
269
  if not edge_tts_voice:
270
  raise ValueError(f"Invalid voice '{selected_voice}'.")
271
 
272
- info, edge_tts_output_path, tts_output_data, edge_output_file = tts(
273
  model_name,
274
  tts_text,
275
  edge_tts_voice,
@@ -299,6 +294,6 @@ class EndpointHandler:
299
  raise ValueError("Invalid JSON structure.")
300
 
301
  def save_audio_data_to_file(self, audio_data, sample_rate=40000):
302
- file_path = get_unique_filename('wav')
303
  wavfile.write(file_path, sample_rate, audio_data)
304
  return file_path
 
37
 
38
  config = Config()
39
 
40
+ class EndpointHandler:
41
+ def __init__(self, model_dir=None):
42
+ self.model_dir = model_dir
43
+ self.hubert_model = self.load_hubert()
44
+ self.rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device)
45
+ self.voice_mapping = {
46
+ "Mongolian Male": "mn-MN-BataaNeural",
47
+ "Mongolian Female": "mn-MN-YesuiNeural"
48
+ }
49
+ # Edge TTS
50
+ self.tts_voice_list = edge_tts.list_voices()
51
+ self.tts_voices = ["mn-MN-BataaNeural", "mn-MN-YesuiNeural"] # Specific voices
52
 
53
+ # RVC models
54
+ self.model_root = "weights"
55
+ self.models = [d for d in os.listdir(self.model_root) if os.path.isdir(f"{self.model_root}/{d}")]
56
+ self.models.sort()
57
 
58
+ def get_unique_filename(self, extension):
59
+ return f"{uuid.uuid4()}.{extension}"
60
 
61
+ def model_data(self, model_name):
62
+ # global n_spk, tgt_sr, net_g, vc, cpt, version, index_file
63
+ pth_path = [
64
+ f"{self.model_root}/{model_name}/{f}"
65
+ for f in os.listdir(f"{self.model_root}/{model_name}")
66
+ if f.endswith(".pth")
67
+ ][0]
68
+ print(f"Loading {pth_path}")
69
+ cpt = torch.load(pth_path, map_location="cpu")
70
+ tgt_sr = cpt["config"][-1]
71
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
72
+ if_f0 = cpt.get("f0", 1)
73
+ version = cpt.get("version", "v1")
74
+ if version == "v1":
75
+ if if_f0 == 1:
76
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
77
+ else:
78
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
79
+ elif version == "v2":
80
+ if if_f0 == 1:
81
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
82
+ else:
83
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
84
+ else:
85
+ raise ValueError("Unknown version")
86
+ del net_g.enc_q
87
+ net_g.load_state_dict(cpt["weight"], strict=False)
88
+ print("Model loaded")
89
+ net_g.eval().to(config.device)
90
+ if config.is_half:
91
+ net_g = net_g.half()
92
  else:
93
+ net_g = net_g.float()
94
+ vc = VC(tgt_sr, config)
95
+ # n_spk = cpt["config"][-3]
96
+
97
+ index_files = [
98
+ f"{self.model_root}/{model_name}/{f}"
99
+ for f in os.listdir(f"{self.model_root}/{model_name}")
100
+ if f.endswith(".index")
101
+ ]
102
+ if len(index_files) == 0:
103
+ print("No index file found")
104
+ index_file = ""
105
  else:
106
+ index_file = index_files[0]
107
+ print(f"Index file found: {index_file}")
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ return tgt_sr, net_g, vc, version, index_file, if_f0
 
 
 
 
 
 
 
 
 
 
110
 
111
+ def load_hubert(self):
112
+ # global hubert_model
113
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
114
+ ["hubert_base.pt"],
115
+ suffix="",
116
+ )
117
+ hubert_model = models[0]
118
+ hubert_model = hubert_model.to(config.device)
119
+ if config.is_half:
120
+ hubert_model = hubert_model.half()
121
+ else:
122
+ hubert_model = hubert_model.float()
123
+ return hubert_model.eval()
124
 
125
+ def get_model_names(self):
126
+ return [d for d in os.listdir(self.model_root) if os.path.isdir(f"{self.model_root}/{d}")]
127
 
128
+ def tts(self, model_name, tts_text, tts_voice, index_rate, use_uploaded_voice, uploaded_voice):
129
+ # Default values for parameters used in EdgeTTS
130
+ speed = 0 # Default speech speed
131
+ f0_up_key = 0 # Default pitch adjustment
132
+ f0_method = "rmvpe" # Default pitch extraction method
133
+ protect = 0.33 # Default protect value
134
+ filter_radius = 3
135
+ resample_sr = 0
136
+ rms_mix_rate = 0.25
137
+ edge_time = 0 # Initialize edge_time
 
 
 
138
 
139
+ edge_output_filename = self.get_unique_filename("mp3")
 
 
140
 
141
+ try:
142
+ if use_uploaded_voice:
143
+ if uploaded_voice is None:
144
+ return "No voice file uploaded.", None, None
145
+
146
+ # Process the uploaded voice file
147
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
148
+ tmp_file.write(uploaded_voice)
149
+ uploaded_file_path = tmp_file.name
 
150
 
151
+ audio, sr = librosa.load(uploaded_file_path, sr=16000, mono=True)
152
+ else:
153
+ # EdgeTTS processing
154
+ if limitation and len(tts_text) > 4000:
155
+ return (
156
+ f"Text characters should be at most 280 in this huggingface space, but got {len(tts_text)} characters.",
157
+ None,
158
+ None,
159
+ )
160
+
161
+ # Invoke Edge TTS
162
+ t0 = time.time()
163
+ speed_str = f"+{speed}%" if speed >= 0 else f"{speed}%"
164
+ edge_tts.Communicate(tts_text, tts_voice, rate=speed_str).save(edge_output_filename)
165
+ t1 = time.time()
166
+ edge_time = t1 - t0
167
 
168
+ audio, sr = librosa.load(edge_output_filename, sr=16000, mono=True)
 
 
 
 
 
 
 
 
169
 
170
+ # Common processing after loading the audio
171
+ duration = len(audio) / sr
172
+ print(f"Audio duration: {duration}s")
173
+ if limitation and duration >= 20:
174
  return (
175
+ f"Audio should be less than 20 seconds in this huggingface space, but got {duration}s.",
176
  None,
177
  None,
178
  )
 
 
 
 
 
 
 
179
 
180
+ f0_up_key = int(f0_up_key)
181
+ tgt_sr, net_g, vc, version, index_file, if_f0 = self.model_data(model_name)
182
 
183
+ # Setup for RMVPE or other pitch extraction methods
184
+ if f0_method == "rmvpe":
185
+ vc.model_rmvpe = self.rmvpe_model
186
+
187
+ # Perform voice conversion pipeline
188
+ times = [0, 0, 0]
189
+ audio_opt = vc.pipeline(
190
+ self.hubert_model,
191
+ net_g,
192
+ 0,
193
+ audio,
194
+ edge_output_filename if not use_uploaded_voice else uploaded_file_path,
195
+ times,
196
+ f0_up_key,
197
+ f0_method,
198
+ index_file,
199
+ index_rate,
200
+ if_f0,
201
+ filter_radius,
202
+ tgt_sr,
203
+ resample_sr,
204
+ rms_mix_rate,
205
+ version,
206
+ protect,
207
  None,
208
  )
209
 
210
+ if tgt_sr != resample_sr and resample_sr >= 16000:
211
+ tgt_sr = resample_sr
212
+
213
+ info = f"Success. Time: tts: {edge_time}s, npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s"
214
+ print(info)
215
+ return (
216
+ info,
217
+ edge_output_filename if not use_uploaded_voice else None,
218
+ (tgt_sr, audio_opt),
219
+ edge_output_filename
220
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ except EOFError:
223
+ info = "Output not valid. This may occur when input text and speaker do not match."
224
+ print(info)
225
+ return info, None, None
226
+ except Exception as e:
227
+ traceback_info = traceback.format_exc()
228
+ print(traceback_info)
229
+ return str(e), None, None
230
 
231
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
232
  try:
 
260
  use_uploaded_voice = json_data["use_uploaded_voice"]
261
  voice_upload_file = json_data.get("voice_upload_file", None)
262
 
263
+ edge_tts_voice = self.voice_mapping.get(selected_voice)
264
  if not edge_tts_voice:
265
  raise ValueError(f"Invalid voice '{selected_voice}'.")
266
 
267
+ info, edge_tts_output_path, tts_output_data, edge_output_file = self.tts(
268
  model_name,
269
  tts_text,
270
  edge_tts_voice,
 
294
  raise ValueError("Invalid JSON structure.")
295
 
296
  def save_audio_data_to_file(self, audio_data, sample_rate=40000):
297
+ file_path = self.get_unique_filename('wav')
298
  wavfile.write(file_path, sample_rate, audio_data)
299
  return file_path