mrfakename commited on
Commit
257b408
·
verified ·
1 Parent(s): 971a624

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (2) hide show
  1. inference-cli.py +77 -116
  2. inference-cli.toml +1 -1
inference-cli.py CHANGED
@@ -93,17 +93,6 @@ wave_path = Path(output_dir)/"out.wav"
93
  spectrogram_path = Path(output_dir)/"out.png"
94
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
95
 
96
- SPLIT_WORDS = [
97
- "but", "however", "nevertheless", "yet", "still",
98
- "therefore", "thus", "hence", "consequently",
99
- "moreover", "furthermore", "additionally",
100
- "meanwhile", "alternatively", "otherwise",
101
- "namely", "specifically", "for example", "such as",
102
- "in fact", "indeed", "notably",
103
- "in contrast", "on the other hand", "conversely",
104
- "in conclusion", "to summarize", "finally"
105
- ]
106
-
107
  device = (
108
  "cuda"
109
  if torch.cuda.is_available()
@@ -167,103 +156,36 @@ F5TTS_model_cfg = dict(
167
  )
168
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
169
 
170
- def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
171
- if len(text.encode('utf-8')) <= max_chars:
172
- return [text]
173
- if text[-1] not in ['。', '.', '!', '!', '?', '?']:
174
- text += '.'
175
-
176
- sentences = re.split('([。.!?!?])', text)
177
- sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
178
-
179
- batches = []
180
- current_batch = ""
181
-
182
- def split_by_words(text):
183
- words = text.split()
184
- current_word_part = ""
185
- word_batches = []
186
- for word in words:
187
- if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
188
- current_word_part += word + ' '
189
- else:
190
- if current_word_part:
191
- # Try to find a suitable split word
192
- for split_word in split_words:
193
- split_index = current_word_part.rfind(' ' + split_word + ' ')
194
- if split_index != -1:
195
- word_batches.append(current_word_part[:split_index].strip())
196
- current_word_part = current_word_part[split_index:].strip() + ' '
197
- break
198
- else:
199
- # If no suitable split word found, just append the current part
200
- word_batches.append(current_word_part.strip())
201
- current_word_part = ""
202
- current_word_part += word + ' '
203
- if current_word_part:
204
- word_batches.append(current_word_part.strip())
205
- return word_batches
206
 
207
  for sentence in sentences:
208
- if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
209
- current_batch += sentence
210
  else:
211
- # If adding this sentence would exceed the limit
212
- if current_batch:
213
- batches.append(current_batch)
214
- current_batch = ""
215
-
216
- # If the sentence itself is longer than max_chars, split it
217
- if len(sentence.encode('utf-8')) > max_chars:
218
- # First, try to split by colon
219
- colon_parts = sentence.split(':')
220
- if len(colon_parts) > 1:
221
- for part in colon_parts:
222
- if len(part.encode('utf-8')) <= max_chars:
223
- batches.append(part)
224
- else:
225
- # If colon part is still too long, split by comma
226
- comma_parts = re.split('[,,]', part)
227
- if len(comma_parts) > 1:
228
- current_comma_part = ""
229
- for comma_part in comma_parts:
230
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
231
- current_comma_part += comma_part + ','
232
- else:
233
- if current_comma_part:
234
- batches.append(current_comma_part.rstrip(','))
235
- current_comma_part = comma_part + ','
236
- if current_comma_part:
237
- batches.append(current_comma_part.rstrip(','))
238
- else:
239
- # If no comma, split by words
240
- batches.extend(split_by_words(part))
241
- else:
242
- # If no colon, split by comma
243
- comma_parts = re.split('[,,]', sentence)
244
- if len(comma_parts) > 1:
245
- current_comma_part = ""
246
- for comma_part in comma_parts:
247
- if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
248
- current_comma_part += comma_part + ','
249
- else:
250
- if current_comma_part:
251
- batches.append(current_comma_part.rstrip(','))
252
- current_comma_part = comma_part + ','
253
- if current_comma_part:
254
- batches.append(current_comma_part.rstrip(','))
255
- else:
256
- # If no comma, split by words
257
- batches.extend(split_by_words(sentence))
258
- else:
259
- current_batch = sentence
260
-
261
- if current_batch:
262
- batches.append(current_batch)
263
-
264
- return batches
265
 
266
- def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
267
  if model == "F5-TTS":
268
  ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
269
  elif model == "E2-TTS":
@@ -321,8 +243,44 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
321
  generated_waves.append(generated_wave)
322
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
323
 
324
- # Combine all generated waves
325
- final_wave = np.concatenate(generated_waves)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  with open(wave_path, "wb") as f:
328
  sf.write(f.name, final_wave, target_sample_rate)
@@ -343,11 +301,7 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
343
  print(spectrogram_path)
344
 
345
 
346
- def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
347
- if not custom_split_words.strip():
348
- custom_words = [word.strip() for word in custom_split_words.split(',')]
349
- global SPLIT_WORDS
350
- SPLIT_WORDS = custom_words
351
 
352
  print(gen_text)
353
 
@@ -355,7 +309,7 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
355
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
356
  aseg = AudioSegment.from_file(ref_audio_orig)
357
 
358
- non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
359
  non_silent_wave = AudioSegment.silent(duration=0)
360
  for non_silent_seg in non_silent_segs:
361
  non_silent_wave += non_silent_seg
@@ -387,16 +341,23 @@ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_spli
387
  else:
388
  print("Using custom reference text...")
389
 
 
 
 
 
 
 
 
390
  # Split the input text into batches
391
  audio, sr = torchaudio.load(ref_audio)
392
- max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
393
- gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
394
  print('ref_text', ref_text)
395
  for i, gen_text in enumerate(gen_text_batches):
396
  print(f'gen_text {i}', gen_text)
397
 
398
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
399
- return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
400
 
401
 
402
- infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
 
93
  spectrogram_path = Path(output_dir)/"out.png"
94
  vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
95
 
 
 
 
 
 
 
 
 
 
 
 
96
  device = (
97
  "cuda"
98
  if torch.cuda.is_available()
 
156
  )
157
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
158
 
159
+
160
+ def chunk_text(text, max_chars=135):
161
+ """
162
+ Splits the input text into chunks, each with a maximum number of characters.
163
+ Args:
164
+ text (str): The text to be split.
165
+ max_chars (int): The maximum number of characters per chunk.
166
+ Returns:
167
+ List[str]: A list of text chunks.
168
+ """
169
+ chunks = []
170
+ current_chunk = ""
171
+ # Split the text into sentences based on punctuation followed by whitespace
172
+ sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  for sentence in sentences:
175
+ if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
176
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
177
  else:
178
+ if current_chunk:
179
+ chunks.append(current_chunk.strip())
180
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence
181
+
182
+ if current_chunk:
183
+ chunks.append(current_chunk.strip())
184
+
185
+ return chunks
186
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
189
  if model == "F5-TTS":
190
  ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
191
  elif model == "E2-TTS":
 
243
  generated_waves.append(generated_wave)
244
  spectrograms.append(generated_mel_spec[0].cpu().numpy())
245
 
246
+ # Combine all generated waves with cross-fading
247
+ if cross_fade_duration <= 0:
248
+ # Simply concatenate
249
+ final_wave = np.concatenate(generated_waves)
250
+ else:
251
+ final_wave = generated_waves[0]
252
+ for i in range(1, len(generated_waves)):
253
+ prev_wave = final_wave
254
+ next_wave = generated_waves[i]
255
+
256
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
257
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
258
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
259
+
260
+ if cross_fade_samples <= 0:
261
+ # No overlap possible, concatenate
262
+ final_wave = np.concatenate([prev_wave, next_wave])
263
+ continue
264
+
265
+ # Overlapping parts
266
+ prev_overlap = prev_wave[-cross_fade_samples:]
267
+ next_overlap = next_wave[:cross_fade_samples]
268
+
269
+ # Fade out and fade in
270
+ fade_out = np.linspace(1, 0, cross_fade_samples)
271
+ fade_in = np.linspace(0, 1, cross_fade_samples)
272
+
273
+ # Cross-faded overlap
274
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
275
+
276
+ # Combine
277
+ new_wave = np.concatenate([
278
+ prev_wave[:-cross_fade_samples],
279
+ cross_faded_overlap,
280
+ next_wave[cross_fade_samples:]
281
+ ])
282
+
283
+ final_wave = new_wave
284
 
285
  with open(wave_path, "wb") as f:
286
  sf.write(f.name, final_wave, target_sample_rate)
 
301
  print(spectrogram_path)
302
 
303
 
304
+ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
 
 
 
 
305
 
306
  print(gen_text)
307
 
 
309
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
310
  aseg = AudioSegment.from_file(ref_audio_orig)
311
 
312
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000)
313
  non_silent_wave = AudioSegment.silent(duration=0)
