OpenSound commited on
Commit
ce5a339
·
verified ·
1 Parent(s): 82d6b31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +360 -360
app.py CHANGED
@@ -12,8 +12,8 @@ from data.tokenizer import (
12
  )
13
  from edit_utils_en import parse_edit_en
14
  from edit_utils_en import parse_tts_en
15
- # from edit_utils_zh import parse_edit_zh
16
- # from edit_utils_zh import parse_tts_zh
17
  from inference_scale import inference_one_sample
18
  import librosa
19
  import soundfile as sf
@@ -59,18 +59,18 @@ if not os.path.exists(os.path.join(MODELS_PATH, "English.pth")):
59
  else:
60
  print("english model found")
61
 
62
- # if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
63
- # # download mandarin model
64
- # url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth"
65
- # filename = os.path.join(MODELS_PATH, "Mandarin.pth")
66
- # response = requests.get(url, stream=True)
67
- # response.raise_for_status()
68
- # with open(filename, "wb") as file:
69
- # for chunk in response.iter_content(chunk_size=8192):
70
- # file.write(chunk)
71
- # print(f"File downloaded to: {filename}")
72
- # else:
73
- # print("mandarin model found")
74
 
75
  def get_random_string():
76
  return "".join(str(uuid.uuid4()).split("-"))
@@ -130,7 +130,7 @@ from whisperx import align as align_func
130
 
131
  # Load models
132
  text_tokenizer_en = TextTokenizer(backend="espeak")
133
- # text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn')
134
 
135
  ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
136
  ckpt_en = torch.load(ssrspeech_fn_en)
@@ -140,13 +140,13 @@ config_en = model_en.args
140
  phn2num_en = ckpt_en["phn2num"]
141
  model_en.to(device)
142
 
143
- # ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
144
- # ckpt_zh = torch.load(ssrspeech_fn_zh)
145
- # model_zh = ssr.SSR_Speech(ckpt_zh["config"])
146
- # model_zh.load_state_dict(ckpt_zh["model"])
147
- # config_zh = model_zh.args
148
- # phn2num_zh = ckpt_zh["phn2num"]
149
- # model_zh.to(device)
150
 
151
  encodec_fn = f"{MODELS_PATH}/wmencodec.th"
152
 
@@ -158,13 +158,13 @@ ssrspeech_model_en = {
158
  "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
159
  }
160
 
161
- # ssrspeech_model_zh = {
162
- # "config": config_zh,
163
- # "phn2num": phn2num_zh,
164
- # "model": model_zh,
165
- # "text_tokenizer": text_tokenizer_zh,
166
- # "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
167
- # }
168
 
169
 
170
  def get_transcribe_state(segments):
@@ -192,21 +192,21 @@ def transcribe_en(audio_path):
192
  state, success_message
193
  ]
194
 
195
- # @spaces.GPU
196
- # def transcribe_zh(audio_path):
197
- # language = "zh"
198
- # transcribe_model_name = "medium"
199
- # transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
200
- # segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
201
- # _, segments = align_zh(segments, audio_path)
202
- # state = get_transcribe_state(segments)
203
- # success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
204
- # converter = opencc.OpenCC('t2s')
205
- # state["transcript"] = converter.convert(state["transcript"])
206
- # return [
207
- # state["transcript"], state['segments'],
208
- # state, success_message
209
- # ]
210
 
211
  @spaces.GPU
212
  def align_en(segments, audio_path):
@@ -219,15 +219,15 @@ def align_en(segments, audio_path):
219
  return state, segments
220
 
221
 
222
- # @spaces.GPU
223
- # def align_zh(segments, audio_path):
224
- # language = "zh"
225
- # align_model, metadata = load_align_model(language_code=language, device=device)
226
- # audio = load_audio(audio_path)
227
- # segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
228
- # state = get_transcribe_state(segments)
229
 
230
- # return state, segments
231
 
232
 
233
  def get_output_audio(audio_tensors, codec_audio_sr):
