Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import gradio as gr
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
import requests
|
7 |
+
import tarfile
|
8 |
+
from pathlib import Path
|
9 |
+
import soundfile as sf
|
10 |
+
import sherpa_onnx
|
11 |
+
from deep_translator import GoogleTranslator
|
12 |
+
import numpy as np
|
13 |
+
from iso639 import Lang
|
14 |
+
import pycountry
|
15 |
+
|
16 |
+
# Load model JSON
|
17 |
+
MODEL_JSON_URL = "https://github.com/willwade/tts-wrapper/blob/main/tts_wrapper/engines/sherpaonnx/merged_models.json"
|
18 |
+
MODEL_JSON_PATH = "./models.json"
|
19 |
+
|
20 |
+
# Load models
|
21 |
+
if not os.path.exists(MODEL_JSON_PATH):
|
22 |
+
response = requests.get(MODEL_JSON_URL.replace("/blob/", "/raw/"))
|
23 |
+
with open(MODEL_JSON_PATH, "w") as f:
|
24 |
+
f.write(response.text)
|
25 |
+
|
26 |
+
with open(MODEL_JSON_PATH, "r") as f:
|
27 |
+
models = json.load(f)
|
28 |
+
|
29 |
+
def get_model_display_info(model_info):
|
30 |
+
"""Create a display string for a model."""
|
31 |
+
# Get language info
|
32 |
+
lang_info = model_info.get('language', [{}])[0]
|
33 |
+
lang_name = lang_info.get('language_name', lang_info.get('Language Name', 'Unknown'))
|
34 |
+
lang_code = lang_info.get('lang_code', lang_info.get('Iso Code', 'Unknown'))
|
35 |
+
|
36 |
+
# Get model info
|
37 |
+
voice_name = model_info.get('name', model_info.get('id', 'Unknown'))
|
38 |
+
developer = model_info.get('developer', '')
|
39 |
+
quality = model_info.get('quality', 'MMS' if 'mms' in voice_name.lower() else '')
|
40 |
+
|
41 |
+
# Create display name
|
42 |
+
model_display = f"{voice_name} ({developer}"
|
43 |
+
if quality:
|
44 |
+
model_display += f" - {quality}"
|
45 |
+
model_display += ")"
|
46 |
+
|
47 |
+
# Combine language and model info
|
48 |
+
return f"{lang_name} ({lang_code}) | {model_display}"
|
49 |
+
|
50 |
+
# Group models by language
|
51 |
+
models_by_lang = {}
|
52 |
+
for model_id, model_info in models.items():
|
53 |
+
# Get language info from the first language in the list
|
54 |
+
lang_info = model_info.get('language', [{}])[0]
|
55 |
+
lang_name = lang_info.get('language_name', lang_info.get('Language Name', 'Unknown'))
|
56 |
+
lang_code = lang_info.get('lang_code', lang_info.get('Iso Code', 'Unknown'))
|
57 |
+
group_key = f"{lang_name} ({lang_code})"
|
58 |
+
|
59 |
+
if group_key not in models_by_lang:
|
60 |
+
models_by_lang[group_key] = []
|
61 |
+
|
62 |
+
# Add model to language group
|
63 |
+
models_by_lang[group_key].append((get_model_display_info(model_info), model_id))
|
64 |
+
|
65 |
+
# Create dropdown choices with model IDs as values
|
66 |
+
dropdown_choices = []
|
67 |
+
models_by_display = {} # Map display names to model IDs
|
68 |
+
for lang, model_list in sorted(models_by_lang.items()):
|
69 |
+
# Add all models in this language group
|
70 |
+
for display_name, model_id in sorted(model_list):
|
71 |
+
dropdown_choices.append(display_name)
|
72 |
+
models_by_display[display_name] = model_id
|
73 |
+
|
74 |
+
def get_language_code(model_info):
|
75 |
+
"""Get the language code."""
|
76 |
+
if not model_info.get("language"):
|
77 |
+
return None
|
78 |
+
|
79 |
+
lang_info = model_info["language"][0]
|
80 |
+
# Try both key formats for language code
|
81 |
+
lang_code = lang_info.get("lang_code", lang_info.get("Iso Code", "")).lower()
|
82 |
+
return lang_code
|
83 |
+
|
84 |
+
# Special cases for codes not in ISO standard
|
85 |
+
SPECIAL_CODES = {
|
86 |
+
"cmn": "zh", # Mandarin Chinese
|
87 |
+
"yue": "zh", # Cantonese
|
88 |
+
"pi": "el", # Pali (using Greek for this model)
|
89 |
+
"guj": "gu", # Gujarati
|
90 |
+
}
|
91 |
+
|
92 |
+
def get_translate_code(iso_code):
|
93 |
+
"""Convert ISO code to Google Translate code."""
|
94 |
+
if not iso_code:
|
95 |
+
return None
|
96 |
+
|
97 |
+
# Remove any script or dialect specifiers
|
98 |
+
base_code = iso_code.split('-')[0].lower()
|
99 |
+
|
100 |
+
# Check special cases first
|
101 |
+
if base_code in SPECIAL_CODES:
|
102 |
+
return SPECIAL_CODES[base_code]
|
103 |
+
|
104 |
+
try:
|
105 |
+
# Try to get the ISO 639-1 (2-letter) code
|
106 |
+
lang = Lang(base_code)
|
107 |
+
return lang.pt1
|
108 |
+
except:
|
109 |
+
# If that fails, try to find a matching language in pycountry
|
110 |
+
try:
|
111 |
+
lang = pycountry.languages.get(alpha_3=base_code)
|
112 |
+
if lang and hasattr(lang, 'alpha_2'):
|
113 |
+
return lang.alpha_2
|
114 |
+
except:
|
115 |
+
pass
|
116 |
+
|
117 |
+
# If all else fails, try to use the original code
|
118 |
+
if len(base_code) == 2:
|
119 |
+
return base_code
|
120 |
+
|
121 |
+
return None
|
122 |
+
|
123 |
+
def translate_text(input_text, source_lang="en", target_lang="en"):
|
124 |
+
"""Translate text using Google Translator."""
|
125 |
+
if source_lang == target_lang:
|
126 |
+
return input_text
|
127 |
+
try:
|
128 |
+
# Convert ISO code to Google Translate code
|
129 |
+
target_lang = get_translate_code(target_lang)
|
130 |
+
|
131 |
+
try:
|
132 |
+
translated = GoogleTranslator(source=source_lang, target=target_lang).translate(input_text)
|
133 |
+
return f"{translated} (translated from: {input_text})"
|
134 |
+
except Exception as first_error:
|
135 |
+
# If the first attempt fails with the mapped code, try with the original
|
136 |
+
try:
|
137 |
+
translated = GoogleTranslator(source=source_lang, target=target_lang).translate(input_text)
|
138 |
+
return f"{translated} (translated from: {input_text})"
|
139 |
+
except:
|
140 |
+
raise first_error
|
141 |
+
|
142 |
+
except Exception as e:
|
143 |
+
print(f"Translation error: {str(e)} for target language: {target_lang}")
|
144 |
+
print(f"Attempted to use language code: {target_lang}")
|
145 |
+
return f"Translation Error: Could not translate to {target_lang}. Original text: {input_text}"
|
146 |
+
|
147 |
+
def download_and_extract_model(url, destination):
|
148 |
+
"""Download and extract the model files."""
|
149 |
+
print(f"Downloading from URL: {url}")
|
150 |
+
print(f"Destination: {destination}")
|
151 |
+
|
152 |
+
# Convert Hugging Face URL format if needed
|
153 |
+
if "huggingface.co" in url:
|
154 |
+
# Replace /tree/main/ with /resolve/main/ for direct file download
|
155 |
+
base_url = url.replace("/tree/main/", "/resolve/main/")
|
156 |
+
model_id = base_url.split("/")[-1]
|
157 |
+
|
158 |
+
# Check if this is an MMS model
|
159 |
+
is_mms_model = "mms-tts-multilingual-models-onnx" in url
|
160 |
+
|
161 |
+
if is_mms_model:
|
162 |
+
# MMS models have both model.onnx and tokens.txt
|
163 |
+
model_url = f"{base_url}/model.onnx"
|
164 |
+
tokens_url = f"{base_url}/tokens.txt"
|
165 |
+
|
166 |
+
# Download model.onnx
|
167 |
+
print("Downloading model.onnx...")
|
168 |
+
model_path = os.path.join(destination, "model.onnx")
|
169 |
+
response = requests.get(model_url, stream=True)
|
170 |
+
if response.status_code != 200:
|
171 |
+
raise Exception(f"Failed to download model from {model_url}. Status code: {response.status_code}")
|
172 |
+
|
173 |
+
total_size = int(response.headers.get('content-length', 0))
|
174 |
+
block_size = 8192
|
175 |
+
downloaded = 0
|
176 |
+
|
177 |
+
print(f"Total size: {total_size / (1024*1024):.1f} MB")
|
178 |
+
with open(model_path, "wb") as f:
|
179 |
+
for chunk in response.iter_content(chunk_size=block_size):
|
180 |
+
if chunk:
|
181 |
+
f.write(chunk)
|
182 |
+
downloaded += len(chunk)
|
183 |
+
if total_size > 0:
|
184 |
+
percent = int((downloaded / total_size) * 100)
|
185 |
+
if percent % 10 == 0:
|
186 |
+
print(f" {percent}%", end="", flush=True)
|
187 |
+
print("\nModel download complete")
|
188 |
+
|
189 |
+
# Download tokens.txt
|
190 |
+
print("Downloading tokens.txt...")
|
191 |
+
tokens_path = os.path.join(destination, "tokens.txt")
|
192 |
+
response = requests.get(tokens_url, stream=True)
|
193 |
+
if response.status_code != 200:
|
194 |
+
raise Exception(f"Failed to download tokens from {tokens_url}. Status code: {response.status_code}")
|
195 |
+
|
196 |
+
with open(tokens_path, "wb") as f:
|
197 |
+
f.write(response.content)
|
198 |
+
print("Tokens download complete")
|
199 |
+
|
200 |
+
return
|
201 |
+
else:
|
202 |
+
# Other models are stored as tar.bz2 files
|
203 |
+
url = f"{base_url}.tar.bz2"
|
204 |
+
|
205 |
+
# Try the URL
|
206 |
+
response = requests.get(url, stream=True)
|
207 |
+
if response.status_code != 200:
|
208 |
+
raise Exception(f"Failed to download model from {url}. Status code: {response.status_code}")
|
209 |
+
|
210 |
+
# Check if this is a Git LFS file pointer
|
211 |
+
content_start = response.content[:100].decode('utf-8', errors='ignore')
|
212 |
+
if content_start.startswith('version https://git-lfs.github.com/spec/v1'):
|
213 |
+
raise Exception(f"Received Git LFS pointer instead of file content from {url}")
|
214 |
+
|
215 |
+
# Create model directory if it doesn't exist
|
216 |
+
os.makedirs(destination, exist_ok=True)
|
217 |
+
|
218 |
+
# For non-MMS models, handle tar.bz2 files
|
219 |
+
tar_path = os.path.join(destination, "model.tar.bz2")
|
220 |
+
|
221 |
+
# Download the file
|
222 |
+
print("Downloading model archive...")
|
223 |
+
response = requests.get(url, stream=True)
|
224 |
+
total_size = int(response.headers.get('content-length', 0))
|
225 |
+
block_size = 8192
|
226 |
+
downloaded = 0
|
227 |
+
|
228 |
+
print(f"Total size: {total_size / (1024*1024):.1f} MB")
|
229 |
+
with open(tar_path, "wb") as f:
|
230 |
+
for chunk in response.iter_content(chunk_size=block_size):
|
231 |
+
if chunk:
|
232 |
+
f.write(chunk)
|
233 |
+
downloaded += len(chunk)
|
234 |
+
if total_size > 0:
|
235 |
+
percent = int((downloaded / total_size) * 100)
|
236 |
+
if percent % 10 == 0:
|
237 |
+
print(f" {percent}%", end="", flush=True)
|
238 |
+
print("\nDownload complete")
|
239 |
+
|
240 |
+
# Extract the tar.bz2 file
|
241 |
+
print(f"Extracting {tar_path} to {destination}")
|
242 |
+
try:
|
243 |
+
with tarfile.open(tar_path, "r:bz2") as tar:
|
244 |
+
tar.extractall(path=destination)
|
245 |
+
os.remove(tar_path)
|
246 |
+
print("Extraction complete")
|
247 |
+
except Exception as e:
|
248 |
+
print(f"Error during extraction: {str(e)}")
|
249 |
+
raise
|
250 |
+
|
251 |
+
print("Contents of destination directory:")
|
252 |
+
for root, dirs, files in os.walk(destination):
|
253 |
+
print(f"\nDirectory: {root}")
|
254 |
+
if dirs:
|
255 |
+
print(" Subdirectories:", dirs)
|
256 |
+
if files:
|
257 |
+
print(" Files:", files)
|
258 |
+
|
259 |
+
def find_model_files(model_dir):
|
260 |
+
"""Find model files in the given directory and its subdirectories."""
|
261 |
+
model_files = {}
|
262 |
+
|
263 |
+
# Check if this is an MMS model
|
264 |
+
is_mms = 'mms' in os.path.basename(model_dir).lower()
|
265 |
+
|
266 |
+
for root, _, files in os.walk(model_dir):
|
267 |
+
for file in files:
|
268 |
+
file_path = os.path.join(root, file)
|
269 |
+
|
270 |
+
# Model file
|
271 |
+
if file.endswith('.onnx'):
|
272 |
+
model_files['model'] = file_path
|
273 |
+
|
274 |
+
# Tokens file
|
275 |
+
elif file == 'tokens.txt':
|
276 |
+
model_files['tokens'] = file_path
|
277 |
+
|
278 |
+
# Lexicon file (only for non-MMS models)
|
279 |
+
elif file == 'lexicon.txt' and not is_mms:
|
280 |
+
model_files['lexicon'] = file_path
|
281 |
+
|
282 |
+
# Create empty lexicon file if needed (only for non-MMS models)
|
283 |
+
if not is_mms and 'model' in model_files and 'lexicon' not in model_files:
|
284 |
+
model_dir = os.path.dirname(model_files['model'])
|
285 |
+
lexicon_path = os.path.join(model_dir, 'lexicon.txt')
|
286 |
+
with open(lexicon_path, 'w', encoding='utf-8') as f:
|
287 |
+
pass # Create empty file
|
288 |
+
model_files['lexicon'] = lexicon_path
|
289 |
+
|
290 |
+
return model_files if 'model' in model_files else {}
|
291 |
+
|
292 |
+
def generate_audio(text, model_info):
|
293 |
+
"""Generate audio from text using the specified model."""
|
294 |
+
try:
|
295 |
+
model_dir = os.path.join("./models", model_info['id'])
|
296 |
+
|
297 |
+
print(f"\nLooking for model in: {model_dir}")
|
298 |
+
|
299 |
+
# Download model if it doesn't exist
|
300 |
+
if not os.path.exists(model_dir):
|
301 |
+
print(f"Model directory doesn't exist, downloading {model_info['id']}...")
|
302 |
+
os.makedirs(model_dir, exist_ok=True)
|
303 |
+
download_and_extract_model(model_info['url'], model_dir)
|
304 |
+
|
305 |
+
print(f"Contents of {model_dir}:")
|
306 |
+
for item in os.listdir(model_dir):
|
307 |
+
item_path = os.path.join(model_dir, item)
|
308 |
+
if os.path.isdir(item_path):
|
309 |
+
print(f" Directory: {item}")
|
310 |
+
print(f" Contents: {os.listdir(item_path)}")
|
311 |
+
else:
|
312 |
+
print(f" File: {item}")
|
313 |
+
|
314 |
+
# Find and validate model files
|
315 |
+
model_files = find_model_files(model_dir)
|
316 |
+
if not model_files or 'model' not in model_files:
|
317 |
+
raise ValueError(f"Could not find required model files in {model_dir}")
|
318 |
+
|
319 |
+
print("\nFound model files:")
|
320 |
+
print(f"Model: {model_files['model']}")
|
321 |
+
print(f"Tokens: {model_files.get('tokens', 'Not found')}")
|
322 |
+
print(f"Lexicon: {model_files.get('lexicon', 'Not required for MMS')}\n")
|
323 |
+
|
324 |
+
# Check if this is an MMS model
|
325 |
+
is_mms = 'mms' in os.path.basename(model_dir).lower()
|
326 |
+
|
327 |
+
# Create configuration based on model type
|
328 |
+
if is_mms:
|
329 |
+
if 'tokens' not in model_files or not os.path.exists(model_files['tokens']):
|
330 |
+
raise ValueError("tokens.txt is required for MMS models")
|
331 |
+
|
332 |
+
# MMS models use tokens.txt and no lexicon
|
333 |
+
vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
|
334 |
+
model_files['model'], # model
|
335 |
+
'', # lexicon
|
336 |
+
model_files['tokens'], # tokens
|
337 |
+
'', # data_dir
|
338 |
+
'', # dict_dir
|
339 |
+
0.667, # noise_scale
|
340 |
+
0.8, # noise_scale_w
|
341 |
+
1.0 # length_scale
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
# Non-MMS models use lexicon.txt
|
345 |
+
if 'tokens' not in model_files or not os.path.exists(model_files['tokens']):
|
346 |
+
raise ValueError("tokens.txt is required for VITS models")
|
347 |
+
|
348 |
+
# Set data dir if it exists
|
349 |
+
espeak_data = os.path.join(os.path.dirname(model_files['model']), 'espeak-ng-data')
|
350 |
+
data_dir = espeak_data if os.path.exists(espeak_data) else ''
|
351 |
+
|
352 |
+
# Get lexicon path if it exists
|
353 |
+
lexicon = model_files.get('lexicon', '') if os.path.exists(model_files.get('lexicon', '')) else ''
|
354 |
+
|
355 |
+
# Create VITS model config
|
356 |
+
vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
|
357 |
+
model_files['model'], # model
|
358 |
+
lexicon, # lexicon
|
359 |
+
model_files['tokens'], # tokens
|
360 |
+
data_dir, # data_dir
|
361 |
+
'', # dict_dir
|
362 |
+
0.667, # noise_scale
|
363 |
+
0.8, # noise_scale_w
|
364 |
+
1.0 # length_scale
|
365 |
+
)
|
366 |
+
|
367 |
+
# Create the model config with VITS
|
368 |
+
model_config = sherpa_onnx.OfflineTtsModelConfig()
|
369 |
+
model_config.vits = vits_config
|
370 |
+
|
371 |
+
# Create TTS configuration
|
372 |
+
config = sherpa_onnx.OfflineTtsConfig(
|
373 |
+
model=model_config,
|
374 |
+
max_num_sentences=2
|
375 |
+
)
|
376 |
+
|
377 |
+
# Initialize TTS engine
|
378 |
+
tts = sherpa_onnx.OfflineTts(config)
|
379 |
+
|
380 |
+
# Generate audio
|
381 |
+
audio_data = tts.generate(text)
|
382 |
+
|
383 |
+
# Ensure we have valid audio data
|
384 |
+
if audio_data is None or len(audio_data.samples) == 0:
|
385 |
+
raise ValueError("Failed to generate audio - no data generated")
|
386 |
+
|
387 |
+
# Convert samples list to numpy array and normalize
|
388 |
+
audio_array = np.array(audio_data.samples, dtype=np.float32)
|
389 |
+
if np.any(audio_array): # Check if array is not all zeros
|
390 |
+
audio_array = audio_array / np.abs(audio_array).max()
|
391 |
+
else:
|
392 |
+
raise ValueError("Generated audio is empty")
|
393 |
+
|
394 |
+
# Return in Gradio's expected format (numpy array, sample rate)
|
395 |
+
return (audio_array, audio_data.sample_rate)
|
396 |
+
|
397 |
+
except Exception as e:
|
398 |
+
error_msg = str(e)
|
399 |
+
# Check for OOV or token conversion errors
|
400 |
+
if "out of vocabulary" in error_msg.lower() or "token" in error_msg.lower():
|
401 |
+
error_msg = f"Text contains unsupported characters: {error_msg}"
|
402 |
+
print(f"Error generating audio: {error_msg}")
|
403 |
+
print(f"Error in TTS generation: {error_msg}")
|
404 |
+
raise
|
405 |
+
|
406 |
+
def tts_interface(selected_model, text, translate_enabled, status_output):
|
407 |
+
try:
|
408 |
+
if not text.strip():
|
409 |
+
return None, "Please enter some text"
|
410 |
+
|
411 |
+
# Get model ID from the display name mapping
|
412 |
+
model_id = models_by_display.get(selected_model)
|
413 |
+
if not model_id or model_id not in models:
|
414 |
+
return None, "Please select a model"
|
415 |
+
|
416 |
+
model_info = models[model_id]
|
417 |
+
|
418 |
+
# Check if this is an MMS model
|
419 |
+
is_mms = 'mms' in model_id.lower()
|
420 |
+
|
421 |
+
# Get the language code and check if translation is needed
|
422 |
+
lang_code = get_language_code(model_info)
|
423 |
+
translate_code = get_translate_code(lang_code)
|
424 |
+
|
425 |
+
# For MMS models, we always need to translate
|
426 |
+
if is_mms:
|
427 |
+
if not translate_code:
|
428 |
+
return None, f"Cannot determine translation target language from code: {lang_code}"
|
429 |
+
print(f"MMS model detected, translating to {translate_code}")
|
430 |
+
text = translate_text(text, "en", translate_code)
|
431 |
+
# For other models, check if translation is enabled and needed
|
432 |
+
elif translate_enabled and translate_code and translate_code != "en":
|
433 |
+
if not translate_code:
|
434 |
+
return None, f"Cannot determine translation target language from code: {lang_code}"
|
435 |
+
print(f"Will translate to {translate_code} (from ISO code {lang_code})")
|
436 |
+
text = translate_text(text, "en", translate_code)
|
437 |
+
|
438 |
+
try:
|
439 |
+
# Update status with language info
|
440 |
+
lang_info = model_info.get('language', [{}])[0]
|
441 |
+
lang_name = lang_info.get('language_name', 'Unknown')
|
442 |
+
voice_name = model_info.get('name', model_id)
|
443 |
+
status = f"Generating speech using {voice_name} ({lang_name})..."
|
444 |
+
|
445 |
+
# Generate audio
|
446 |
+
audio_data, sample_rate = generate_audio(text, model_info)
|
447 |
+
|
448 |
+
return (sample_rate, audio_data), f"Generated speech using {voice_name} ({lang_name})"
|
449 |
+
|
450 |
+
except ValueError as e:
|
451 |
+
# Handle known errors with user-friendly messages
|
452 |
+
error_msg = str(e)
|
453 |
+
if "cannot process some words" in error_msg.lower():
|
454 |
+
return None, error_msg
|
455 |
+
return None, f"Error: {error_msg}"
|
456 |
+
|
457 |
+
except Exception as e:
|
458 |
+
print(f"Error in TTS generation: {str(e)}")
|
459 |
+
error_msg = str(e)
|
460 |
+
return None, f"Error: {error_msg}"
|
461 |
+
|
462 |
+
# Gradio Interface
|
463 |
+
with gr.Blocks() as app:
|
464 |
+
gr.Markdown("# Sherpa-ONNX TTS Demo")
|
465 |
+
with gr.Row():
|
466 |
+
with gr.Column():
|
467 |
+
model_dropdown = gr.Dropdown(
|
468 |
+
choices=dropdown_choices,
|
469 |
+
label="Select Model",
|
470 |
+
value=dropdown_choices[0] if dropdown_choices else None
|
471 |
+
)
|
472 |
+
text_input = gr.Textbox(
|
473 |
+
label="Text to speak",
|
474 |
+
placeholder="Enter text here...",
|
475 |
+
lines=3
|
476 |
+
)
|
477 |
+
translate_checkbox = gr.Checkbox(
|
478 |
+
label="Translate to model language",
|
479 |
+
value=False
|
480 |
+
)
|
481 |
+
with gr.Row():
|
482 |
+
generate_btn = gr.Button("Generate Audio")
|
483 |
+
stop_btn = gr.Button("Stop")
|
484 |
+
|
485 |
+
with gr.Column():
|
486 |
+
audio_output = gr.Audio(
|
487 |
+
label="Generated Audio",
|
488 |
+
type="numpy"
|
489 |
+
)
|
490 |
+
status_text = gr.Textbox(
|
491 |
+
label="Status",
|
492 |
+
interactive=False
|
493 |
+
)
|
494 |
+
|
495 |
+
# Handle model selection to update translate checkbox
|
496 |
+
def update_translate_checkbox(selected_model):
|
497 |
+
"""Update visibility of translate checkbox based on selected model's language."""
|
498 |
+
try:
|
499 |
+
# Find the model info for the selected model
|
500 |
+
for lang_group in models_by_lang.values():
|
501 |
+
for display_name, model_id in lang_group:
|
502 |
+
if display_name == selected_model:
|
503 |
+
model_info = models[model_id]
|
504 |
+
lang_info = model_info.get('language', [{}])[0]
|
505 |
+
lang_code = lang_info.get('lang_code', '')
|
506 |
+
return {"visible": lang_code != 'en'}
|
507 |
+
return {"visible": False}
|
508 |
+
except Exception as e:
|
509 |
+
print(f"Error updating translate checkbox: {str(e)}")
|
510 |
+
return {"visible": False}
|
511 |
+
|
512 |
+
model_dropdown.change(
|
513 |
+
fn=update_translate_checkbox,
|
514 |
+
inputs=[model_dropdown],
|
515 |
+
outputs=[translate_checkbox]
|
516 |
+
)
|
517 |
+
|
518 |
+
# Set up event handlers
|
519 |
+
gen_event = generate_btn.click(
|
520 |
+
fn=tts_interface,
|
521 |
+
inputs=[model_dropdown, text_input, translate_checkbox, status_text],
|
522 |
+
outputs=[audio_output, status_text]
|
523 |
+
)
|
524 |
+
|
525 |
+
stop_btn.click(
|
526 |
+
fn=None,
|
527 |
+
cancels=gen_event,
|
528 |
+
queue=False
|
529 |
+
)
|
530 |
+
|
531 |
+
app.launch()
|