314
  for non_silent_seg in non_silent_segs:
315
  non_silent_wave += non_silent_seg
 
341
  else:
342
  print("Using custom reference text...")
343
 
344
+ # Add the functionality to ensure it ends with ". "
345
+ if not ref_text.endswith(". ") and not ref_text.endswith("。"):
346
+ if ref_text.endswith("."):
347
+ ref_text += " "
348
+ else:
349
+ ref_text += ". "
350
+
351
  # Split the input text into batches
352
  audio, sr = torchaudio.load(ref_audio)
353
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr))
354
+ gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
355
  print('ref_text', ref_text)
356
  for i, gen_text in enumerate(gen_text_batches):
357
  print(f'gen_text {i}', gen_text)
358
 
359
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
360
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
361
 
362
 
363
+ infer(ref_audio, ref_text, gen_text, model, remove_silence)
inference-cli.toml CHANGED
@@ -6,5 +6,5 @@ ref_text = "Some call me nature, others call me mother nature."
6
  gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
  # File with text to generate. Ignores the text above.
8
  gen_file = ""
9
- remove_silence = true
10
  output_dir = "tests"
 
6
  gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
  # File with text to generate. Ignores the text above.
8
  gen_file = ""
9
+ remove_silence = false
10
  output_dir = "tests"