@@ -442,210 +442,210 @@ def run_tts_en(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
442
  return output_audio, success_message
443
 
444
 
445
- # @spaces.GPU
446
- # def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
447
- # audio_path, original_transcript, transcript):
448
 
449
- # codec_audio_sr = 16000
450
- # codec_sr = 50
451
- # top_k = 0
452
- # top_p = 0.8
453
- # temperature = 1
454
- # kvcache = 1
455
- # stop_repetition = 2
456
 
457
- # aug_text = True if aug_text == 1 else False
458
 
459
- # seed_everything(seed)
460
 
461
- # # resample audio
462
- # audio, _ = librosa.load(audio_path, sr=16000)
463
- # sf.write(audio_path, audio, 16000)
464
 
465
- # # text normalization
466
- # target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
467
- # orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
468
 
469
- # [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
470
 
471
- # converter = opencc.OpenCC('t2s')
472
- # orig_transcript = converter.convert(orig_transcript)
473
- # transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
474
- # transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
475
 
476
- # print(orig_transcript)
477
- # print(target_transcript)
478
 
479
- # operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript)
480
- # print(operations)
481
- # print("orig_spans: ", orig_spans)
482
 
483
- # if len(orig_spans) > 3:
484
- # raise gr.Error("Current model only supports maximum 3 editings")
485
 
486
- # starting_intervals = []
487
- # ending_intervals = []
488
- # for orig_span in orig_spans:
489
- # start, end = get_mask_interval(transcribe_state, orig_span)
490
- # starting_intervals.append(start)
491
- # ending_intervals.append(end)
492
-
493
- # print("intervals: ", starting_intervals, ending_intervals)
494
-
495
- # info = torchaudio.info(audio_path)
496
- # audio_dur = info.num_frames / info.sample_rate
497
-
498
- # def combine_spans(spans, threshold=0.2):
499
- # spans.sort(key=lambda x: x[0])
500
- # combined_spans = []
501
- # current_span = spans[0]
502
-
503
- # for i in range(1, len(spans)):
504
- # next_span = spans[i]
505
- # if current_span[1] >= next_span[0] - threshold:
506
- # current_span[1] = max(current_span[1], next_span[1])
507
- # else:
508
- # combined_spans.append(current_span)
509
- # current_span = next_span
510
- # combined_spans.append(current_span)
511
- # return combined_spans
512
-
513
- # morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
514
- # for start, end in zip(starting_intervals, ending_intervals)] # in seconds
515
- # morphed_span = combine_spans(morphed_span, threshold=0.2)
516
- # print("morphed_spans: ", morphed_span)
517
- # mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
518
- # mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
519
-
520
- # decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
521
-
522
- # new_audio = inference_one_sample(
523
- # ssrspeech_model_zh["model"],
524
- # ssrspeech_model_zh["config"],
525
- # ssrspeech_model_zh["phn2num"],
526
- # ssrspeech_model_zh["text_tokenizer"],
527
- # ssrspeech_model_zh["audio_tokenizer"],
528
- # audio_path, orig_transcript, target_transcript, mask_interval,
529
- # cfg_coef, cfg_stride, aug_text, False, True, False,
530
- # device, decode_config
531
- # )
532
- # audio_tensors = []
533
- # # save segments for comparison
534
- # new_audio = new_audio[0].cpu()
535
- # torchaudio.save(audio_path, new_audio, codec_audio_sr)
536
- # audio_tensors.append(new_audio)
537
- # output_audio = get_output_audio(audio_tensors, codec_audio_sr)
538
-
539
- # success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
540
- # return output_audio, success_message
541
-
542
-
543
- # @spaces.GPU
544
- # def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
545
- # audio_path, original_transcript, transcript):
546
-
547
- # codec_audio_sr = 16000
548
- # codec_sr = 50
549
- # top_k = 0
550
- # top_p = 0.8
551
- # temperature = 1
552
- # kvcache = 1
553
- # stop_repetition = 2
554
-
555
- # aug_text = True if aug_text == 1 else False
556
-
557
- # seed_everything(seed)
558
-
559
- # # resample audio
560
- # audio, _ = librosa.load(audio_path, sr=16000)
561
- # sf.write(audio_path, audio, 16000)
562
-
563
- # # text normalization
564
- # target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
565
- # orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
566
-
567
- # [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
568
-
569
- # converter = opencc.OpenCC('t2s')
570
- # orig_transcript = converter.convert(orig_transcript)
571
- # transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
572
- # transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
573
-
574
- # print(orig_transcript)
575
- # print(target_transcript)
576
-
577
- # info = torchaudio.info(audio_path)
578
- # duration = info.num_frames / info.sample_rate
579
- # cut_length = duration
580
- # # Cut long audio for tts
581
- # if duration > prompt_length:
582
- # seg_num = len(transcribe_state['segments'])
583
- # for i in range(seg_num):
584
- # words = transcribe_state['segments'][i]['words']
585
- # for item in words:
586
- # if item['end'] >= prompt_length:
587
- # cut_length = min(item['end'], cut_length)
588
-
589
- # audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
590
- # sf.write(audio_path, audio, 16000)
591
- # [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
592
-
593
-
594
- # converter = opencc.OpenCC('t2s')
595
- # orig_transcript = converter.convert(orig_transcript)
596
- # transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
597
- # transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
598
-
599
- # print(orig_transcript)
600
- # target_transcript_copy = target_transcript # for tts cut out
601
- # target_transcript_copy = target_transcript_copy[0]
602
- # target_transcript = orig_transcript + target_transcript
603
- # print(target_transcript)
604
-
605
-
606
- # info = torchaudio.info(audio_path)
607
- # audio_dur = info.num_frames / info.sample_rate
608
-
609
- # morphed_span = [(audio_dur, audio_dur)] # in seconds
610
- # mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
611
- # mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
612
- # print("mask_interval: ", mask_interval)
613
-
614
- # decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
615
-
616
- # new_audio = inference_one_sample(
617
- # ssrspeech_model_zh["model"],
618
- # ssrspeech_model_zh["config"],
619
- # ssrspeech_model_zh["phn2num"],
620
- # ssrspeech_model_zh["text_tokenizer"],
621
- # ssrspeech_model_zh["audio_tokenizer"],
622
- # audio_path, orig_transcript, target_transcript, mask_interval,
623
- # cfg_coef, cfg_stride, aug_text, False, True, True,
624
- # device, decode_config
625
- # )
626
- # audio_tensors = []
627
- # # save segments for comparison
628
- # new_audio = new_audio[0].cpu()
629
- # torchaudio.save(audio_path, new_audio, codec_audio_sr)
630
-
631
- # [new_transcript, new_segments, _,_] = transcribe_zh(audio_path)
632
-
633
- # transcribe_state,_ = align_zh(traditional_to_simplified(new_segments), audio_path)
634
- # transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
635
- # tmp1 = transcribe_state['segments'][0]['words'][0]['word']
636
- # tmp2 = target_transcript_copy
637
 
638
- # if tmp1 == tmp2:
639
- # offset = transcribe_state['segments'][0]['words'][0]['start']
640
- # else:
641
- # offset = transcribe_state['segments'][0]['words'][1]['start']
642
-
643
- # new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
644
- # audio_tensors.append(new_audio)
645
- # output_audio = get_output_audio(audio_tensors, codec_audio_sr)
646
-
647
- # success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
648
- # return output_audio, success_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
 
651
  if __name__ == "__main__":
@@ -815,131 +815,131 @@ if __name__ == "__main__":
815
  outputs=[output_audio, success_output]
816
  )
