hexgrad commited on
Commit
d989475
1 Parent(s): 8c17d76

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +82 -85
  2. models.py +13 -7
app.py CHANGED
@@ -12,22 +12,22 @@ import spaces
12
  import torch
13
  import yaml
14
 
15
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
 
17
  snapshot = snapshot_download(repo_id='hexgrad/kokoro', allow_patterns=['*.pt', '*.pth', '*.yml'], use_auth_token=os.environ['TOKEN'])
18
  config = yaml.safe_load(open(os.path.join(snapshot, 'config.yml')))
19
- model = build_model(config['model_params'])
20
- _ = [model[key].eval() for key in model]
21
- _ = [model[key].to(device) for key in model]
22
  for key, state_dict in torch.load(os.path.join(snapshot, 'net.pth'), map_location='cpu', weights_only=True)['net'].items():
23
- assert key in model, key
24
- try:
25
- model[key].load_state_dict(state_dict)
26
- except:
27
- state_dict = {k[7:]: v for k, v in state_dict.items()}
28
- model[key].load_state_dict(state_dict, strict=False)
 
29
 
30
- PARAM_COUNT = sum(p.numel() for value in model.values() for p in value.parameters())
31
  assert PARAM_COUNT < 82_000_000, PARAM_COUNT
32
 
33
  random_texts = {}
@@ -118,7 +118,7 @@ def phonemize(text, voice, norm=True):
118
  ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
119
  ps = ''.join(filter(lambda p: p in VOCAB, ps))
