spark-tts commited on
Commit
ad5cf60
·
1 Parent(s): 6f15685
Files changed (3) hide show
  1. .gitignore +1 -1
  2. sparktts/utils/token_parser.py +8 -0
  3. webui.py +192 -0
.gitignore CHANGED
@@ -7,7 +7,7 @@ results/
7
  demo/
8
  # C extensions
9
  *.so
10
-
11
  # Distribution / packaging
12
  .Python
13
  build/
 
7
  demo/
8
  # C extensions
9
  *.so
10
+ .gradio/
11
  # Distribution / packaging
12
  .Python
13
  build/
sparktts/utils/token_parser.py CHANGED
@@ -19,6 +19,14 @@ LEVELS_MAP = {
19
  "very_high": 4,
20
  }
21
 
 
 
 
 
 
 
 
 
22
  GENDER_MAP = {
23
  "female": 0,
24
  "male": 1,
 
19
  "very_high": 4,
20
  }
21
 
22
+ LEVELS_MAP_UI = {
23
+ 1: 'very_low',
24
+ 2: 'low',
25
+ 3: 'moderate',
26
+ 4: 'high',
27
+ 5: 'very_high'
28
+ }
29
+
30
  GENDER_MAP = {
31
  "female": 0,
32
  "male": 1,
webui.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 SparkAudio
2
+ # 2025 Xinsheng Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import torch
18
+ import soundfile as sf
19
+ import logging
20
+ import gradio as gr
21
+ from datetime import datetime
22
+ from cli.SparkTTS import SparkTTS
23
+ from sparktts.utils.token_parser import LEVELS_MAP_UI
24
+
25
+
26
+ def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
27
+ """Load the model once at the beginning."""
28
+ logging.info(f"Loading model from: {model_dir}")
29
+ device = torch.device(f"cuda:{device}")
30
+ model = SparkTTS(model_dir, device)
31
+ return model
32
+
33
+
34
+ def run_tts(
35
+ text,
36
+ model,
37
+ prompt_text=None,
38
+ prompt_speech=None,
39
+ gender=None,
40
+ pitch=None,
41
+ speed=None,
42
+ save_dir="example/results",
43
+ ):
44
+ """Perform TTS inference and save the generated audio."""
45
+ logging.info(f"Saving audio to: {save_dir}")
46
+
47
+ if prompt_text is not None:
48
+ prompt_text = None if len(prompt_text) <= 1 else prompt_text
49
+
50
+ # Ensure the save directory exists
51
+ os.makedirs(save_dir, exist_ok=True)
52
+
53
+ # Generate unique filename using timestamp
54
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
55
+ save_path = os.path.join(save_dir, f"{timestamp}.wav")
56
+
57
+ logging.info("Starting inference...")
58
+
59
+ # Perform inference and save the output audio
60
+ with torch.no_grad():
61
+ wav = model.inference(
62
+ text,
63
+ prompt_speech,
64
+ prompt_text,
65
+ gender,
66
+ pitch,
67
+ speed,
68
+ )
69
+
70
+ sf.write(save_path, wav, samplerate=16000)
71
+
72
+ logging.info(f"Audio saved at: {save_path}")
73
+
74
+ return save_path, model # Return model along with audio path
75
+
76
+
77
+ def voice_clone(text, model, prompt_text, prompt_wav_upload, prompt_wav_record):
78
+ """Gradio interface for TTS with prompt speech input."""
79
+ # Determine prompt speech (from audio file or recording)
80
+ prompt_speech = prompt_wav_upload if prompt_wav_upload else prompt_wav_record
81
+ prompt_text = None if len(prompt_text) < 2 else prompt_text
82
+ audio_output_path, model = run_tts(
83
+ text, model, prompt_text=prompt_text, prompt_speech=prompt_speech
84
+ )
85
+
86
+ return audio_output_path, model
87
+
88
+
89
+ def voice_creation(text, model, gender, pitch, speed):
90
+ """Gradio interface for TTS with control over voice attributes."""
91
+ pitch = LEVELS_MAP_UI[int(pitch)]
92
+ speed = LEVELS_MAP_UI[int(speed)]
93
+ audio_output_path, model = run_tts(
94
+ text, model, gender=gender, pitch=pitch, speed=speed
95
+ )
96
+ return audio_output_path, model
97
+
98
+
99
+ def build_ui(model_dir, device=0):
100
+ with gr.Blocks() as demo:
101
+ # Initialize model
102
+ model = initialize_model(model_dir, device=device)
103
+ # Use HTML for centered title
104
+ gr.HTML('<h1 style="text-align: center;">Spark-TTS by SparkAudio</h1>')
105
+ with gr.Tabs():
106
+ # Voice Clone Tab
107
+ with gr.TabItem("Voice Clone"):
108
+ gr.Markdown(
109
+ "### Upload reference audio or recording (上传参考音频或者录音)"
110
+ )
111
+
112
+ with gr.Row():
113
+ prompt_wav_upload = gr.Audio(
114
+ sources="upload",
115
+ type="filepath",
116
+ label="Choose the prompt audio file, ensuring the sampling rate is no lower than 16kHz.",
117
+ )
118
+ prompt_wav_record = gr.Audio(
119
+ sources="microphone",
120
+ type="filepath",
121
+ label="Record the prompt audio file.",
122
+ )
123
+
124
+ with gr.Row():
125
+ text_input = gr.Textbox(
126
+ label="Text", lines=3, placeholder="Enter text here"
127
+ )
128
+ prompt_text_input = gr.Textbox(
129
+ label="Text of prompt speech (Optional; recommended for cloning in the same language.)",
130
+ lines=3,
131
+ placeholder="Enter text of the prompt speech.",
132
+ )
133
+
134
+ audio_output = gr.Audio(
135
+ label="Generated Audio", autoplay=True, streaming=True
136
+ )
137
+
138
+ generate_buttom_clone = gr.Button("Generate")
139
+
140
+ generate_buttom_clone.click(
141
+ voice_clone,
142
+ inputs=[
143
+ text_input,
144
+ gr.State(model),
145
+ prompt_text_input,
146
+ prompt_wav_upload,
147
+ prompt_wav_record,
148
+ ],
149
+ outputs=[audio_output, gr.State(model)],
150
+ )
151
+
152
+ # Voice Creation Tab
153
+ with gr.TabItem("Voice Creation"):
154
+ gr.Markdown(
155
+ "### Create your own voice based on the following parameters"
156
+ )
157
+
158
+ with gr.Row():
159
+ with gr.Column():
160
+ gender = gr.Radio(
161
+ choices=["male", "female"], value="male", label="Gender"
162
+ )
163
+ pitch = gr.Slider(
164
+ minimum=1, maximum=5, step=1, value=3, label="Pitch"
165
+ )
166
+ speed = gr.Slider(
167
+ minimum=1, maximum=5, step=1, value=3, label="Speed"
168
+ )
169
+ with gr.Column():
170
+ text_input_creation = gr.Textbox(
171
+ label="Input Text",
172
+ lines=3,
173
+ placeholder="Enter text here",
174
+ value="You can generate a customized voice by adjusting parameters such as pitch and speed.",
175
+ )
176
+ create_button = gr.Button("Create Voice")
177
+
178
+ audio_output = gr.Audio(
179
+ label="Generated Audio", autoplay=True, streaming=True
180
+ )
181
+ create_button.click(
182
+ voice_creation,
183
+ inputs=[text_input_creation, gr.State(model), gender, pitch, speed],
184
+ outputs=[audio_output, gr.State(model)],
185
+ )
186
+
187
+ return demo
188
+
189
+
190
+ if __name__ == "__main__":
191
+ demo = build_ui(model_dir="pretrained_models/Spark-TTS-0.5B", device=5)
192
+ demo.launch()