817
 
818
- # with gr.Tab("Mandarin Speech Editing"):
819
 
820
- # with gr.Row():
821
- # with gr.Column(scale=2):
822
- # input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
823
- # with gr.Group():
824
- # original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
825
- # info="Use whisperx model to get the transcript.")
826
- # transcribe_btn = gr.Button(value="Transcribe")
827
-
828
- # with gr.Column(scale=3):
829
- # with gr.Group():
830
- # transcript = gr.Textbox(label="Text", lines=7, value="价格已基本都在一万到两万之间", interactive=True)
831
- # run_btn = gr.Button(value="Run")
832
-
833
- # with gr.Column(scale=2):
834
- # output_audio = gr.Audio(label="Output Audio")
835
 
836
- # with gr.Row():
837
- # with gr.Accordion("Advanced Settings", open=False):
838
- # seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
839
- # aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
840
- # info="set to 1 to use classifer-free guidance, change if you don't like the results")
841
- # cfg_coef = gr.Number(label="cfg_coef", value=1.5,
842
- # info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
843
- # cfg_stride = gr.Number(label="cfg_stride", value=1,
844
- # info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
845
- # prompt_length = gr.Number(label="prompt_length", value=3,
846
- # info="used for tts prompt, will automatically cut the prompt audio to this length")
847
- # sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
848
-
849
- # success_output = gr.HTML()
850
-
851
- # semgents = gr.State() # not used
852
- # state = gr.State() # not used
853
- # audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
854
- # input_audio.change(
855
- # lambda audio: audio,
856
- # inputs=[input_audio],
857
- # outputs=[audio_state]
858
- # )
859
 
