spark-tts
commited on
Commit
·
8db57fe
1
Parent(s):
fc905b2
fix webui OOM bug
Browse files
README.md
CHANGED
@@ -103,9 +103,9 @@ python -m cli.inference \
|
|
103 |
--prompt_speech_path "path/to/prompt_audio"
|
104 |
```
|
105 |
|
106 |
-
**UI Usage**
|
107 |
|
108 |
-
You can start the UI interface by running `python webui.py`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio.
|
109 |
|
110 |
|
111 |
| **Voice Cloning** | **Voice Creation** |
|
|
|
103 |
--prompt_speech_path "path/to/prompt_audio"
|
104 |
```
|
105 |
|
106 |
+
**Web UI Usage**
|
107 |
|
108 |
+
You can start the UI interface by running `python webui.py --device 0`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio.
|
109 |
|
110 |
|
111 |
| **Voice Cloning** | **Voice Creation** |
|
webui.py
CHANGED
@@ -17,6 +17,7 @@ import os
|
|
17 |
import torch
|
18 |
import soundfile as sf
|
19 |
import logging
|
|
|
20 |
import gradio as gr
|
21 |
from datetime import datetime
|
22 |
from cli.SparkTTS import SparkTTS
|
@@ -71,35 +72,53 @@ def run_tts(
|
|
71 |
|
72 |
logging.info(f"Audio saved at: {save_path}")
|
73 |
|
74 |
-
return save_path
|
75 |
-
|
76 |
-
|
77 |
-
def voice_clone(text, model, prompt_text, prompt_wav_upload, prompt_wav_record):
|
78 |
-
"""Gradio interface for TTS with prompt speech input."""
|
79 |
-
# Determine prompt speech (from audio file or recording)
|
80 |
-
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
81 |
-
prompt_text = None if len(prompt_text) < 2 else prompt_text
|
82 |
-
audio_output_path, model = run_tts(
|
83 |
-
text, model, prompt_text=prompt_text, prompt_speech=prompt_speech
|
84 |
-
)
|
85 |
-
|
86 |
-
return audio_output_path, model
|
87 |
-
|
88 |
-
|
89 |
-
def voice_creation(text, model, gender, pitch, speed):
|
90 |
-
"""Gradio interface for TTS with control over voice attributes."""
|
91 |
-
pitch = LEVELS_MAP_UI[int(pitch)]
|
92 |
-
speed = LEVELS_MAP_UI[int(speed)]
|
93 |
-
audio_output_path, model = run_tts(
|
94 |
-
text, model, gender=gender, pitch=pitch, speed=speed
|
95 |
-
)
|
96 |
-
return audio_output_path, model
|
97 |
|
98 |
|
99 |
def build_ui(model_dir, device=0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
with gr.Blocks() as demo:
|
101 |
-
# Initialize model
|
102 |
-
model = initialize_model(model_dir, device=device)
|
103 |
# Use HTML for centered title
|
104 |
gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
|
105 |
with gr.Tabs():
|
@@ -141,12 +160,11 @@ def build_ui(model_dir, device=0):
|
|
141 |
voice_clone,
|
142 |
inputs=[
|
143 |
text_input,
|
144 |
-
gr.State(model),
|
145 |
prompt_text_input,
|
146 |
prompt_wav_upload,
|
147 |
prompt_wav_record,
|
148 |
],
|
149 |
-
outputs=[audio_output
|
150 |
)
|
151 |
|
152 |
# Voice Creation Tab
|
@@ -180,13 +198,56 @@ def build_ui(model_dir, device=0):
|
|
180 |
)
|
181 |
create_button.click(
|
182 |
voice_creation,
|
183 |
-
inputs=[text_input_creation,
|
184 |
-
outputs=[audio_output
|
185 |
)
|
186 |
|
187 |
return demo
|
188 |
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
if __name__ == "__main__":
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
import torch
|
18 |
import soundfile as sf
|
19 |
import logging
|
20 |
+
import argparse
|
21 |
import gradio as gr
|
22 |
from datetime import datetime
|
23 |
from cli.SparkTTS import SparkTTS
|
|
|
72 |
|
73 |
logging.info(f"Audio saved at: {save_path}")
|
74 |
|
75 |
+
return save_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
|
78 |
def build_ui(model_dir, device=0):
|
79 |
+
|
80 |
+
# Initialize model
|
81 |
+
model = initialize_model(model_dir, device=device)
|
82 |
+
|
83 |
+
# Define callback function for voice cloning
|
84 |
+
def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
|
85 |
+
"""
|
86 |
+
Gradio callback to clone voice using text and optional prompt speech.
|
87 |
+
- text: The input text to be synthesised.
|
88 |
+
- prompt_text: Additional textual info for the prompt (optional).
|
89 |
+
- prompt_wav_upload/prompt_wav_record: Audio files used as reference.
|
90 |
+
"""
|
91 |
+
prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
|
92 |
+
prompt_text_clean = None if len(prompt_text) < 2 else prompt_text
|
93 |
+
|
94 |
+
audio_output_path = run_tts(
|
95 |
+
text,
|
96 |
+
model,
|
97 |
+
prompt_text=prompt_text_clean,
|
98 |
+
prompt_speech=prompt_speech
|
99 |
+
)
|
100 |
+
return audio_output_path
|
101 |
+
|
102 |
+
# Define callback function for creating new voices
|
103 |
+
def voice_creation(text, gender, pitch, speed):
|
104 |
+
"""
|
105 |
+
Gradio callback to create a synthetic voice with adjustable parameters.
|
106 |
+
- text: The input text for synthesis.
|
107 |
+
- gender: 'male' or 'female'.
|
108 |
+
- pitch/speed: Ranges mapped by LEVELS_MAP_UI.
|
109 |
+
"""
|
110 |
+
pitch_val = LEVELS_MAP_UI[int(pitch)]
|
111 |
+
speed_val = LEVELS_MAP_UI[int(speed)]
|
112 |
+
audio_output_path = run_tts(
|
113 |
+
text,
|
114 |
+
model,
|
115 |
+
gender=gender,
|
116 |
+
pitch=pitch_val,
|
117 |
+
speed=speed_val
|
118 |
+
)
|
119 |
+
return audio_output_path
|
120 |
+
|
121 |
with gr.Blocks() as demo:
|
|
|
|
|
122 |
# Use HTML for centered title
|
123 |
gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
|
124 |
with gr.Tabs():
|
|
|
160 |
voice_clone,
|
161 |
inputs=[
|
162 |
text_input,
|
|
|
163 |
prompt_text_input,
|
164 |
prompt_wav_upload,
|
165 |
prompt_wav_record,
|
166 |
],
|
167 |
+
outputs=[audio_output],
|
168 |
)
|
169 |
|
170 |
# Voice Creation Tab
|
|
|
198 |
)
|
199 |
create_button.click(
|
200 |
voice_creation,
|
201 |
+
inputs=[text_input_creation, gender, pitch, speed],
|
202 |
+
outputs=[audio_output],
|
203 |
)
|
204 |
|
205 |
return demo
|
206 |
|
207 |
|
208 |
+
def parse_arguments():
|
209 |
+
"""
|
210 |
+
Parse command-line arguments such as model directory and device ID.
|
211 |
+
"""
|
212 |
+
parser = argparse.ArgumentParser(description="Spark TTS Gradio server.")
|
213 |
+
parser.add_argument(
|
214 |
+
"--model_dir",
|
215 |
+
type=str,
|
216 |
+
default="pretrained_models/Spark-TTS-0.5B",
|
217 |
+
help="Path to the model directory."
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--device",
|
221 |
+
type=int,
|
222 |
+
default=0,
|
223 |
+
help="ID of the GPU device to use (e.g., 0 for cuda:0)."
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--server_name",
|
227 |
+
type=str,
|
228 |
+
default="0.0.0.0",
|
229 |
+
help="Server host/IP for Gradio app."
|
230 |
+
)
|
231 |
+
parser.add_argument(
|
232 |
+
"--server_port",
|
233 |
+
type=int,
|
234 |
+
default=7860,
|
235 |
+
help="Server port for Gradio app."
|
236 |
+
)
|
237 |
+
return parser.parse_args()
|
238 |
+
|
239 |
if __name__ == "__main__":
|
240 |
+
# Parse command-line arguments
|
241 |
+
args = parse_arguments()
|
242 |
+
|
243 |
+
# Build the Gradio demo by specifying the model directory and GPU device
|
244 |
+
demo = build_ui(
|
245 |
+
model_dir=args.model_dir,
|
246 |
+
device=args.device
|
247 |
+
)
|
248 |
+
|
249 |
+
# Launch Gradio with the specified server name and port
|
250 |
+
demo.launch(
|
251 |
+
server_name=args.server_name,
|
252 |
+
server_port=args.server_port
|
253 |
+
)
|