120
  if lang == 'j' and any(p in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' for p in ps):
121
- gr.Warning('Japanese tokenizer does not handle English letters.')
122
  return ps.strip()
123
 
124
  def length_to_mask(lengths):
@@ -154,7 +154,7 @@ CHOICES = {
154
  '🇬🇧 🚹 Lewis 🧪': 'bm_lewis',
155
  '🇯🇵 🚺 Japanese Female': 'jf_0',
156
  }
157
- VOICES = {k: torch.load(os.path.join(snapshot, 'voicepacks', f'{k}.pt'), weights_only=True).to(device) for k in CHOICES.values()}
158
 
159
  np_log_99 = np.log(99)
160
  def s_curve(p):
@@ -168,19 +168,18 @@ def s_curve(p):
168
 
169
  SAMPLE_RATE = 24000
170
 
171
- @spaces.GPU(duration=10)
172
  @torch.no_grad()
173
- def forward(tokens, voice, speed):
174
- ref_s = VOICES[voice][len(tokens)]
175
  tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
176
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
177
  text_mask = length_to_mask(input_lengths).to(device)
178
- bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
179
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
180
  s = ref_s[:, 128:]
181
- d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
182
- x, _ = model.predictor.lstm(d)
183
- duration = model.predictor.duration_proj(x)
184
  duration = torch.sigmoid(duration).sum(axis=-1) / speed
185
  pred_dur = torch.round(duration).clamp(min=1).long()
186
  pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
@@ -189,12 +188,16 @@ def forward(tokens, voice, speed):
189
  pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
190
  c_frame += pred_dur[0,i].item()
191
  en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
192
- F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
193
- t_en = model.text_encoder(tokens, input_lengths, text_mask)
194
  asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
195
- return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
 
 
 
 
196
 
197
- def generate(text, voice, ps=None, speed=1.0, opening_cut=4000, closing_cut=2000, ease_in=3000, ease_out=1000, pad_before=0, pad_after=0):
198
  if voice not in VOICES:
199
  # Ensure stability for https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena
200
  voice = 'af'
@@ -206,7 +209,10 @@ def generate(text, voice, ps=None, speed=1.0, opening_cut=4000, closing_cut=2000
206
  tokens = tokens[:510]
207
  ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
208
  try:
209
- out = forward(tokens, voice, speed)
 
 
 
210
  except gr.exceptions.Error as e:
211
  raise gr.Error(e)
212
  return (None, '')
@@ -222,23 +228,15 @@ def generate(text, voice, ps=None, speed=1.0, opening_cut=4000, closing_cut=2000
222
  ease_out = min(int(ease_out / speed), len(out)//2)
223
  for i in range(ease_out):
224
  out[-i-1] *= s_curve(i / ease_out)
225
- pad_before = int(pad_before / speed)
226
- if pad_before > 0:
227
- out = np.concatenate([np.zeros(pad_before), out])
228
- pad_after = int(pad_after / speed)
229
- if pad_after > 0:
230
- out = np.concatenate([out, np.zeros(pad_after)])
231
  return ((SAMPLE_RATE, out), ps)
232
 
233
  def toggle_autoplay(autoplay):
234
  return gr.Audio(interactive=False, label='Output Audio', autoplay=autoplay)
235
 
236
  with gr.Blocks() as basic_tts:
237
- with gr.Row():
238
- gr.Markdown('Generate speech for one segment of text (up to 510 tokens) using Kokoro, a TTS model with 80 million parameters.')
239
  with gr.Row():
240
  with gr.Column():
241
- text = gr.Textbox(label='Input Text')
242
  voice = gr.Dropdown(list(CHOICES.items()), label='Voice', info='⭐ Starred voices are more stable. 🧪 Experimental voices are less stable.')
243
  with gr.Row():
244
  random_btn = gr.Button('Random Text', variant='secondary')
@@ -252,36 +250,36 @@ with gr.Blocks() as basic_tts:
252
  phonemize_btn.click(phonemize, inputs=[text, voice], outputs=[in_ps])
253
  with gr.Column():
254
  audio = gr.Audio(interactive=False, label='Output Audio', autoplay=True)
 
 
255
  with gr.Accordion('Output Tokens', open=True):
256
  out_ps = gr.Textbox(interactive=False, show_label=False, info='Tokens used to generate the audio, up to 510 allowed. Same as input tokens if supplied, excluding unknowns.')
 
 
 
 
 
 
 
257
  with gr.Accordion('Audio Settings', open=False):
258
  with gr.Row():
259
- autoplay = gr.Checkbox(value=True, label='Autoplay')
260
- with gr.Row():
261
- speed = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label='Speed', info='⚡️ Adjust the speed of the audio. The settings below are auto-scaled by speed.')
262
  with gr.Row():
263
  with gr.Column():
264
- opening_cut = gr.Slider(minimum=0, maximum=24000, value=4000, step=1000, label='Opening Cut', info='✂️ Cut this many samples from the start.')
265
  with gr.Column():
266
- closing_cut = gr.Slider(minimum=0, maximum=24000, value=2000, step=1000, label='Closing Cut', info='✂️ Cut this many samples from the end.')
267
  with gr.Row():
268
  with gr.Column():
269
- ease_in = gr.Slider(minimum=0, maximum=24000, value=3000, step=1000, label='Ease In', info='🚀 Ease in for this many samples, after opening cut.')
270
  with gr.Column():
271
- ease_out = gr.Slider(minimum=0, maximum=24000, value=1000, step=1000, label='Ease Out', info='📐 Ease out for this many samples, before closing cut.')
272
- with gr.Row():
273
- with gr.Column():
274
- pad_before = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='Pad Before', info='🔇 How many samples of silence to insert before the start.')
275
- with gr.Column():
276
- pad_after = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='Pad After', info='🔇 How many samples of silence to append after the end.')
277
- autoplay.change(toggle_autoplay, inputs=[autoplay], outputs=[audio])
278
- text.submit(generate, inputs=[text, voice, in_ps, speed, opening_cut, closing_cut, ease_in, ease_out, pad_before, pad_after], outputs=[audio, out_ps])
279
- generate_btn.click(generate, inputs=[text, voice, in_ps, speed, opening_cut, closing_cut, ease_in, ease_out, pad_before, pad_after], outputs=[audio, out_ps])
280
 
281
- @spaces.GPU
282
  @torch.no_grad()
283
- def lf_forward(token_lists, voice, speed):
284
- voicepack = VOICES[voice]
285
  outs = []
286
  for tokens in token_lists:
287
  ref_s = voicepack[len(tokens)]
@@ -289,11 +287,11 @@ def lf_forward(token_lists, voice, speed):
289
  tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
290
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
291
  text_mask = length_to_mask(input_lengths).to(device)
292
- bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
293
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
294
- d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
295
- x, _ = model.predictor.lstm(d)
296
- duration = model.predictor.duration_proj(x)
297
  duration = torch.sigmoid(duration).sum(axis=-1) / speed
298
  pred_dur = torch.round(duration).clamp(min=1).long()
299
  pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
@@ -302,12 +300,16 @@ def lf_forward(token_lists, voice, speed):
302
  pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
303
  c_frame += pred_dur[0,i].item()
304
  en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
305
- F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
306
- t_en = model.text_encoder(tokens, input_lengths, text_mask)
307
  asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
308
- outs.append(model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy())
309
  return outs
310
 
 
 
 
 
311
  def resplit_strings(arr):
312
  # Handle edge cases
313
  if not arr:
@@ -360,7 +362,7 @@ def segment_and_tokenize(text, voice, skip_square_brackets=True, newline_split=2
360
  segments = [row for t in texts for row in recursive_split(t, voice)]
361
  return [(i, *row) for i, row in enumerate(segments)]
362
 
363
- def lf_generate(segments, voice, speed=1.0, opening_cut=4000, closing_cut=2000, ease_in=3000, ease_out=1000, pad_before=5000, pad_after=5000, pad_between=10000):
364
  token_lists = list(map(tokenize, segments['Tokens']))
365
  wavs = []
366
  opening_cut = int(opening_cut / speed)
@@ -369,7 +371,10 @@ def lf_generate(segments, voice, speed=1.0, opening_cut=4000, closing_cut=2000,
369
  batch_size = 100
370
  for i in range(0, len(token_lists), batch_size):
371
  try:
372
- outs = lf_forward(token_lists[i:i+batch_size], voice, speed)
 
 
 
373
  except gr.exceptions.Error as e:
374
  if wavs:
375
  gr.Warning(str(e))
@@ -390,12 +395,6 @@ def lf_generate(segments, voice, speed=1.0, opening_cut=4000, closing_cut=2000,
390
  if wavs and pad_between > 0:
391
  wavs.append(np.zeros(pad_between))
392
  wavs.append(out)
393
- pad_before = int(pad_before / speed)
394
- if pad_before > 0:
395
- wavs.insert(0, np.zeros(pad_before))
396
- pad_after = int(pad_after / speed)
397
- if pad_after > 0:
398
- wavs.append(np.zeros(pad_after))
399
  return (SAMPLE_RATE, np.concatenate(wavs)) if wavs else None
400
 
401
  def did_change_segments(segments):
@@ -416,47 +415,45 @@ def extract_text(file):
416
  return None
417
 
418
  with gr.Blocks() as lf_tts:
419
- with gr.Row():
420
- gr.Markdown('Generate speech in batches of 100 text segments and automatically join them together. This may exhaust your ZeroGPU quota.')
421
  with gr.Row():
422
  with gr.Column():
423
  file_input = gr.File(file_types=['.pdf', '.txt'], label='Input File: pdf or txt')
424
- text = gr.Textbox(label='Input Text')
425
  file_input.upload(fn=extract_text, inputs=[file_input], outputs=[text])
426
  voice = gr.Dropdown(list(CHOICES.items()), label='Voice', info='⭐ Starred voices are more stable. 🧪 Experimental voices are less stable.')
427
  with gr.Accordion('Text Settings', open=False):
428
- skip_square_brackets = gr.Checkbox(True, label='Skip [Square Brackets]', info='Recommended for academic papers, Wikipedia articles, or texts with citations.')
429
  newline_split = gr.Number(2, label='Newline Split', info='Split the input text on this many newlines. Affects how the text is segmented.', precision=0, minimum=0)
430
  with gr.Row():
431
  segment_btn = gr.Button('Tokenize', variant='primary')
432
  generate_btn = gr.Button('Generate x0', variant='secondary', interactive=False)
433
  with gr.Column():
434
  audio = gr.Audio(interactive=False, label='Output Audio')
 
 
 
 
 
435
  with gr.Accordion('Audio Settings', open=False):
436
  with gr.Row():
437
- speed = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label='Speed', info='⚡️ Adjust the speed of the audio. The settings below are auto-scaled by speed.')
438
- with gr.Row():
439
- with gr.Column():
440
- opening_cut = gr.Slider(minimum=0, maximum=24000, value=4000, step=1000, label='Opening Cut', info='✂️ Cut this many samples from the start.')
441
- with gr.Column():
442
- closing_cut = gr.Slider(minimum=0, maximum=24000, value=2000, step=1000, label='Closing Cut', info='✂️ Cut this many samples from the end.')
443
  with gr.Row():
444
  with gr.Column():
445
- ease_in = gr.Slider(minimum=0, maximum=24000, value=3000, step=1000, label='Ease In', info='🚀 Ease in for this many samples, after opening cut.')
446
  with gr.Column():
447
- ease_out = gr.Slider(minimum=0, maximum=24000, value=1000, step=1000, label='Ease Out', info='📐 Ease out for this many samples, before closing cut.')
448
  with gr.Row():
449
  with gr.Column():
450
- pad_before = gr.Slider(minimum=0, maximum=24000, value=5000, step=1000, label='Pad Before', info='🔇 How many samples of silence to insert before the start.')
451
  with gr.Column():
452
- pad_after = gr.Slider(minimum=0, maximum=24000, value=5000, step=1000, label='Pad After', info='🔇 How many samples of silence to append after the end.')
453
  with gr.Row():
454
- pad_between = gr.Slider(minimum=0, maximum=24000, value=10000, step=1000, label='Pad Between', info='🔇 How many samples of silence to insert between segments.')
455
  with gr.Row():
456
  segments = gr.Dataframe(headers=['#', 'Text', 'Tokens', 'Length'], row_count=(1, 'dynamic'), col_count=(4, 'fixed'), label='Segments', interactive=False, wrap=True)
457
  segments.change(fn=did_change_segments, inputs=[segments], outputs=[segment_btn, generate_btn])
458
  segment_btn.click(segment_and_tokenize, inputs=[text, voice, skip_square_brackets, newline_split], outputs=[segments])
459
- generate_btn.click(lf_generate, inputs=[segments, voice, speed, opening_cut, closing_cut, ease_in, ease_out, pad_before, pad_after, pad_between], outputs=[audio])
460
 
461
  with gr.Blocks() as about:
462
  gr.Markdown("""
 
12
  import torch
13
  import yaml
14
 
15
+ CUDA_AVAILABLE = torch.cuda.is_available()
16
 
17
  snapshot = snapshot_download(repo_id='hexgrad/kokoro', allow_patterns=['*.pt', '*.pth', '*.yml'], use_auth_token=os.environ['TOKEN'])
18
  config = yaml.safe_load(open(os.path.join(snapshot, 'config.yml')))
19
+
20
+ models = {device: build_model(config['model_params'], device) for device in ['cpu'] + (['cuda'] if CUDA_AVAILABLE else [])}
 
21
  for key, state_dict in torch.load(os.path.join(snapshot, 'net.pth'), map_location='cpu', weights_only=True)['net'].items():
22
+ for device in models:
23
+ assert key in models[device], key
24
+ try:
25
+ models[device][key].load_state_dict(state_dict)
26
+ except:
27
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
28
+ models[device][key].load_state_dict(state_dict, strict=False)
29
 
30
+ PARAM_COUNT = sum(p.numel() for value in models['cpu'].values() for p in value.parameters())
31
  assert PARAM_COUNT < 82_000_000, PARAM_COUNT
32
 
33
  random_texts = {}
 
118
  ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
119
  ps = ''.join(filter(lambda p: p in VOCAB, ps))
120
  if lang == 'j' and any(p in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' for p in ps):
121
+ gr.Warning('Japanese tokenizer does not handle English letters')
122
  return ps.strip()
123
 
124
  def length_to_mask(lengths):
 
154
  '🇬🇧 🚹 Lewis 🧪': 'bm_lewis',
155
  '🇯🇵 🚺 Japanese Female': 'jf_0',
156
  }
157
+ VOICES = {device: {k: torch.load(os.path.join(snapshot, 'voicepacks', f'{k}.pt'), weights_only=True).to(device) for k in CHOICES.values()} for device in models}
158
 
159
  np_log_99 = np.log(99)
160
  def s_curve(p):
 
168
 
169
  SAMPLE_RATE = 24000
170
 
 
171
  @torch.no_grad()
172
+ def forward(tokens, voice, speed, device='cpu'):
173
+ ref_s = VOICES[device][voice][len(tokens)]
174
  tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
175
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
176
  text_mask = length_to_mask(input_lengths).to(device)
177
+ bert_dur = models[device].bert(tokens, attention_mask=(~text_mask).int())
178
+ d_en = models[device].bert_encoder(bert_dur).transpose(-1, -2)
179
  s = ref_s[:, 128:]
180
+ d = models[device].predictor.text_encoder(d_en, s, input_lengths, text_mask)
181
+ x, _ = models[device].predictor.lstm(d)
182
+ duration = models[device].predictor.duration_proj(x)
183
  duration = torch.sigmoid(duration).sum(axis=-1) / speed
184
  pred_dur = torch.round(duration).clamp(min=1).long()
185
  pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
 
188
  pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
189
  c_frame += pred_dur[0,i].item()
190
  en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
191
+ F0_pred, N_pred = models[device].predictor.F0Ntrain(en, s)
192
+ t_en = models[device].text_encoder(tokens, input_lengths, text_mask)
193
  asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
194
+ return models[device].decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
195
+
196
+ @spaces.GPU(duration=10)
197
+ def forward_gpu(tokens, voice, speed):
198
+ return forward(tokens, voice, speed, device='cuda')
199
 
200
+ def generate(text, voice, ps=None, speed=1.0, opening_cut=4000, closing_cut=2000, ease_in=3000, ease_out=1000, use_gpu=None):
201
  if voice not in VOICES:
202
  # Ensure stability for https://huggingface.co/spaces/Pendrokar/TTS-Spaces-Arena
203
  voice = 'af'
 
209
  tokens = tokens[:510]
210
  ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
211
  try:
212
+ if use_gpu or (use_gpu is None and len(ps) > 99):
213
+ out = forward_gpu(tokens, voice, speed)
214
+ else:
215
+ out = forward(tokens, voice, speed)
216
  except gr.exceptions.Error as e:
217
  raise gr.Error(e)
218
  return (None, '')
 
228
  ease_out = min(int(ease_out / speed), len(out)//2)
229
  for i in range(ease_out):
230
  out[-i-1] *= s_curve(i / ease_out)
 
 
 
 
 
 
231
  return ((SAMPLE_RATE, out), ps)
232
 
233
  def toggle_autoplay(autoplay):
234
  return gr.Audio(interactive=False, label='Output Audio', autoplay=autoplay)
235
 
236
  with gr.Blocks() as basic_tts:
 
 
237
  with gr.Row():
238
  with gr.Column():
239
+ text = gr.Textbox(label='Input Text', info='Generate speech for one segment of text using Kokoro, a TTS model with 80 million parameters.')
240
  voice = gr.Dropdown(list(CHOICES.items()), label='Voice', info='⭐ Starred voices are more stable. 🧪 Experimental voices are less stable.')
241
  with gr.Row():
242
  random_btn = gr.Button('Random Text', variant='secondary')
 
250
  phonemize_btn.click(phonemize, inputs=[text, voice], outputs=[in_ps])
251
  with gr.Column():
252
  audio = gr.Audio(interactive=False, label='Output Audio', autoplay=True)
253
+ autoplay = gr.Checkbox(value=True, label='Autoplay')
254
+ autoplay.change(toggle_autoplay, inputs=[autoplay], outputs=[audio])
255
  with gr.Accordion('Output Tokens', open=True):
256
  out_ps = gr.Textbox(interactive=False, show_label=False, info='Tokens used to generate the audio, up to 510 allowed. Same as input tokens if supplied, excluding unknowns.')
257
+ with gr.Row():
258
+ use_gpu = gr.Radio(
259
+ [('CPU', False), ('Force GPU', True), ('Dynamic', None)],
260
+ value=None if CUDA_AVAILABLE else False, label='⚙️ Hardware',
261
+ info='CPU: unlimited, ~faster <100 tokens. GPU: limited usage quota, ~faster 100+ tokens. Dynamic: switches based on # of tokens.',
262
+ interactive=CUDA_AVAILABLE
263
+ )
264
  with gr.Accordion('Audio Settings', open=False):
265
  with gr.Row():
266
+ speed = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label='⚡️ Speed', info='Adjust the speed of the audio; the settings below are auto-scaled by speed')
 
 
267
  with gr.Row():
268
  with gr.Column():
269
+ opening_cut = gr.Slider(minimum=0, maximum=24000, value=4000, step=1000, label='✂️ Opening Cut', info='Cut this many samples from the start')
270
  with gr.Column():
271
+ closing_cut = gr.Slider(minimum=0, maximum=24000, value=2000, step=1000, label='🎬 Closing Cut', info='Cut this many samples from the end')
272
  with gr.Row():
273
  with gr.Column():
274
+ ease_in = gr.Slider(minimum=0, maximum=24000, value=3000, step=1000, label='🎢 Ease In', info='Ease in for this many samples, after opening cut')
275
  with gr.Column():
276
+ ease_out = gr.Slider(minimum=0, maximum=24000, value=1000, step=1000, label='🛝 Ease Out', info='Ease out for this many samples, before closing cut')
277
+ text.submit(generate, inputs=[text, voice, in_ps, speed, opening_cut, closing_cut, ease_in, ease_out, use_gpu], outputs=[audio, out_ps])
278
+ generate_btn.click(generate, inputs=[text, voice, in_ps, speed, opening_cut, closing_cut, ease_in, ease_out, use_gpu], outputs=[audio, out_ps])
 
 
 
 
 
 
279
 
 
280
  @torch.no_grad()
281
+ def lf_forward(token_lists, voice, speed, device='cpu'):
282
+ voicepack = VOICES[device][voice]
283
  outs = []
284
  for tokens in token_lists:
285
  ref_s = voicepack[len(tokens)]
 
287
  tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
288
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
289
  text_mask = length_to_mask(input_lengths).to(device)
290
+ bert_dur = models[device].bert(tokens, attention_mask=(~text_mask).int())
291
+ d_en = models[device].bert_encoder(bert_dur).transpose(-1, -2)
292
+ d = models[device].predictor.text_encoder(d_en, s, input_lengths, text_mask)
293
+ x, _ = models[device].predictor.lstm(d)
294
+ duration = models[device].predictor.duration_proj(x)
295
  duration = torch.sigmoid(duration).sum(axis=-1) / speed
296
  pred_dur = torch.round(duration).clamp(min=1).long()
297
  pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
 
300
  pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
301
  c_frame += pred_dur[0,i].item()
302
  en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
303
+ F0_pred, N_pred = models[device].predictor.F0Ntrain(en, s)
304
+ t_en = models[device].text_encoder(tokens, input_lengths, text_mask)
305
  asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
306
+ outs.append(models[device].decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy())
307
  return outs
308
 
309
+ @spaces.GPU
310
+ def lf_forward_gpu(token_lists, voice, speed):
311
+ return lf_forward(token_lists, voice, speed, device='cuda')
312
+
313
  def resplit_strings(arr):
314
  # Handle edge cases
315
  if not arr:
 
362
  segments = [row for t in texts for row in recursive_split(t, voice)]
363
  return [(i, *row) for i, row in enumerate(segments)]
364
 
365
+ def lf_generate(segments, voice, speed=1.0, opening_cut=4000, closing_cut=2000, ease_in=3000, ease_out=1000, pad_between=10000, use_gpu=True):
366
  token_lists = list(map(tokenize, segments['Tokens']))
367
  wavs = []
368
  opening_cut = int(opening_cut / speed)
 
371
  batch_size = 100
372
  for i in range(0, len(token_lists), batch_size):
373
  try:
374
+ if use_gpu:
375
+ outs = lf_forward_gpu(token_lists[i:i+batch_size], voice, speed)
376
+ else:
377
+ outs = lf_forward(token_lists[i:i+batch_size], voice, speed)
378
  except gr.exceptions.Error as e:
379
  if wavs:
380
  gr.Warning(str(e))
 
395
  if wavs and pad_between > 0:
396
  wavs.append(np.zeros(pad_between))
397
  wavs.append(out)
 
 
 
 
 
 
398
  return (SAMPLE_RATE, np.concatenate(wavs)) if wavs else None
399
 
400
  def did_change_segments(segments):
 
415
  return None
416
 
417
  with gr.Blocks() as lf_tts:
 
 
418
  with gr.Row():
419
  with gr.Column():
420
  file_input = gr.File(file_types=['.pdf', '.txt'], label='Input File: pdf or txt')
421
+ text = gr.Textbox(label='Input Text', info='Generate speech in batches of 100 text segments and automatically join them together.')
422
  file_input.upload(fn=extract_text, inputs=[file_input], outputs=[text])
423
  voice = gr.Dropdown(list(CHOICES.items()), label='Voice', info='⭐ Starred voices are more stable. 🧪 Experimental voices are less stable.')
424
  with gr.Accordion('Text Settings', open=False):
425
+ skip_square_brackets = gr.Checkbox(True, label='Skip [Square Brackets]', info='Recommended for academic papers, Wikipedia articles, or texts with citations')
426
  newline_split = gr.Number(2, label='Newline Split', info='Split the input text on this many newlines. Affects how the text is segmented.', precision=0, minimum=0)
427
  with gr.Row():
428
  segment_btn = gr.Button('Tokenize', variant='primary')
429
  generate_btn = gr.Button('Generate x0', variant='secondary', interactive=False)
430
  with gr.Column():
431
  audio = gr.Audio(interactive=False, label='Output Audio')
432
+ use_gpu = gr.Checkbox(value=CUDA_AVAILABLE, label='Use ZeroGPU', info='🚀 ZeroGPU is fast but has a limited usage quota', interactive=CUDA_AVAILABLE)
433
+ use_gpu.change(
434
+ fn=lambda v: gr.Checkbox(value=v, label='Use ZeroGPU', info='🚀 ZeroGPU is fast but has a limited usage quota' if v else '🐌 CPU is slow but unlimited'),
435
+ inputs=[use_gpu], outputs=[use_gpu]
436
+ )
437
  with gr.Accordion('Audio Settings', open=False):
438
  with gr.Row():
439
+ speed = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label='⚡️ Speed', info='Adjust the speed of the audio; the settings below are auto-scaled by speed')
 
 
 
 
 
440
  with gr.Row():
441
  with gr.Column():
442
+ opening_cut = gr.Slider(minimum=0, maximum=24000, value=4000, step=1000, label='✂️ Opening Cut', info='Cut this many samples from the start')
443
  with gr.Column():
444
+ closing_cut = gr.Slider(minimum=0, maximum=24000, value=2000, step=1000, label='🎬 Closing Cut', info='Cut this many samples from the end')
445
  with gr.Row():
446
  with gr.Column():
447
+ ease_in = gr.Slider(minimum=0, maximum=24000, value=3000, step=1000, label='🎢 Ease In', info='Ease in for this many samples, after opening cut')
448
  with gr.Column():
449
+ ease_out = gr.Slider(minimum=0, maximum=24000, value=1000, step=1000, label='🛝 Ease Out', info='Ease out for this many samples, before closing cut')
450
  with gr.Row():
451
+ pad_between = gr.Slider(minimum=0, maximum=24000, value=10000, step=1000, label='🔇 Pad Between', info='How many samples of silence to insert between segments')
452
  with gr.Row():
453
  segments = gr.Dataframe(headers=['#', 'Text', 'Tokens', 'Length'], row_count=(1, 'dynamic'), col_count=(4, 'fixed'), label='Segments', interactive=False, wrap=True)
454
  segments.change(fn=did_change_segments, inputs=[segments], outputs=[segment_btn, generate_btn])
455
  segment_btn.click(segment_and_tokenize, inputs=[text, voice, skip_square_brackets, newline_split], outputs=[segments])
456
+ generate_btn.click(lf_generate, inputs=[segments, voice, speed, opening_cut, closing_cut, ease_in, ease_out, pad_between, use_gpu], outputs=[audio])
457
 
458
  with gr.Blocks() as about:
459
  gr.Markdown("""
models.py CHANGED
@@ -549,7 +549,7 @@ def recursive_munch(d):
549
  else:
550
  return d
551
 
552
- def build_model(args):
553
  args = recursive_munch(args)
554
  assert args.decoder.type == 'istftnet', 'Decoder type unknown'
555
  decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
@@ -562,10 +562,16 @@ def build_model(args):
562
  text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
563
  predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
564
  bert = load_plbert()
565
- return Munch(
566
- bert=bert,
567
- bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
568
- predictor=predictor,
569
- decoder=decoder,
570
- text_encoder=text_encoder,
 
 
 
 
 
571
  )
 
 
549
  else:
550
  return d
551
 
552
+ def build_model(args, device):
553
  args = recursive_munch(args)
554
  assert args.decoder.type == 'istftnet', 'Decoder type unknown'
555
  decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
 
562
  text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
563
  predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
564
  bert = load_plbert()
565
+ bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
566
+ for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
567
+ for child in parent.children():
568
+ if isinstance(child, nn.RNNBase):
569
+ child.flatten_parameters()
570
+ model = Munch(
571
+ bert=bert.to(device).eval(),
572
+ bert_encoder=bert_encoder.to(device).eval(),
573
+ predictor=predictor.to(device).eval(),
574
+ decoder=decoder.to(device).eval(),
575
+ text_encoder=text_encoder.to(device).eval(),
576
  )
577
+ return model