spark-tts commited on
Commit
8db57fe
·
1 Parent(s): fc905b2

fix webui OOM bug

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. webui.py +92 -31
README.md CHANGED
@@ -103,9 +103,9 @@ python -m cli.inference \
103
  --prompt_speech_path "path/to/prompt_audio"
104
  ```
105
 
106
- **UI Usage**
107
 
108
- You can start the UI interface by running `python webui.py`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio.
109
 
110
 
111
  | **Voice Cloning** | **Voice Creation** |
 
103
  --prompt_speech_path "path/to/prompt_audio"
104
  ```
105
 
106
+ **Web UI Usage**
107
 
108
+ You can start the UI interface by running `python webui.py --device 0`, which allows you to perform Voice Cloning and Voice Creation. Voice Cloning supports uploading reference audio or directly recording the audio.
109
 
110
 
111
  | **Voice Cloning** | **Voice Creation** |
webui.py CHANGED
@@ -17,6 +17,7 @@ import os
17
  import torch
18
  import soundfile as sf
19
  import logging
 
20
  import gradio as gr
21
  from datetime import datetime
22
  from cli.SparkTTS import SparkTTS
@@ -71,35 +72,53 @@ def run_tts(
71
 
72
  logging.info(f"Audio saved at: {save_path}")
73
 
74
- return save_path, model # Return model along with audio path
75
-
76
-
77
- def voice_clone(text, model, prompt_text, prompt_wav_upload, prompt_wav_record):
78
- """Gradio interface for TTS with prompt speech input."""
79
- # Determine prompt speech (from audio file or recording)
80
- prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
81
- prompt_text = None if len(prompt_text) < 2 else prompt_text
82
- audio_output_path, model = run_tts(
83
- text, model, prompt_text=prompt_text, prompt_speech=prompt_speech
84
- )
85
-
86
- return audio_output_path, model
87
-
88
-
89
- def voice_creation(text, model, gender, pitch, speed):
90
- """Gradio interface for TTS with control over voice attributes."""
91
- pitch = LEVELS_MAP_UI[int(pitch)]
92
- speed = LEVELS_MAP_UI[int(speed)]
93
- audio_output_path, model = run_tts(
94
- text, model, gender=gender, pitch=pitch, speed=speed
95
- )
96
- return audio_output_path, model
97
 
98
 
99
  def build_ui(model_dir, device=0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  with gr.Blocks() as demo:
101
- # Initialize model
102
- model = initialize_model(model_dir, device=device)
103
  # Use HTML for centered title
104
  gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
105
  with gr.Tabs():
@@ -141,12 +160,11 @@ def build_ui(model_dir, device=0):
141
  voice_clone,
142
  inputs=[
143
  text_input,
144
- gr.State(model),
145
  prompt_text_input,
146
  prompt_wav_upload,
147
  prompt_wav_record,
148
  ],
149
- outputs=[audio_output, gr.State(model)],
150
  )
151
 
152
  # Voice Creation Tab
@@ -180,13 +198,56 @@ def build_ui(model_dir, device=0):
180
  )
181
  create_button.click(
182
  voice_creation,
183
- inputs=[text_input_creation, gr.State(model), gender, pitch, speed],
184
- outputs=[audio_output, gr.State(model)],
185
  )
186
 
187
  return demo
188
 
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if __name__ == "__main__":
191
- demo = build_ui(model_dir="pretrained_models/Spark-TTS-0.5B", device=0)
192
- demo.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  import torch
18
  import soundfile as sf
19
  import logging
20
+ import argparse
21
  import gradio as gr
22
  from datetime import datetime
23
  from cli.SparkTTS import SparkTTS
 
72
 
73
  logging.info(f"Audio saved at: {save_path}")
74
 
75
+ return save_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  def build_ui(model_dir, device=0):
79
+
80
+ # Initialize model
81
+ model = initialize_model(model_dir, device=device)
82
+
83
+ # Define callback function for voice cloning
84
+ def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
85
+ """
86
+ Gradio callback to clone voice using text and optional prompt speech.
87
+ - text: The input text to be synthesised.
88
+ - prompt_text: Additional textual info for the prompt (optional).
89
+ - prompt_wav_upload/prompt_wav_record: Audio files used as reference.
90
+ """
91
+ prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
92
+ prompt_text_clean = None if len(prompt_text) < 2 else prompt_text
93
+
94
+ audio_output_path = run_tts(
95
+ text,
96
+ model,
97
+ prompt_text=prompt_text_clean,
98
+ prompt_speech=prompt_speech
99
+ )
100
+ return audio_output_path
101
+
102
+ # Define callback function for creating new voices
103
+ def voice_creation(text, gender, pitch, speed):
104
+ """
105
+ Gradio callback to create a synthetic voice with adjustable parameters.
106
+ - text: The input text for synthesis.
107
+ - gender: 'male' or 'female'.
108
+ - pitch/speed: Ranges mapped by LEVELS_MAP_UI.
109
+ """
110
+ pitch_val = LEVELS_MAP_UI[int(pitch)]
111
+ speed_val = LEVELS_MAP_UI[int(speed)]
112
+ audio_output_path = run_tts(
113
+ text,
114
+ model,
115
+ gender=gender,
116
+ pitch=pitch_val,
117
+ speed=speed_val
118
+ )
119
+ return audio_output_path
120
+
121
  with gr.Blocks() as demo:
 
 
122
  # Use HTML for centered title
123
  gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
124
  with gr.Tabs():
 
160
  voice_clone,
161
  inputs=[
162
  text_input,
 
163
  prompt_text_input,
164
  prompt_wav_upload,
165
  prompt_wav_record,
166
  ],
167
+ outputs=[audio_output],
168
  )
169
 
170
  # Voice Creation Tab
 
198
  )
199
  create_button.click(
200
  voice_creation,
201
+ inputs=[text_input_creation, gender, pitch, speed],
202
+ outputs=[audio_output],
203
  )
204
 
205
  return demo
206
 
207
 
208
+ def parse_arguments():
209
+ """
210
+ Parse command-line arguments such as model directory and device ID.
211
+ """
212
+ parser = argparse.ArgumentParser(description="Spark TTS Gradio server.")
213
+ parser.add_argument(
214
+ "--model_dir",
215
+ type=str,
216
+ default="pretrained_models/Spark-TTS-0.5B",
217
+ help="Path to the model directory."
218
+ )
219
+ parser.add_argument(
220
+ "--device",
221
+ type=int,
222
+ default=0,
223
+ help="ID of the GPU device to use (e.g., 0 for cuda:0)."
224
+ )
225
+ parser.add_argument(
226
+ "--server_name",
227
+ type=str,
228
+ default="0.0.0.0",
229
+ help="Server host/IP for Gradio app."
230
+ )
231
+ parser.add_argument(
232
+ "--server_port",
233
+ type=int,
234
+ default=7860,
235
+ help="Server port for Gradio app."
236
+ )
237
+ return parser.parse_args()
238
+
239
  if __name__ == "__main__":
240
+ # Parse command-line arguments
241
+ args = parse_arguments()
242
+
243
+ # Build the Gradio demo by specifying the model directory and GPU device
244
+ demo = build_ui(
245
+ model_dir=args.model_dir,
246
+ device=args.device
247
+ )
248
+
249
+ # Launch Gradio with the specified server name and port
250
+ demo.launch(
251
+ server_name=args.server_name,
252
+ server_port=args.server_port
253
+ )