willwade commited on
Commit
56c77d7
·
verified ·
1 Parent(s): 02f98a8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +531 -0
app.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ import shutil
5
+ import subprocess
6
+ import requests
7
+ import tarfile
8
+ from pathlib import Path
9
+ import soundfile as sf
10
+ import sherpa_onnx
11
+ from deep_translator import GoogleTranslator
12
+ import numpy as np
13
+ from iso639 import Lang
14
+ import pycountry
15
+
16
+ # Load model JSON
17
+ MODEL_JSON_URL = "https://github.com/willwade/tts-wrapper/blob/main/tts_wrapper/engines/sherpaonnx/merged_models.json"
18
+ MODEL_JSON_PATH = "./models.json"
19
+
20
+ # Load models
21
+ if not os.path.exists(MODEL_JSON_PATH):
22
+ response = requests.get(MODEL_JSON_URL.replace("/blob/", "/raw/"))
23
+ with open(MODEL_JSON_PATH, "w") as f:
24
+ f.write(response.text)
25
+
26
+ with open(MODEL_JSON_PATH, "r") as f:
27
+ models = json.load(f)
28
+
29
+ def get_model_display_info(model_info):
30
+ """Create a display string for a model."""
31
+ # Get language info
32
+ lang_info = model_info.get('language', [{}])[0]
33
+ lang_name = lang_info.get('language_name', lang_info.get('Language Name', 'Unknown'))
34
+ lang_code = lang_info.get('lang_code', lang_info.get('Iso Code', 'Unknown'))
35
+
36
+ # Get model info
37
+ voice_name = model_info.get('name', model_info.get('id', 'Unknown'))
38
+ developer = model_info.get('developer', '')
39
+ quality = model_info.get('quality', 'MMS' if 'mms' in voice_name.lower() else '')
40
+
41
+ # Create display name
42
+ model_display = f"{voice_name} ({developer}"
43
+ if quality:
44
+ model_display += f" - {quality}"
45
+ model_display += ")"
46
+
47
+ # Combine language and model info
48
+ return f"{lang_name} ({lang_code}) | {model_display}"
49
+
50
+ # Group models by language
51
+ models_by_lang = {}
52
+ for model_id, model_info in models.items():
53
+ # Get language info from the first language in the list
54
+ lang_info = model_info.get('language', [{}])[0]
55
+ lang_name = lang_info.get('language_name', lang_info.get('Language Name', 'Unknown'))
56
+ lang_code = lang_info.get('lang_code', lang_info.get('Iso Code', 'Unknown'))
57
+ group_key = f"{lang_name} ({lang_code})"
58
+
59
+ if group_key not in models_by_lang:
60
+ models_by_lang[group_key] = []
61
+
62
+ # Add model to language group
63
+ models_by_lang[group_key].append((get_model_display_info(model_info), model_id))
64
+
65
+ # Create dropdown choices with model IDs as values
66
+ dropdown_choices = []
67
+ models_by_display = {} # Map display names to model IDs
68
+ for lang, model_list in sorted(models_by_lang.items()):
69
+ # Add all models in this language group
70
+ for display_name, model_id in sorted(model_list):
71
+ dropdown_choices.append(display_name)
72
+ models_by_display[display_name] = model_id
73
+
74
+ def get_language_code(model_info):
75
+ """Get the language code."""
76
+ if not model_info.get("language"):
77
+ return None
78
+
79
+ lang_info = model_info["language"][0]
80
+ # Try both key formats for language code
81
+ lang_code = lang_info.get("lang_code", lang_info.get("Iso Code", "")).lower()
82
+ return lang_code
83
+
84
+ # Special cases for codes not in ISO standard
85
+ SPECIAL_CODES = {
86
+ "cmn": "zh", # Mandarin Chinese
87
+ "yue": "zh", # Cantonese
88
+ "pi": "el", # Pali (using Greek for this model)
89
+ "guj": "gu", # Gujarati
90
+ }
91
+
92
+ def get_translate_code(iso_code):
93
+ """Convert ISO code to Google Translate code."""
94
+ if not iso_code:
95
+ return None
96
+
97
+ # Remove any script or dialect specifiers
98
+ base_code = iso_code.split('-')[0].lower()
99
+
100
+ # Check special cases first
101
+ if base_code in SPECIAL_CODES:
102
+ return SPECIAL_CODES[base_code]
103
+
104
+ try:
105
+ # Try to get the ISO 639-1 (2-letter) code
106
+ lang = Lang(base_code)
107
+ return lang.pt1
108
+ except:
109
+ # If that fails, try to find a matching language in pycountry
110
+ try:
111
+ lang = pycountry.languages.get(alpha_3=base_code)
112
+ if lang and hasattr(lang, 'alpha_2'):
113
+ return lang.alpha_2
114
+ except:
115
+ pass
116
+
117
+ # If all else fails, try to use the original code
118
+ if len(base_code) == 2:
119
+ return base_code
120
+
121
+ return None
122
+
123
+ def translate_text(input_text, source_lang="en", target_lang="en"):
124
+ """Translate text using Google Translator."""
125
+ if source_lang == target_lang:
126
+ return input_text
127
+ try:
128
+ # Convert ISO code to Google Translate code
129
+ target_lang = get_translate_code(target_lang)
130
+
131
+ try:
132
+ translated = GoogleTranslator(source=source_lang, target=target_lang).translate(input_text)
133
+ return f"{translated} (translated from: {input_text})"
134
+ except Exception as first_error:
135
+ # If the first attempt fails with the mapped code, try with the original
136
+ try:
137
+ translated = GoogleTranslator(source=source_lang, target=target_lang).translate(input_text)
138
+ return f"{translated} (translated from: {input_text})"
139
+ except:
140
+ raise first_error
141
+
142
+ except Exception as e:
143
+ print(f"Translation error: {str(e)} for target language: {target_lang}")
144
+ print(f"Attempted to use language code: {target_lang}")
145
+ return f"Translation Error: Could not translate to {target_lang}. Original text: {input_text}"
146
+
147
+ def download_and_extract_model(url, destination):
148
+ """Download and extract the model files."""
149
+ print(f"Downloading from URL: {url}")
150
+ print(f"Destination: {destination}")
151
+
152
+ # Convert Hugging Face URL format if needed
153
+ if "huggingface.co" in url:
154
+ # Replace /tree/main/ with /resolve/main/ for direct file download
155
+ base_url = url.replace("/tree/main/", "/resolve/main/")
156
+ model_id = base_url.split("/")[-1]
157
+
158
+ # Check if this is an MMS model
159
+ is_mms_model = "mms-tts-multilingual-models-onnx" in url
160
+
161
+ if is_mms_model:
162
+ # MMS models have both model.onnx and tokens.txt
163
+ model_url = f"{base_url}/model.onnx"
164
+ tokens_url = f"{base_url}/tokens.txt"
165
+
166
+ # Download model.onnx
167
+ print("Downloading model.onnx...")
168
+ model_path = os.path.join(destination, "model.onnx")
169
+ response = requests.get(model_url, stream=True)
170
+ if response.status_code != 200:
171
+ raise Exception(f"Failed to download model from {model_url}. Status code: {response.status_code}")
172
+
173
+ total_size = int(response.headers.get('content-length', 0))
174
+ block_size = 8192
175
+ downloaded = 0
176
+
177
+ print(f"Total size: {total_size / (1024*1024):.1f} MB")
178
+ with open(model_path, "wb") as f:
179
+ for chunk in response.iter_content(chunk_size=block_size):
180
+ if chunk:
181
+ f.write(chunk)
182
+ downloaded += len(chunk)
183
+ if total_size > 0:
184
+ percent = int((downloaded / total_size) * 100)
185
+ if percent % 10 == 0:
186
+ print(f" {percent}%", end="", flush=True)
187
+ print("\nModel download complete")
188
+
189
+ # Download tokens.txt
190
+ print("Downloading tokens.txt...")
191
+ tokens_path = os.path.join(destination, "tokens.txt")
192
+ response = requests.get(tokens_url, stream=True)
193
+ if response.status_code != 200:
194
+ raise Exception(f"Failed to download tokens from {tokens_url}. Status code: {response.status_code}")
195
+
196
+ with open(tokens_path, "wb") as f:
197
+ f.write(response.content)
198
+ print("Tokens download complete")
199
+
200
+ return
201
+ else:
202
+ # Other models are stored as tar.bz2 files
203
+ url = f"{base_url}.tar.bz2"
204
+
205
+ # Try the URL
206
+ response = requests.get(url, stream=True)
207
+ if response.status_code != 200:
208
+ raise Exception(f"Failed to download model from {url}. Status code: {response.status_code}")
209
+
210
+ # Check if this is a Git LFS file pointer
211
+ content_start = response.content[:100].decode('utf-8', errors='ignore')
212
+ if content_start.startswith('version https://git-lfs.github.com/spec/v1'):
213
+ raise Exception(f"Received Git LFS pointer instead of file content from {url}")
214
+
215
+ # Create model directory if it doesn't exist
216
+ os.makedirs(destination, exist_ok=True)
217
+
218
+ # For non-MMS models, handle tar.bz2 files
219
+ tar_path = os.path.join(destination, "model.tar.bz2")
220
+
221
+ # Download the file
222
+ print("Downloading model archive...")
223
+ response = requests.get(url, stream=True)
224
+ total_size = int(response.headers.get('content-length', 0))
225
+ block_size = 8192
226
+ downloaded = 0
227
+
228
+ print(f"Total size: {total_size / (1024*1024):.1f} MB")
229
+ with open(tar_path, "wb") as f:
230
+ for chunk in response.iter_content(chunk_size=block_size):
231
+ if chunk:
232
+ f.write(chunk)
233
+ downloaded += len(chunk)
234
+ if total_size > 0:
235
+ percent = int((downloaded / total_size) * 100)
236
+ if percent % 10 == 0:
237
+ print(f" {percent}%", end="", flush=True)
238
+ print("\nDownload complete")
239
+
240
+ # Extract the tar.bz2 file
241
+ print(f"Extracting {tar_path} to {destination}")
242
+ try:
243
+ with tarfile.open(tar_path, "r:bz2") as tar:
244
+ tar.extractall(path=destination)
245
+ os.remove(tar_path)
246
+ print("Extraction complete")
247
+ except Exception as e:
248
+ print(f"Error during extraction: {str(e)}")
249
+ raise
250
+
251
+ print("Contents of destination directory:")
252
+ for root, dirs, files in os.walk(destination):
253
+ print(f"\nDirectory: {root}")
254
+ if dirs:
255
+ print(" Subdirectories:", dirs)
256
+ if files:
257
+ print(" Files:", files)
258
+
259
+ def find_model_files(model_dir):
260
+ """Find model files in the given directory and its subdirectories."""
261
+ model_files = {}
262
+
263
+ # Check if this is an MMS model
264
+ is_mms = 'mms' in os.path.basename(model_dir).lower()
265
+
266
+ for root, _, files in os.walk(model_dir):
267
+ for file in files:
268
+ file_path = os.path.join(root, file)
269
+
270
+ # Model file
271
+ if file.endswith('.onnx'):
272
+ model_files['model'] = file_path
273
+
274
+ # Tokens file
275
+ elif file == 'tokens.txt':
276
+ model_files['tokens'] = file_path
277
+
278
+ # Lexicon file (only for non-MMS models)
279
+ elif file == 'lexicon.txt' and not is_mms:
280
+ model_files['lexicon'] = file_path
281
+
282
+ # Create empty lexicon file if needed (only for non-MMS models)
283
+ if not is_mms and 'model' in model_files and 'lexicon' not in model_files:
284
+ model_dir = os.path.dirname(model_files['model'])
285
+ lexicon_path = os.path.join(model_dir, 'lexicon.txt')
286
+ with open(lexicon_path, 'w', encoding='utf-8') as f:
287
+ pass # Create empty file
288
+ model_files['lexicon'] = lexicon_path
289
+
290
+ return model_files if 'model' in model_files else {}
291
+
292
+ def generate_audio(text, model_info):
293
+ """Generate audio from text using the specified model."""
294
+ try:
295
+ model_dir = os.path.join("./models", model_info['id'])
296
+
297
+ print(f"\nLooking for model in: {model_dir}")
298
+
299
+ # Download model if it doesn't exist
300
+ if not os.path.exists(model_dir):
301
+ print(f"Model directory doesn't exist, downloading {model_info['id']}...")
302
+ os.makedirs(model_dir, exist_ok=True)
303
+ download_and_extract_model(model_info['url'], model_dir)
304
+
305
+ print(f"Contents of {model_dir}:")
306
+ for item in os.listdir(model_dir):
307
+ item_path = os.path.join(model_dir, item)
308
+ if os.path.isdir(item_path):
309
+ print(f" Directory: {item}")
310
+ print(f" Contents: {os.listdir(item_path)}")
311
+ else:
312
+ print(f" File: {item}")
313
+
314
+ # Find and validate model files
315
+ model_files = find_model_files(model_dir)
316
+ if not model_files or 'model' not in model_files:
317
+ raise ValueError(f"Could not find required model files in {model_dir}")
318
+
319
+ print("\nFound model files:")
320
+ print(f"Model: {model_files['model']}")
321
+ print(f"Tokens: {model_files.get('tokens', 'Not found')}")
322
+ print(f"Lexicon: {model_files.get('lexicon', 'Not required for MMS')}\n")
323
+
324
+ # Check if this is an MMS model
325
+ is_mms = 'mms' in os.path.basename(model_dir).lower()
326
+
327
+ # Create configuration based on model type
328
+ if is_mms:
329
+ if 'tokens' not in model_files or not os.path.exists(model_files['tokens']):
330
+ raise ValueError("tokens.txt is required for MMS models")
331
+
332
+ # MMS models use tokens.txt and no lexicon
333
+ vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
334
+ model_files['model'], # model
335
+ '', # lexicon
336
+ model_files['tokens'], # tokens
337
+ '', # data_dir
338
+ '', # dict_dir
339
+ 0.667, # noise_scale
340
+ 0.8, # noise_scale_w
341
+ 1.0 # length_scale
342
+ )
343
+ else:
344
+ # Non-MMS models use lexicon.txt
345
+ if 'tokens' not in model_files or not os.path.exists(model_files['tokens']):
346
+ raise ValueError("tokens.txt is required for VITS models")
347
+
348
+ # Set data dir if it exists
349
+ espeak_data = os.path.join(os.path.dirname(model_files['model']), 'espeak-ng-data')
350
+ data_dir = espeak_data if os.path.exists(espeak_data) else ''
351
+
352
+ # Get lexicon path if it exists
353
+ lexicon = model_files.get('lexicon', '') if os.path.exists(model_files.get('lexicon', '')) else ''
354
+
355
+ # Create VITS model config
356
+ vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
357
+ model_files['model'], # model
358
+ lexicon, # lexicon
359
+ model_files['tokens'], # tokens
360
+ data_dir, # data_dir
361
+ '', # dict_dir
362
+ 0.667, # noise_scale
363
+ 0.8, # noise_scale_w
364
+ 1.0 # length_scale
365
+ )
366
+
367
+ # Create the model config with VITS
368
+ model_config = sherpa_onnx.OfflineTtsModelConfig()
369
+ model_config.vits = vits_config
370
+
371
+ # Create TTS configuration
372
+ config = sherpa_onnx.OfflineTtsConfig(
373
+ model=model_config,
374
+ max_num_sentences=2
375
+ )
376
+
377
+ # Initialize TTS engine
378
+ tts = sherpa_onnx.OfflineTts(config)
379
+
380
+ # Generate audio
381
+ audio_data = tts.generate(text)
382
+
383
+ # Ensure we have valid audio data
384
+ if audio_data is None or len(audio_data.samples) == 0:
385
+ raise ValueError("Failed to generate audio - no data generated")
386
+
387
+ # Convert samples list to numpy array and normalize
388
+ audio_array = np.array(audio_data.samples, dtype=np.float32)
389
+ if np.any(audio_array): # Check if array is not all zeros
390
+ audio_array = audio_array / np.abs(audio_array).max()
391
+ else:
392
+ raise ValueError("Generated audio is empty")
393
+
394
+ # Return in Gradio's expected format (numpy array, sample rate)
395
+ return (audio_array, audio_data.sample_rate)
396
+
397
+ except Exception as e:
398
+ error_msg = str(e)
399
+ # Check for OOV or token conversion errors
400
+ if "out of vocabulary" in error_msg.lower() or "token" in error_msg.lower():
401
+ error_msg = f"Text contains unsupported characters: {error_msg}"
402
+ print(f"Error generating audio: {error_msg}")
403
+ print(f"Error in TTS generation: {error_msg}")
404
+ raise
405
+
406
+ def tts_interface(selected_model, text, translate_enabled, status_output):
407
+ try:
408
+ if not text.strip():
409
+ return None, "Please enter some text"
410
+
411
+ # Get model ID from the display name mapping
412
+ model_id = models_by_display.get(selected_model)
413
+ if not model_id or model_id not in models:
414
+ return None, "Please select a model"
415
+
416
+ model_info = models[model_id]
417
+
418
+ # Check if this is an MMS model
419
+ is_mms = 'mms' in model_id.lower()
420
+
421
+ # Get the language code and check if translation is needed
422
+ lang_code = get_language_code(model_info)
423
+ translate_code = get_translate_code(lang_code)
424
+
425
+ # For MMS models, we always need to translate
426
+ if is_mms:
427
+ if not translate_code:
428
+ return None, f"Cannot determine translation target language from code: {lang_code}"
429
+ print(f"MMS model detected, translating to {translate_code}")
430
+ text = translate_text(text, "en", translate_code)
431
+ # For other models, check if translation is enabled and needed
432
+ elif translate_enabled and translate_code and translate_code != "en":
433
+ if not translate_code:
434
+ return None, f"Cannot determine translation target language from code: {lang_code}"
435
+ print(f"Will translate to {translate_code} (from ISO code {lang_code})")
436
+ text = translate_text(text, "en", translate_code)
437
+
438
+ try:
439
+ # Update status with language info
440
+ lang_info = model_info.get('language', [{}])[0]
441
+ lang_name = lang_info.get('language_name', 'Unknown')
442
+ voice_name = model_info.get('name', model_id)
443
+ status = f"Generating speech using {voice_name} ({lang_name})..."
444
+
445
+ # Generate audio
446
+ audio_data, sample_rate = generate_audio(text, model_info)
447
+
448
+ return (sample_rate, audio_data), f"Generated speech using {voice_name} ({lang_name})"
449
+
450
+ except ValueError as e:
451
+ # Handle known errors with user-friendly messages
452
+ error_msg = str(e)
453
+ if "cannot process some words" in error_msg.lower():
454
+ return None, error_msg
455
+ return None, f"Error: {error_msg}"
456
+
457
+ except Exception as e:
458
+ print(f"Error in TTS generation: {str(e)}")
459
+ error_msg = str(e)
460
+ return None, f"Error: {error_msg}"
461
+
462
+ # Gradio Interface
463
+ with gr.Blocks() as app:
464
+ gr.Markdown("# Sherpa-ONNX TTS Demo")
465
+ with gr.Row():
466
+ with gr.Column():
467
+ model_dropdown = gr.Dropdown(
468
+ choices=dropdown_choices,
469
+ label="Select Model",
470
+ value=dropdown_choices[0] if dropdown_choices else None
471
+ )
472
+ text_input = gr.Textbox(
473
+ label="Text to speak",
474
+ placeholder="Enter text here...",
475
+ lines=3
476
+ )
477
+ translate_checkbox = gr.Checkbox(
478
+ label="Translate to model language",
479
+ value=False
480
+ )
481
+ with gr.Row():
482
+ generate_btn = gr.Button("Generate Audio")
483
+ stop_btn = gr.Button("Stop")
484
+
485
+ with gr.Column():
486
+ audio_output = gr.Audio(
487
+ label="Generated Audio",
488
+ type="numpy"
489
+ )
490
+ status_text = gr.Textbox(
491
+ label="Status",
492
+ interactive=False
493
+ )
494
+
495
+ # Handle model selection to update translate checkbox
496
+ def update_translate_checkbox(selected_model):
497
+ """Update visibility of translate checkbox based on selected model's language."""
498
+ try:
499
+ # Find the model info for the selected model
500
+ for lang_group in models_by_lang.values():
501
+ for display_name, model_id in lang_group:
502
+ if display_name == selected_model:
503
+ model_info = models[model_id]
504
+ lang_info = model_info.get('language', [{}])[0]
505
+ lang_code = lang_info.get('lang_code', '')
506
+ return {"visible": lang_code != 'en'}
507
+ return {"visible": False}
508
+ except Exception as e:
509
+ print(f"Error updating translate checkbox: {str(e)}")
510
+ return {"visible": False}
511
+
512
+ model_dropdown.change(
513
+ fn=update_translate_checkbox,
514
+ inputs=[model_dropdown],
515
+ outputs=[translate_checkbox]
516
+ )
517
+
518
+ # Set up event handlers
519
+ gen_event = generate_btn.click(
520
+ fn=tts_interface,
521
+ inputs=[model_dropdown, text_input, translate_checkbox, status_text],
522
+ outputs=[audio_output, status_text]
523
+ )
524
+
525
+ stop_btn.click(
526
+ fn=None,
527
+ cancels=gen_event,
528
+ queue=False
529
+ )
530
+
531
+ app.launch()