Spaces:
Running
Running
mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- inference-cli.py +77 -116
- inference-cli.toml +1 -1
inference-cli.py
CHANGED
@@ -93,17 +93,6 @@ wave_path = Path(output_dir)/"out.wav"
|
|
93 |
spectrogram_path = Path(output_dir)/"out.png"
|
94 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
95 |
|
96 |
-
SPLIT_WORDS = [
|
97 |
-
"but", "however", "nevertheless", "yet", "still",
|
98 |
-
"therefore", "thus", "hence", "consequently",
|
99 |
-
"moreover", "furthermore", "additionally",
|
100 |
-
"meanwhile", "alternatively", "otherwise",
|
101 |
-
"namely", "specifically", "for example", "such as",
|
102 |
-
"in fact", "indeed", "notably",
|
103 |
-
"in contrast", "on the other hand", "conversely",
|
104 |
-
"in conclusion", "to summarize", "finally"
|
105 |
-
]
|
106 |
-
|
107 |
device = (
|
108 |
"cuda"
|
109 |
if torch.cuda.is_available()
|
@@ -167,103 +156,36 @@ F5TTS_model_cfg = dict(
|
|
167 |
)
|
168 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
current_word_part = ""
|
185 |
-
word_batches = []
|
186 |
-
for word in words:
|
187 |
-
if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
|
188 |
-
current_word_part += word + ' '
|
189 |
-
else:
|
190 |
-
if current_word_part:
|
191 |
-
# Try to find a suitable split word
|
192 |
-
for split_word in split_words:
|
193 |
-
split_index = current_word_part.rfind(' ' + split_word + ' ')
|
194 |
-
if split_index != -1:
|
195 |
-
word_batches.append(current_word_part[:split_index].strip())
|
196 |
-
current_word_part = current_word_part[split_index:].strip() + ' '
|
197 |
-
break
|
198 |
-
else:
|
199 |
-
# If no suitable split word found, just append the current part
|
200 |
-
word_batches.append(current_word_part.strip())
|
201 |
-
current_word_part = ""
|
202 |
-
current_word_part += word + ' '
|
203 |
-
if current_word_part:
|
204 |
-
word_batches.append(current_word_part.strip())
|
205 |
-
return word_batches
|
206 |
|
207 |
for sentence in sentences:
|
208 |
-
if len(
|
209 |
-
|
210 |
else:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
if len(colon_parts) > 1:
|
221 |
-
for part in colon_parts:
|
222 |
-
if len(part.encode('utf-8')) <= max_chars:
|
223 |
-
batches.append(part)
|
224 |
-
else:
|
225 |
-
# If colon part is still too long, split by comma
|
226 |
-
comma_parts = re.split('[,,]', part)
|
227 |
-
if len(comma_parts) > 1:
|
228 |
-
current_comma_part = ""
|
229 |
-
for comma_part in comma_parts:
|
230 |
-
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
231 |
-
current_comma_part += comma_part + ','
|
232 |
-
else:
|
233 |
-
if current_comma_part:
|
234 |
-
batches.append(current_comma_part.rstrip(','))
|
235 |
-
current_comma_part = comma_part + ','
|
236 |
-
if current_comma_part:
|
237 |
-
batches.append(current_comma_part.rstrip(','))
|
238 |
-
else:
|
239 |
-
# If no comma, split by words
|
240 |
-
batches.extend(split_by_words(part))
|
241 |
-
else:
|
242 |
-
# If no colon, split by comma
|
243 |
-
comma_parts = re.split('[,,]', sentence)
|
244 |
-
if len(comma_parts) > 1:
|
245 |
-
current_comma_part = ""
|
246 |
-
for comma_part in comma_parts:
|
247 |
-
if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
|
248 |
-
current_comma_part += comma_part + ','
|
249 |
-
else:
|
250 |
-
if current_comma_part:
|
251 |
-
batches.append(current_comma_part.rstrip(','))
|
252 |
-
current_comma_part = comma_part + ','
|
253 |
-
if current_comma_part:
|
254 |
-
batches.append(current_comma_part.rstrip(','))
|
255 |
-
else:
|
256 |
-
# If no comma, split by words
|
257 |
-
batches.extend(split_by_words(sentence))
|
258 |
-
else:
|
259 |
-
current_batch = sentence
|
260 |
-
|
261 |
-
if current_batch:
|
262 |
-
batches.append(current_batch)
|
263 |
-
|
264 |
-
return batches
|
265 |
|
266 |
-
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
267 |
if model == "F5-TTS":
|
268 |
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
269 |
elif model == "E2-TTS":
|
@@ -321,8 +243,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
|
321 |
generated_waves.append(generated_wave)
|
322 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
323 |
|
324 |
-
# Combine all generated waves
|
325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
with open(wave_path, "wb") as f:
|
328 |
sf.write(f.name, final_wave, target_sample_rate)
|
@@ -343,11 +301,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
|
|
343 |
print(spectrogram_path)
|
344 |
|
345 |
|
346 |
-
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence,
|
347 |
-
if not custom_split_words.strip():
|
348 |
-
custom_words = [word.strip() for word in custom_split_words.split(',')]
|
349 |
-
global SPLIT_WORDS
|
350 |
-
SPLIT_WORDS = custom_words
|
351 |
|
352 |
print(gen_text)
|
353 |
|
@@ -355,7 +309,7 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
|
355 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
356 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
357 |
|
358 |
-
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=
|
359 |
non_silent_wave = AudioSegment.silent(duration=0)
|
360 |
for non_silent_seg in non_silent_segs:
|
361 |
non_silent_wave += non_silent_seg
|
@@ -387,16 +341,23 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
|
|
387 |
else:
|
388 |
print("Using custom reference text...")
|
389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
# Split the input text into batches
|
391 |
audio, sr = torchaudio.load(ref_audio)
|
392 |
-
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (
|
393 |
-
gen_text_batches =
|
394 |
print('ref_text', ref_text)
|
395 |
for i, gen_text in enumerate(gen_text_batches):
|
396 |
print(f'gen_text {i}', gen_text)
|
397 |
|
398 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
399 |
-
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
|
400 |
|
401 |
|
402 |
-
infer(ref_audio, ref_text, gen_text, model, remove_silence
|
|
|
93 |
spectrogram_path = Path(output_dir)/"out.png"
|
94 |
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
device = (
|
97 |
"cuda"
|
98 |
if torch.cuda.is_available()
|
|
|
156 |
)
|
157 |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
158 |
|
159 |
+
|
160 |
+
def chunk_text(text, max_chars=135):
|
161 |
+
"""
|
162 |
+
Splits the input text into chunks, each with a maximum number of characters.
|
163 |
+
Args:
|
164 |
+
text (str): The text to be split.
|
165 |
+
max_chars (int): The maximum number of characters per chunk.
|
166 |
+
Returns:
|
167 |
+
List[str]: A list of text chunks.
|
168 |
+
"""
|
169 |
+
chunks = []
|
170 |
+
current_chunk = ""
|
171 |
+
# Split the text into sentences based on punctuation followed by whitespace
|
172 |
+
sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
for sentence in sentences:
|
175 |
+
if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
|
176 |
+
current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
177 |
else:
|
178 |
+
if current_chunk:
|
179 |
+
chunks.append(current_chunk.strip())
|
180 |
+
current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
|
181 |
+
|
182 |
+
if current_chunk:
|
183 |
+
chunks.append(current_chunk.strip())
|
184 |
+
|
185 |
+
return chunks
|
186 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
+
def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
|
189 |
if model == "F5-TTS":
|
190 |
ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
|
191 |
elif model == "E2-TTS":
|
|
|
243 |
generated_waves.append(generated_wave)
|
244 |
spectrograms.append(generated_mel_spec[0].cpu().numpy())
|
245 |
|
246 |
+
# Combine all generated waves with cross-fading
|
247 |
+
if cross_fade_duration <= 0:
|
248 |
+
# Simply concatenate
|
249 |
+
final_wave = np.concatenate(generated_waves)
|
250 |
+
else:
|
251 |
+
final_wave = generated_waves[0]
|
252 |
+
for i in range(1, len(generated_waves)):
|
253 |
+
prev_wave = final_wave
|
254 |
+
next_wave = generated_waves[i]
|
255 |
+
|
256 |
+
# Calculate cross-fade samples, ensuring it does not exceed wave lengths
|
257 |
+
cross_fade_samples = int(cross_fade_duration * target_sample_rate)
|
258 |
+
cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
|
259 |
+
|
260 |
+
if cross_fade_samples <= 0:
|
261 |
+
# No overlap possible, concatenate
|
262 |
+
final_wave = np.concatenate([prev_wave, next_wave])
|
263 |
+
continue
|
264 |
+
|
265 |
+
# Overlapping parts
|
266 |
+
prev_overlap = prev_wave[-cross_fade_samples:]
|
267 |
+
next_overlap = next_wave[:cross_fade_samples]
|
268 |
+
|
269 |
+
# Fade out and fade in
|
270 |
+
fade_out = np.linspace(1, 0, cross_fade_samples)
|
271 |
+
fade_in = np.linspace(0, 1, cross_fade_samples)
|
272 |
+
|
273 |
+
# Cross-faded overlap
|
274 |
+
cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
|
275 |
+
|
276 |
+
# Combine
|
277 |
+
new_wave = np.concatenate([
|
278 |
+
prev_wave[:-cross_fade_samples],
|
279 |
+
cross_faded_overlap,
|
280 |
+
next_wave[cross_fade_samples:]
|
281 |
+
])
|
282 |
+
|
283 |
+
final_wave = new_wave
|
284 |
|
285 |
with open(wave_path, "wb") as f:
|
286 |
sf.write(f.name, final_wave, target_sample_rate)
|
|
|
301 |
print(spectrogram_path)
|
302 |
|
303 |
|
304 |
+
def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
|
|
|
|
|
|
|
|
|
305 |
|
306 |
print(gen_text)
|
307 |
|
|
|
309 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
|
310 |
aseg = AudioSegment.from_file(ref_audio_orig)
|
311 |
|
312 |
+
non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
|
313 |
non_silent_wave = AudioSegment.silent(duration=0)
|
314 |
for non_silent_seg in non_silent_segs:
|
315 |
non_silent_wave += non_silent_seg
|
|
|
341 |
else:
|
342 |
print("Using custom reference text...")
|
343 |
|
344 |
+
# Add the functionality to ensure it ends with ". "
|
345 |
+
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
|
346 |
+
if ref_text.endswith("."):
|
347 |
+
ref_text += " "
|
348 |
+
else:
|
349 |
+
ref_text += ". "
|
350 |
+
|
351 |
# Split the input text into batches
|
352 |
audio, sr = torchaudio.load(ref_audio)
|
353 |
+
max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
|
354 |
+
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
|
355 |
print('ref_text', ref_text)
|
356 |
for i, gen_text in enumerate(gen_text_batches):
|
357 |
print(f'gen_text {i}', gen_text)
|
358 |
|
359 |
print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
|
360 |
+
return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
|
361 |
|
362 |
|
363 |
+
infer(ref_audio, ref_text, gen_text, model, remove_silence)
|
inference-cli.toml
CHANGED
@@ -6,5 +6,5 @@ ref_text = "Some call me nature, others call me mother nature."
|
|
6 |
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
7 |
# File with text to generate. Ignores the text above.
|
8 |
gen_file = ""
|
9 |
-
remove_silence =
|
10 |
output_dir = "tests"
|
|
|
6 |
gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
|
7 |
# File with text to generate. Ignores the text above.
|
8 |
gen_file = ""
|
9 |
+
remove_silence = false
|
10 |
output_dir = "tests"
|