860
- # transcribe_btn.click(fn=transcribe_zh,
861
- # inputs=[audio_state],
862
- # outputs=[original_transcript, semgents, state, success_output])
863
 
864
- # run_btn.click(fn=run_edit_zh,
865
- # inputs=[
866
- # seed, sub_amount,
867
- # aug_text, cfg_coef, cfg_stride, prompt_length,
868
- # audio_state, original_transcript, transcript,
869
- # ],
870
- # outputs=[output_audio, success_output])
871
-
872
- # transcript.submit(fn=run_edit_zh,
873
- # inputs=[
874
- # seed, sub_amount,
875
- # aug_text, cfg_coef, cfg_stride, prompt_length,
876
- # audio_state, original_transcript, transcript,
877
- # ],
878
- # outputs=[output_audio, success_output]
879
- # )
880
 
881
- # with gr.Tab("Mandarin TTS"):
882
 
883
- # with gr.Row():
884
- # with gr.Column(scale=2):
885
- # input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
886
- # with gr.Group():
887
- # original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
888
- # info="Use whisperx model to get the transcript.")
889
- # transcribe_btn = gr.Button(value="Transcribe")
890
-
891
- # with gr.Column(scale=3):
892
- # with gr.Group():
893
- # transcript = gr.Textbox(label="Text", lines=7, value="我简直不敢相信同一个模型也可以进行文本到语音的生成", interactive=True)
894
- # run_btn = gr.Button(value="Run")
895
-
896
- # with gr.Column(scale=2):
897
- # output_audio = gr.Audio(label="Output Audio")
898
 
899
- # with gr.Row():
900
- # with gr.Accordion("Advanced Settings", open=False):
901
- # seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
902
- # aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
903
- # info="set to 1 to use classifer-free guidance, change if you don't like the results")
904
- # cfg_coef = gr.Number(label="cfg_coef", value=1.5,
905
- # info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
906
- # cfg_stride = gr.Number(label="cfg_stride", value=1,
907
- # info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
908
- # prompt_length = gr.Number(label="prompt_length", value=3,
909
- # info="used for tts prompt, will automatically cut the prompt audio to this length")
910
- # sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
911
-
912
- # success_output = gr.HTML()
913
-
914
- # semgents = gr.State() # not used
915
- # state = gr.State() # not used
916
- # audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
917
- # input_audio.change(
918
- # lambda audio: audio,
919
- # inputs=[input_audio],
920
- # outputs=[audio_state]
921
- # )
922
 
923
- # transcribe_btn.click(fn=transcribe_zh,
924
- # inputs=[audio_state],
925
- # outputs=[original_transcript, semgents, state, success_output])
926
 
927
- # run_btn.click(fn=run_tts_zh,
928
- # inputs=[
929
- # seed, sub_amount,
930
- # aug_text, cfg_coef, cfg_stride, prompt_length,
931
- # audio_state, original_transcript, transcript,
932
- # ],
933
- # outputs=[output_audio, success_output])
934
-
935
- # transcript.submit(fn=run_tts_zh,
936
- # inputs=[
937
- # seed, sub_amount,
938
- # aug_text, cfg_coef, cfg_stride, prompt_length,
939
- # audio_state, original_transcript, transcript,
940
- # ],
941
- # outputs=[output_audio, success_output]
942
- # )
943
 
944
  # Launch the Gradio demo
945
  demo.launch()
 
12
  )
13
  from edit_utils_en import parse_edit_en
14
  from edit_utils_en import parse_tts_en
15
+ from edit_utils_zh import parse_edit_zh
16
+ from edit_utils_zh import parse_tts_zh
17
  from inference_scale import inference_one_sample
18
  import librosa
19
  import soundfile as sf
 
59
  else:
60
  print("english model found")
61
 
62
+ if not os.path.exists(os.path.join(MODELS_PATH, "Mandarin.pth")):
63
+ # download mandarin model
64
+ url = "https://huggingface.co/westbrook/SSR-Speech-Mandarin/resolve/main/Mandarin.pth"
65
+ filename = os.path.join(MODELS_PATH, "Mandarin.pth")
66
+ response = requests.get(url, stream=True)
67
+ response.raise_for_status()
68
+ with open(filename, "wb") as file:
69
+ for chunk in response.iter_content(chunk_size=8192):
70
+ file.write(chunk)
71
+ print(f"File downloaded to: {filename}")
72
+ else:
73
+ print("mandarin model found")
74
 
