hexgrad commited on
Commit
6459fb3
·
verified ·
1 Parent(s): eff6dd2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -20
app.py CHANGED
@@ -9,6 +9,7 @@ import pypdf
9
  import random
10
  import re
11
  import spaces
 
12
  import torch
13
  import yaml
14
 
@@ -253,7 +254,7 @@ def generate(text, voice='af', ps=None, speed=1, trim=3000, use_gpu='auto'):
253
  except gr.exceptions.Error as e:
254
  if use_gpu:
255
  gr.Warning(str(e))
256
- gr.Info('GPU failover to CPU')
257
  out = forward(tokens, voices, speed)
258
  else:
259
  raise gr.Error(e)
@@ -323,10 +324,12 @@ with gr.Blocks() as basic_tts:
323
  generate_btn.click(generate, inputs=[text, voice, in_ps, speed, trim, use_gpu], outputs=[audio, out_ps])
324
 
325
  @torch.no_grad()
326
- def lf_forward(token_lists, voices, speed, device='cpu'):
327
  voicepack = torch.mean(torch.stack([VOICES[device][v] for v in voices]), dim=0)
328
  outs = []
329
  for tokens in token_lists:
 
 
330
  ref_s = voicepack[len(tokens)]
331
  s = ref_s[:, 128:]
332
  tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
@@ -352,8 +355,8 @@ def lf_forward(token_lists, voices, speed, device='cpu'):
352
  return outs
353
 
354
  @spaces.GPU
355
- def lf_forward_gpu(token_lists, voices, speed):
356
- return lf_forward(token_lists, voices, speed, device='cuda')
357
 
358
  def resplit_strings(arr):
359
  # Handle edge cases
@@ -407,35 +410,45 @@ def segment_and_tokenize(text, voice, skip_square_brackets=True, newline_split=2
407
  segments = [row for t in texts for row in recursive_split(t, voice)]
408
  return [(i, *row) for i, row in enumerate(segments)]
409
 
410
- def lf_generate(segments, voice, speed=1, trim=0, pad_between=0, use_gpu=True):
 
 
411
  token_lists = list(map(tokenize, segments['Tokens']))
412
  voices = resolve_voices(voice)
413
  speed = clamp_speed(speed)
414
- wavs = []
415
  trim = int(trim / speed)
416
  pad_between = int(pad_between / speed)
417
- batch_size = 100
418
- for i in range(0, len(token_lists), batch_size):
 
 
 
419
  try:
420
  if use_gpu:
421
- outs = lf_forward_gpu(token_lists[i:i+batch_size], voices, speed)
422
  else:
423
- outs = lf_forward(token_lists[i:i+batch_size], voices, speed)
424
  except gr.exceptions.Error as e:
425
- if wavs:
426
  gr.Warning(str(e))
 
 
 
427
  else:
428
  raise gr.Error(e)
429
- break
430
  for out in outs:
431
  if trim > 0:
432
  if trim * 2 >= len(out):
433
  continue
434
  out = out[trim:-trim]
435
- if wavs and pad_between > 0:
436
- wavs.append(np.zeros(pad_between))
437
- wavs.append(out)
438
- return (SAMPLE_RATE, np.concatenate(wavs)) if wavs else None
 
 
 
 
439
 
440
  def did_change_segments(segments):
441
  x = len(segments) if segments['Length'].any() else 0
@@ -473,23 +486,27 @@ with gr.Blocks() as lf_tts:
473
  skip_square_brackets = gr.Checkbox(True, label='Skip [Square Brackets]', info='Recommended for academic papers, Wikipedia articles, or texts with citations')
474
  newline_split = gr.Number(2, label='Newline Split', info='Split the input text on this many newlines. Affects how the text is segmented.', precision=0, minimum=0)
475
  with gr.Row():
 
476
  segment_btn = gr.Button('Tokenize', variant='primary')
477
- generate_btn = gr.Button('Generate x0', variant='secondary', interactive=False)
478
  with gr.Column():
479
- audio = gr.Audio(interactive=False, label='Output Audio')
480
  with gr.Accordion('Audio Settings', open=True):
481
  speed = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.1, label='⚡️ Speed', info='Adjust the speaking speed')
