arnavmehta7 commited on
Commit
3a0629c
·
verified ·
1 Parent(s): e1be390

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -19
app.py CHANGED
@@ -4,17 +4,40 @@ import torch
4
  import librosa
5
  from pathlib import Path
6
  import tempfile, torchaudio
7
-
 
 
8
 
9
  # Load the MARS5 model
10
  mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)
 
 
 
 
 
 
 
11
 
12
- # Default reference audio and transcript
13
- # default_audio_path = "example.wav"
14
- # default_transcript = "We actually haven't managed to meet demand."
 
15
 
16
  # Function to process the text and audio input and generate the synthesized output
17
  def synthesize(text, audio_file, transcript):
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Load the reference audio
19
  wav, sr = librosa.load(audio_file, sr=mars5.sr, mono=True)
20
  wav = torch.from_numpy(wav)
@@ -29,21 +52,52 @@ def synthesize(text, audio_file, transcript):
29
  # Save the synthesized audio to a temporary file
30
  output_path = Path(tempfile.mktemp(suffix=".wav"))
31
  torchaudio.save(output_path, wav_out.unsqueeze(0), mars5.sr)
32
-
33
  return str(output_path)
34
 
35
- # Create the Gradio interface
36
- interface = gr.Interface(
37
- fn=synthesize,
38
- inputs=[
39
- gr.Textbox(label="Text to synthesize"),
40
- gr.Audio(label="Audio file to clone from", type="filepath"),
41
- gr.Textbox(label="Uploaded audio file transcript"),
42
- ],
43
- outputs=gr.Audio(label="Synthesized Audio"),
44
- title="MARS5 TTS Demo",
45
- description="Enter text and upload an audio file to clone the voice and generate synthesized speech using MARS5 TTS."
46
- )
47
 
48
- # Launch the Gradio app
49
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import librosa
5
  from pathlib import Path
6
  import tempfile, torchaudio
7
+ # from faster_whisper import WhisperModel
8
+ from transformers import pipeline
9
+ from uuid import uuid4
10
 
11
  # Load the MARS5 model
12
  mars5, config_class = torch.hub.load('Camb-ai/mars5-tts', 'mars5_english', trust_repo=True)
13
+ # asr_model = WhisperModel("small", device="cpu", compute_type="int8")
14
+ asr_model = pipeline(
15
+ "automatic-speech-recognition",
16
+ model="openai/whisper-medium",
17
+ chunk_length_s=30,
18
+ device=torch.device("cuda"),
19
+ )
20
 
21
+ def transcribe_file(f: str) -> str:
22
+ predictions = asr_model(f, return_timestamps=True)["chunks"]
23
+ print(f">>>>>. predictions: {predictions}")
24
+ return " ".join([prediction["text"] for prediction in predictions])
25
 
26
  # Function to process the text and audio input and generate the synthesized output
27
  def synthesize(text, audio_file, transcript):
28
+ audio_file = Path(audio_file)
29
+ temp_file = f"{uuid4()}.{audio_file.suffix}"
30
+
31
+ # copying the audio_file
32
+ with open(audio_file, 'rb') as src, open(temp_file, 'wb') as dst:
33
+ dst.write(src.read())
34
+
35
+ audio_file = temp_file
36
+
37
+ print(f">>>>> synthesizing! audio_file: {audio_file}")
38
+ if not transcript:
39
+ transcript = transcribe_file(audio_file)
40
+
41
  # Load the reference audio
42
  wav, sr = librosa.load(audio_file, sr=mars5.sr, mono=True)
43
  wav = torch.from_numpy(wav)
 
52
  # Save the synthesized audio to a temporary file
53
  output_path = Path(tempfile.mktemp(suffix=".wav"))
54
  torchaudio.save(output_path, wav_out.unsqueeze(0), mars5.sr)
 
55
  return str(output_path)
56
 
57
+ defaults = {
58
+ 'temperature': 0.8,
59
+ 'top_k': -1,
60
+ 'top_p': 0.2,
61
+ 'typical_p': 1.0,
62
+ 'freq_penalty': 2.6,
63
+ 'presence_penalty': 0.4,
64
+ 'rep_penalty_window': 100,
65
+ 'max_prompt_phones': 360,
66
+ 'deep_clone': True,
67
+ 'nar_guidance_w': 3
68
+ }
69
 
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("## MARS5 TTS Demo\nEnter text and upload an audio file to clone the voice and generate synthesized speech using MARS5 TTS.")
73
+ text = gr.Textbox(label="Text to synthesize")
74
+ audio_file = gr.Audio(label="Audio file to clone from", type="filepath")
75
+
76
+ generate_btn = gr.Button(label="Generate Synthesized Audio")
77
+
78
+ with gr.Accordion("Advanced Settings", open=False):
79
+ gr.Markdown("additional inference settings\nWARNING: changing these incorrectly may degrade quality.")
80
+ prompt_text = gr.Textbox(label="Transcript of voice reference")
81
+ temperature = gr.Slider(minimum=0.01, maximum=3, step=0.01, label="temperature", value=defaults['temperature'])
82
+ top_k = gr.Slider(minimum=-1, maximum=2000, step=1, label="top_k", value=defaults['top_k'])
83
+ top_p = gr.Slider(minimum=0.01, maximum=1.0, step=0.01, label="top_p", value=defaults['top_p'])
84
+ typical_p = gr.Slider(minimum=0.01, maximum=1, step=0.01, label="typical_p", value=defaults['typical_p'])
85
+ freq_penalty = gr.Slider(minimum=0, maximum=5, step=0.05, label="freq_penalty", value=defaults['freq_penalty'])
86
+ presence_penalty = gr.Slider(minimum=0, maximum=5, step=0.05, label="presence_penalty", value=defaults['presence_penalty'])
87
+ rep_penalty_window = gr.Slider(minimum=1, maximum=500, step=1, label="rep_penalty_window", value=defaults['rep_penalty_window'])
88
+ nar_guidance_w = gr.Slider(minimum=1, maximum=8, step=0.1, label="nar_guidance_w", value=defaults['nar_guidance_w'])
89
+ meta_n = gr.Slider(minimum=1, maximum=10, step=1, label="meta_n", value=2, interactive=False)
90
+ deep_clone = gr.Checkbox(value=defaults['deep_clone'], label='deep_clone')
91
+
92
+ dummy = gr.Number(label='Example number', visible=False)
93
+
94
+ output = gr.Audio(label="Synthesized Audio", type="filepath")
95
+ def on_click(text, audio_file, prompt_text):
96
+ print(f">>>> transcript: {prompt_text}; audio_file = {audio_file}")
97
+ of = synthesize(text, audio_file, prompt_text)
98
+ print(f">>>> output file: {of}")
99
+ return of
100
+
101
+ generate_btn.click(on_click, inputs=[text, audio_file, prompt_text], outputs=[output])
102
+
103
+ demo.launch(share=False)