75
  def get_random_string():
76
  return "".join(str(uuid.uuid4()).split("-"))
 
130
 
131
  # Load models
132
  text_tokenizer_en = TextTokenizer(backend="espeak")
133
+ text_tokenizer_zh = TextTokenizer(backend="espeak", language='cmn-latn-pinyin')
134
 
135
  ssrspeech_fn_en = f"{MODELS_PATH}/English.pth"
136
  ckpt_en = torch.load(ssrspeech_fn_en)
 
140
  phn2num_en = ckpt_en["phn2num"]
141
  model_en.to(device)
142
 
143
+ ssrspeech_fn_zh = f"{MODELS_PATH}/Mandarin.pth"
144
+ ckpt_zh = torch.load(ssrspeech_fn_zh)
145
+ model_zh = ssr.SSR_Speech(ckpt_zh["config"])
146
+ model_zh.load_state_dict(ckpt_zh["model"])
147
+ config_zh = model_zh.args
148
+ phn2num_zh = ckpt_zh["phn2num"]
149
+ model_zh.to(device)
150
 
151
  encodec_fn = f"{MODELS_PATH}/wmencodec.th"
152
 
 
158
  "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
159
  }
160
 
161
+ ssrspeech_model_zh = {
162
+ "config": config_zh,
163
+ "phn2num": phn2num_zh,
164
+ "model": model_zh,
165
+ "text_tokenizer": text_tokenizer_zh,
166
+ "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
167
+ }
168
 
169
 
170
  def get_transcribe_state(segments):
 
192
  state, success_message
193
  ]
194
 
195
+ @spaces.GPU
196
+ def transcribe_zh(audio_path):
197
+ language = "zh"
198
+ transcribe_model_name = "medium"
199
+ transcribe_model = load_model(transcribe_model_name, device, asr_options={"suppress_numerals": True, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None}, language=language)
200
+ segments = transcribe_model.transcribe(audio_path, batch_size=8)["segments"]
201
+ _, segments = align_zh(segments, audio_path)
202
+ state = get_transcribe_state(segments)
203
+ success_message = "<span style='color:green;'>Success: Transcribe completed successfully!</span>"
204
+ converter = opencc.OpenCC('t2s')
205
+ state["transcript"] = converter.convert(state["transcript"])
206
+ return [
207
+ state["transcript"], state['segments'],
208
+ state, success_message
209
+ ]
210
 
211
  @spaces.GPU
212
  def align_en(segments, audio_path):
 
219
  return state, segments
220
 
221
 
222
+ @spaces.GPU
223
+ def align_zh(segments, audio_path):
224
+ language = "zh"
225
+ align_model, metadata = load_align_model(language_code=language, device=device)
226
+ audio = load_audio(audio_path)
227
+ segments = align_func(segments, align_model, metadata, audio, device, return_char_alignments=False)["segments"]
228
+ state = get_transcribe_state(segments)
229
 
230
+ return state, segments
231
 
232
 
233
  def get_output_audio(audio_tensors, codec_audio_sr):
 
442
  return output_audio, success_message
443
 
444
 
445
+ @spaces.GPU
446
+ def run_edit_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
447
+ audio_path, original_transcript, transcript):
448
 
449
+ codec_audio_sr = 16000
450
+ codec_sr = 50
451
+ top_k = 0
452
+ top_p = 0.8
453
+ temperature = 1
454
+ kvcache = 1
455
+ stop_repetition = 2
456
 
457
+ aug_text = True if aug_text == 1 else False
458
 
459
+ seed_everything(seed)
460
 
461
+ # resample audio
462
+ audio, _ = librosa.load(audio_path, sr=16000)
463
+ sf.write(audio_path, audio, 16000)
464
 
465
+ # text normalization
466
+ target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
467
+ orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
468
 