482
  trim = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='✂️ Trim', info='Cut from both ends')
483
  pad_between = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='🔇 Pad Between', info='How much silence to insert between segments')
 
 
 
484
  with gr.Row():
485
  segments = gr.Dataframe(headers=['#', 'Text', 'Tokens', 'Length'], row_count=(1, 'dynamic'), col_count=(4, 'fixed'), label='Segments', interactive=False, wrap=True)
486
  segments.change(fn=did_change_segments, inputs=[segments], outputs=[segment_btn, generate_btn])
487
  segment_btn.click(segment_and_tokenize, inputs=[text, voice, skip_square_brackets, newline_split], outputs=[segments])
488
- generate_btn.click(lf_generate, inputs=[segments, voice, speed, trim, pad_between, use_gpu], outputs=[audio])
 
489
 
490
  with gr.Blocks() as about:
491
  gr.Markdown("""
492
- Kokoro is a frontier TTS model for its size. It has [80 million](https://hf.co/spaces/hexgrad/Kokoro-TTS/blob/main/app.py#L31) parameters, uses a lean [StyleTTS 2](https://github.com/yl4579/StyleTTS2) architecture, and was trained on high-quality data. The weights are currently private, but a free public demo is hosted here, at `https://hf.co/spaces/hexgrad/Kokoro-TTS`. The Community tab is open for feature requests, bug reports, etc. For other inquiries, contact `@rzvzn` on Discord.
493
 
494
  ### FAQ
495
  **Will this be open sourced?**<br/>
 
9
  import random
10
  import re
11
  import spaces
12
+ import threading
13
  import torch
14
  import yaml
15
 
 
254
  except gr.exceptions.Error as e:
255
  if use_gpu:
256
  gr.Warning(str(e))
257
+ gr.Info('Switching to CPU')
258
  out = forward(tokens, voices, speed)
259
  else:
260
  raise gr.Error(e)
 
324
  generate_btn.click(generate, inputs=[text, voice, in_ps, speed, trim, use_gpu], outputs=[audio, out_ps])
325
 
326
  @torch.no_grad()
327
+ def lf_forward(token_lists, voices, speed, stop_event, device='cpu'):
328
  voicepack = torch.mean(torch.stack([VOICES[device][v] for v in voices]), dim=0)
329
  outs = []
330
  for tokens in token_lists:
331
+ if stop_event.is_set():
332
+ break
333
  ref_s = voicepack[len(tokens)]
334
  s = ref_s[:, 128:]
335
  tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
 
355
  return outs
356
 
357
  @spaces.GPU
358
+ def lf_forward_gpu(token_lists, voices, speed, stop_event):
359
+ return lf_forward(token_lists, voices, speed, stop_event, device='cuda')
360
 
361
  def resplit_strings(arr):
362
  # Handle edge cases
 
410
  segments = [row for t in texts for row in recursive_split(t, voice)]
411
  return [(i, *row) for i, row in enumerate(segments)]
412
 
413
+ def lf_generate(segments, voice, speed=1, trim=0, pad_between=0, use_gpu=True, audio_stream=None):
414
+ if audio_stream is not None and len(audio_stream) == 3:
415
+ audio_stream[-1].set()
416
  token_lists = list(map(tokenize, segments['Tokens']))
417
  voices = resolve_voices(voice)
418
  speed = clamp_speed(speed)
 
419
  trim = int(trim / speed)
420
  pad_between = int(pad_between / speed)
421
+ batch_sizes = [89, 55, 34, 21, 13, 8, 5, 3, 2, 1, 1]
422
+ i = 0
423
+ stop_event = threading.Event()
424
+ while i < len(token_lists):
425
+ bs = batch_sizes.pop() if batch_sizes else 100
426
  try:
427
  if use_gpu:
428
+ outs = lf_forward_gpu(token_lists[i:i+bs], voices, speed, stop_event)
429
  else:
430
+ outs = lf_forward(token_lists[i:i+bs], voices, speed, stop_event)
431
  except gr.exceptions.Error as e:
432
+ if use_gpu:
433
  gr.Warning(str(e))
434
+ gr.Info('Switching to CPU')
435
+ outs = lf_forward(token_lists[i:i+bs], voices, speed, stop_event)
436
+ use_gpu = False
437
  else:
438
  raise gr.Error(e)
 
439
  for out in outs:
440
  if trim > 0:
441
  if trim * 2 >= len(out):
442
  continue
443
  out = out[trim:-trim]
444
+ if i > 0 and pad_between > 0:
445
+ yield SAMPLE_RATE, np.zeros(pad_between), stop_event
446
+ yield SAMPLE_RATE, out, stop_event
447
+ i += bs
448
+
449
+ def lf_stop(audio_stream):
450
+ if audio_stream is not None and len(audio_stream) == 3:
451
+ audio_stream[-1].set()
452
 
453
  def did_change_segments(segments):
454
  x = len(segments) if segments['Length'].any() else 0
 
486
  skip_square_brackets = gr.Checkbox(True, label='Skip [Square Brackets]', info='Recommended for academic papers, Wikipedia articles, or texts with citations')
487
  newline_split = gr.Number(2, label='Newline Split', info='Split the input text on this many newlines. Affects how the text is segmented.', precision=0, minimum=0)
488
  with gr.Row():
489
+ upload_btn = gr.Button('Upload', variant='secondary')
490
  segment_btn = gr.Button('Tokenize', variant='primary')
 
491
  with gr.Column():
492
+ audio_stream = gr.Audio(label='Output Audio Stream', interactive=False, streaming=True, autoplay=True)
493
  with gr.Accordion('Audio Settings', open=True):
494
  speed = gr.Slider(minimum=0.5, maximum=2, value=1, step=0.1, label='⚡️ Speed', info='Adjust the speaking speed')
495
  trim = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='✂️ Trim', info='Cut from both ends')
496
  pad_between = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='🔇 Pad Between', info='How much silence to insert between segments')
497
+ with gr.Row():
498
+ generate_btn = gr.Button('Generate x0', variant='secondary', interactive=False)
499
+ stop_btn = gr.Button('Stop', variant='stop')
500
  with gr.Row():
501
  segments = gr.Dataframe(headers=['#', 'Text', 'Tokens', 'Length'], row_count=(1, 'dynamic'), col_count=(4, 'fixed'), label='Segments', interactive=False, wrap=True)
502
  segments.change(fn=did_change_segments, inputs=[segments], outputs=[segment_btn, generate_btn])
503
  segment_btn.click(segment_and_tokenize, inputs=[text, voice, skip_square_brackets, newline_split], outputs=[segments])
504
+ generate_btn.click(lf_generate, inputs=[segments, voice, speed, trim, pad_between, use_gpu, audio_stream], outputs=[audio_stream])
505
+ stop_btn.click(lf_stop, inputs=[audio_stream], outputs=[audio_stream])
506
 
507
  with gr.Blocks() as about:
508
  gr.Markdown("""
509
+ Kokoro is a frontier TTS model for its size. It has [80 million](https://hf.co/spaces/hexgrad/Kokoro-TTS/blob/main/app.py#L32) parameters, uses a lean [StyleTTS 2](https://github.com/yl4579/StyleTTS2) architecture, and was trained on high-quality data. The weights are currently private, but a free public demo is hosted here, at `https://hf.co/spaces/hexgrad/Kokoro-TTS`. The Community tab is open for feature requests, bug reports, etc. For other inquiries, contact `@rzvzn` on Discord.
510
 
511
  ### FAQ
512
  **Will this be open sourced?**<br/>