469
+ [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
470
 
471
+ converter = opencc.OpenCC('t2s')
472
+ orig_transcript = converter.convert(orig_transcript)
473
+ transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
474
+ transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
475
 
476
+ print(orig_transcript)
477
+ print(target_transcript)
478
 
479
+ operations, orig_spans = parse_edit_zh(orig_transcript, target_transcript)
480
+ print(operations)
481
+ print("orig_spans: ", orig_spans)
482
 
483
+ if len(orig_spans) > 3:
484
+ raise gr.Error("Current model only supports maximum 3 editings")
485
 
486
+ starting_intervals = []
487
+ ending_intervals = []
488
+ for orig_span in orig_spans:
489
+ start, end = get_mask_interval(transcribe_state, orig_span)
490
+ starting_intervals.append(start)
491
+ ending_intervals.append(end)
492
+
493
+ print("intervals: ", starting_intervals, ending_intervals)
494
+
495
+ info = torchaudio.info(audio_path)
496
+ audio_dur = info.num_frames / info.sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
+ def combine_spans(spans, threshold=0.2):
499
+ spans.sort(key=lambda x: x[0])
500
+ combined_spans = []
501
+ current_span = spans[0]
502
+
503
+ for i in range(1, len(spans)):
504
+ next_span = spans[i]
505
+ if current_span[1] >= next_span[0] - threshold:
506
+ current_span[1] = max(current_span[1], next_span[1])
507
+ else:
508
+ combined_spans.append(current_span)
509
+ current_span = next_span
510
+ combined_spans.append(current_span)
511
+ return combined_spans
512
+
513
+ morphed_span = [[max(start - sub_amount, 0), min(end + sub_amount, audio_dur)]
514
+ for start, end in zip(starting_intervals, ending_intervals)] # in seconds
515
+ morphed_span = combine_spans(morphed_span, threshold=0.2)
516
+ print("morphed_spans: ", morphed_span)
517
+ mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
518
+ mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
519
+
520
+ decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
521
+
522
+ new_audio = inference_one_sample(
523
+ ssrspeech_model_zh["model"],
524
+ ssrspeech_model_zh["config"],
525
+ ssrspeech_model_zh["phn2num"],
526
+ ssrspeech_model_zh["text_tokenizer"],
527
+ ssrspeech_model_zh["audio_tokenizer"],
528
+ audio_path, orig_transcript, target_transcript, mask_interval,
529
+ cfg_coef, cfg_stride, aug_text, False, True, False,
530
+ device, decode_config
531
+ )
532
+ audio_tensors = []
533
+ # save segments for comparison
534
+ new_audio = new_audio[0].cpu()
535
+ torchaudio.save(audio_path, new_audio, codec_audio_sr)
536
+ audio_tensors.append(new_audio)
537
+ output_audio = get_output_audio(audio_tensors, codec_audio_sr)
538
+
539
+ success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
540
+ return output_audio, success_message
541
+
542
+
543
+ @spaces.GPU
544
+ def run_tts_zh(seed, sub_amount, aug_text, cfg_coef, cfg_stride, prompt_length,
545
+ audio_path, original_transcript, transcript):
546
+
547
+ codec_audio_sr = 16000
548
+ codec_sr = 50
549
+ top_k = 0
550
+ top_p = 0.8
551
+ temperature = 1
552
+ kvcache = 1
553
+ stop_repetition = 2
554
+
555
+ aug_text = True if aug_text == 1 else False
556
+
557
+ seed_everything(seed)
558
+
559
+ # resample audio
560
+ audio, _ = librosa.load(audio_path, sr=16000)
561
+ sf.write(audio_path, audio, 16000)
562
+
563
+ # text normalization
564
+ target_transcript = transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
565
+ orig_transcript = original_transcript.replace(" ", " ").replace(" ", " ").replace("\n", " ")
566
+
567
+ [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
568
+
569
+ converter = opencc.OpenCC('t2s')
570
+ orig_transcript = converter.convert(orig_transcript)
571
+ transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
572
+ transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
573
+
574
+ print(orig_transcript)
575
+ print(target_transcript)
576
+
577
+ info = torchaudio.info(audio_path)
578
+ duration = info.num_frames / info.sample_rate
579
+ cut_length = duration
580
+ # Cut long audio for tts
581
+ if duration > prompt_length:
582
+ seg_num = len(transcribe_state['segments'])
583
+ for i in range(seg_num):
584
+ words = transcribe_state['segments'][i]['words']
585
+ for item in words:
586
+ if item['end'] >= prompt_length:
587
+ cut_length = min(item['end'], cut_length)
588
+
589
+ audio, _ = librosa.load(audio_path, sr=16000, duration=cut_length)
590
+ sf.write(audio_path, audio, 16000)
591
+ [orig_transcript, segments, _, _] = transcribe_zh(audio_path)
592
+
593
+
594
+ converter = opencc.OpenCC('t2s')
595
+ orig_transcript = converter.convert(orig_transcript)
596
+ transcribe_state,_ = align_zh(traditional_to_simplified(segments), audio_path)
597
+ transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
598
+
599
+ print(orig_transcript)
600
+ target_transcript_copy = target_transcript # for tts cut out
601
+ target_transcript_copy = target_transcript_copy[0]
602
+ target_transcript = orig_transcript + target_transcript
603
+ print(target_transcript)
604
+
605
+
606
+ info = torchaudio.info(audio_path)
607
+ audio_dur = info.num_frames / info.sample_rate
608
+
609
+ morphed_span = [(audio_dur, audio_dur)] # in seconds
610
+ mask_interval = [[round(span[0]*codec_sr), round(span[1]*codec_sr)] for span in morphed_span]
611
+ mask_interval = torch.LongTensor(mask_interval) # [M,2], M==1 for now
612
+ print("mask_interval: ", mask_interval)
613
+
614
+ decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr}
615
+
616
+ new_audio = inference_one_sample(
617
+ ssrspeech_model_zh["model"],
618
+ ssrspeech_model_zh["config"],
619
+ ssrspeech_model_zh["phn2num"],
620
+ ssrspeech_model_zh["text_tokenizer"],
621
+ ssrspeech_model_zh["audio_tokenizer"],
622
+ audio_path, orig_transcript, target_transcript, mask_interval,
623
+ cfg_coef, cfg_stride, aug_text, False, True, True,
624
+ device, decode_config
625
+ )
626
+ audio_tensors = []
627
+ # save segments for comparison
628
+ new_audio = new_audio[0].cpu()
629
+ torchaudio.save(audio_path, new_audio, codec_audio_sr)
630
+
631
+ [new_transcript, new_segments, _,_] = transcribe_zh(audio_path)
632
+
633
+ transcribe_state,_ = align_zh(traditional_to_simplified(new_segments), audio_path)
634
+ transcribe_state['segments'] = traditional_to_simplified(transcribe_state['segments'])
635
+ tmp1 = transcribe_state['segments'][0]['words'][0]['word']
636
+ tmp2 = target_transcript_copy
637
+
638
+ if tmp1 == tmp2:
639
+ offset = transcribe_state['segments'][0]['words'][0]['start']
640
+ else:
641
+ offset = transcribe_state['segments'][0]['words'][1]['start']
642
+
643
+ new_audio, _ = torchaudio.load(audio_path, frame_offset=int(offset*codec_audio_sr))
644
+ audio_tensors.append(new_audio)
645
+ output_audio = get_output_audio(audio_tensors, codec_audio_sr)
646
+
647
+ success_message = "<span style='color:green;'>Success: Inference successfully!</span>"
648
+ return output_audio, success_message
649
 
650
 
651
  if __name__ == "__main__":
 
815
  outputs=[output_audio, success_output]
816
  )
817
 
818
+ with gr.Tab("Mandarin Speech Editing"):
819
 
820
+ with gr.Row():
821
+ with gr.Column(scale=2):
822
+ input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
823
+ with gr.Group():
824
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
825
+ info="Use whisperx model to get the transcript.")
826
+ transcribe_btn = gr.Button(value="Transcribe")
827
+
828
+ with gr.Column(scale=3):
829
+ with gr.Group():
830
+ transcript = gr.Textbox(label="Text", lines=7, value="价格已基本都在一万到两万之间", interactive=True)
831
+ run_btn = gr.Button(value="Run")
832
+
833
+ with gr.Column(scale=2):
834
+ output_audio = gr.Audio(label="Output Audio")
835
 
836
+ with gr.Row():
837
+ with gr.Accordion("Advanced Settings", open=False):
838
+ seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
839
+ aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
840
+ info="set to 1 to use classifer-free guidance, change if you don't like the results")
841
+ cfg_coef = gr.Number(label="cfg_coef", value=1.5,
842
+ info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
843
+ cfg_stride = gr.Number(label="cfg_stride", value=1,
844
+ info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
845
+ prompt_length = gr.Number(label="prompt_length", value=3,
846
+ info="used for tts prompt, will automatically cut the prompt audio to this length")
847
+ sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
848
+
849
+ success_output = gr.HTML()
850
+
851
+ semgents = gr.State() # not used
852
+ state = gr.State() # not used
853
+ audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
854
+ input_audio.change(
855
+ lambda audio: audio,
856
+ inputs=[input_audio],
857
+ outputs=[audio_state]
858
+ )
859
 
860
+ transcribe_btn.click(fn=transcribe_zh,
861
+ inputs=[audio_state],
862
+ outputs=[original_transcript, semgents, state, success_output])
863
 
864
+ run_btn.click(fn=run_edit_zh,
865
+ inputs=[
866
+ seed, sub_amount,
867
+ aug_text, cfg_coef, cfg_stride, prompt_length,
868
+ audio_state, original_transcript, transcript,
869
+ ],
870
+ outputs=[output_audio, success_output])
871
+
872
+ transcript.submit(fn=run_edit_zh,
873
+ inputs=[
874
+ seed, sub_amount,
875
+ aug_text, cfg_coef, cfg_stride, prompt_length,
876
+ audio_state, original_transcript, transcript,
877
+ ],
878
+ outputs=[output_audio, success_output]
879
+ )
880
 
881
+ with gr.Tab("Mandarin TTS"):
882
 
883
+ with gr.Row():
884
+ with gr.Column(scale=2):
885
+ input_audio = gr.Audio(value=f"{DEMO_PATH}/aishell3_test.wav", label="Input Audio", type="filepath", interactive=True)
886
+ with gr.Group():
887
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value="价格已基本都在三万到六万之间",
888
+ info="Use whisperx model to get the transcript.")
889
+ transcribe_btn = gr.Button(value="Transcribe")
890
+
891
+ with gr.Column(scale=3):
892
+ with gr.Group():
893
+ transcript = gr.Textbox(label="Text", lines=7, value="我简直不敢相信同一个模型也可以进行文本到语音的生成", interactive=True)
894
+ run_btn = gr.Button(value="Run")
895
+
896
+ with gr.Column(scale=2):
897
+ output_audio = gr.Audio(label="Output Audio")
898
 
899
+ with gr.Row():
900
+ with gr.Accordion("Advanced Settings", open=False):
901
+ seed = gr.Number(label="seed", value=-1, precision=0, info="random seeds always works :)")
902
+ aug_text = gr.Radio(label="aug_text", choices=[0, 1], value=1,
903
+ info="set to 1 to use classifer-free guidance, change if you don't like the results")
904
+ cfg_coef = gr.Number(label="cfg_coef", value=1.5,
905
+ info="cfg guidance scale, 1.5 is a good value, change if you don't like the results")
906
+ cfg_stride = gr.Number(label="cfg_stride", value=1,
907
+ info="cfg stride, 1 is a good value for Mandarin, change if you don't like the results")
908
+ prompt_length = gr.Number(label="prompt_length", value=3,
909
+ info="used for tts prompt, will automatically cut the prompt audio to this length")
910
+ sub_amount = gr.Number(label="sub_amount", value=0.12, info="margin to the left and right of the editing segment, change if you don't like the results")
911
+
912
+ success_output = gr.HTML()
913
+
914
+ semgents = gr.State() # not used
915
+ state = gr.State() # not used
916
+ audio_state = gr.State(value=f"{DEMO_PATH}/aishell3_test.wav")
917
+ input_audio.change(
918
+ lambda audio: audio,
919
+ inputs=[input_audio],
920
+ outputs=[audio_state]
921
+ )
922
 
923
+ transcribe_btn.click(fn=transcribe_zh,
924
+ inputs=[audio_state],
925
+ outputs=[original_transcript, semgents, state, success_output])
926
 
927
+ run_btn.click(fn=run_tts_zh,
928
+ inputs=[
929
+ seed, sub_amount,
930
+ aug_text, cfg_coef, cfg_stride, prompt_length,
931
+ audio_state, original_transcript, transcript,
932
+ ],
933
+ outputs=[output_audio, success_output])
934
+
935
+ transcript.submit(fn=run_tts_zh,
936
+ inputs=[
937
+ seed, sub_amount,
938
+ aug_text, cfg_coef, cfg_stride, prompt_length,
939
+ audio_state, original_transcript, transcript,
940
+ ],
941
+ outputs=[output_audio, success_output]
942
+ )
943
 
944
  # Launch the Gradio demo
945
  demo.launch()