msekoyan commited on
Commit
a3f1726
·
1 Parent(s): 9fe2ed1

Add application file

Browse files
Files changed (4) hide show
  1. app.py +184 -0
  2. nemo_align.py +36 -0
  3. packages.txt +3 -0
  4. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import subprocess
3
+ import torch
4
+ import gradio as gr
5
+ import yt_dlp
6
+ import pandas as pd
7
+ from nemo.collections.asr.models import ASRModel
8
+ from nemo_align import align_tdt_preds
9
+ import os
10
+
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ def process_audio(input_file, output_file):
15
+ command = [
16
+ 'sox', input_file,
17
+ output_file,
18
+ 'channels', '1',
19
+ 'rate', '16000'
20
+ ]
21
+ try:
22
+ subprocess.run(command, check=True)
23
+ return output_file
24
+ except:
25
+ raise gr.Error("Failed to convert audio to single channel and sampling rate to 16000")
26
+
27
+ def get_dataframe_segments(segments):
28
+ df = pd.DataFrame(columns=['start_time', 'end_time', 'text'])
29
+ if len(segments) == 0:
30
+ df.loc[0] = 0, 0, ''
31
+ return df
32
+
33
+ for segment in segments:
34
+ text, start_time, end_time = segment
35
+ if len(text)>0:
36
+ df.loc[len(df)] = round(start_time, 2), round(end_time, 2), text
37
+
38
+ return df
39
+
40
+
41
+ def get_video_info(url):
42
+ ydl_opts = {
43
+ 'quiet': True,
44
+ 'skip-download': True,
45
+ }
46
+
47
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
48
+ try:
49
+ info = ydl.extract_info(url, download=False)
50
+ except:
51
+ raise gr.Error("Failed to extract video info from Youtube")
52
+ return info
53
+
54
+ def download_audio(url):
55
+ ydl_opts = {
56
+ 'format': 'bestaudio/best,channels:1',
57
+ 'quiet': True,
58
+ 'outtmpl': 'audio_file',
59
+ 'postprocessors': [{
60
+ 'key': 'FFmpegExtractAudio',
61
+ 'preferredcodec': 'flac',
62
+ 'preferredquality': '192',
63
+ }],
64
+ }
65
+
66
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
67
+ try:
68
+ ydl.download([url])
69
+ except yt_dlp.utils.DownloadError as err:
70
+ raise gr.Error(str(err))
71
+
72
+ return process_audio('audio_file.flac', 'processed_file.flac')
73
+
74
+
75
+ def get_audio_from_youtube(url):
76
+ info = get_video_info(url)
77
+ duration = info.get('duration', 0) # Duration in seconds
78
+ video_id = info.get('id',None)
79
+
80
+ html = f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
81
+
82
+ if duration > 2*60*60: # 2 hrs change later based on GPU
83
+ return gr.Error(str("For GPU {}, single pass maximum audio can be 2hrs"))
84
+ else:
85
+ return download_audio(url), html
86
+
87
+
88
+ def get_transcripts(audio_path, model):
89
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
90
+ with torch.inference_mode():
91
+ text = model.transcribe(audio=[audio_path], )
92
+ return text
93
+
94
+ def pick_asr_model():
95
+ model = 'nvidia/parakeet-tdt_ctc-1.1b'
96
+ asr_model = ASRModel.from_pretrained(model).to(device)
97
+ asr_model.cfg.decoding.strategy = "greedy_batch"
98
+ asr_model.change_decoding_strategy(asr_model.cfg.decoding)
99
+ asr_model.eval()
100
+ return asr_model
101
+
102
+ asr_model = pick_asr_model()
103
+
104
+ def run_nemo_models(url, microphone, audio_path, timestamp_type):
105
+ html = None
106
+ if url is None or len(url)<2:
107
+ path1 = microphone if microphone else audio_path
108
+ else:
109
+ gr.Info("Downloading and processing audio from Youtube")
110
+ path1, html = get_audio_from_youtube(url)
111
+
112
+ gr.Info("Running NeMo Model")
113
+
114
+ timestamps = align_tdt_preds(asr_model, path1, timestamp_type)
115
+
116
+ df = get_dataframe_segments(timestamps)
117
+
118
+ return df, html
119
+
120
+ def clear_youtube_link():
121
+ # Remove .flac files in current directory
122
+ file_list = os.listdir()
123
+ for file in file_list:
124
+ if file.endswith(".flac"):
125
+ os.remove(file)
126
+
127
+ return None
128
+
129
+
130
+ # def run_speaker_diarization()
131
+
132
+ with gr.Blocks(
133
+ title="NeMo Parakeet Model",
134
+ css="""
135
+ textarea { font-size: 18px;}
136
+ #model_output_text_box span {
137
+ font-size: 18px;
138
+ font-weight: bold;
139
+ }
140
+ """,
141
+ theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
142
+ ) as demo:
143
+ gr.HTML("<h1 style='text-align: center'>Transcription with timestamps using Parakeet TDT-CTC</h1>")
144
+ gr.Markdown('''
145
+ Choose between different sources of audio (Microphone, Audio File, Youtube Video) to transcribe along with timestamps.
146
+
147
+ Parakeet models with limited attention are quite fast due to their limited attention mechanism. The current model with 1.1B parameters can transcribe very long audios upto 11 hrs on A6000 GPU in a single pass.
148
+
149
+ Model used: [nvidia/parakeet-tdt_ctc-1.1b](https://huggingface.co/nvidia/parakeet-tdt_ctc-1.1b).
150
+ ''')
151
+ # This block is for reading audio from MIC
152
+ with gr.Tab('Audio from Youtube'):
153
+ with gr.Row():
154
+ yt_link = gr.Textbox(value=None, label='Enter Youtube Link', type='text')
155
+ yt_render = gr.HTML()
156
+
157
+ with gr.Tab('Audio From File'):
158
+ file_input = gr.Audio(sources='upload', label='Upload Audio', type='filepath')
159
+
160
+ with gr.Tab('Audio From Microphone'):
161
+ mic_input = gr.Audio(sources='microphone', label='Record Audio', type='filepath')
162
+
163
+
164
+ # b1 = gr.Button("Get Transcription with Punctuation and Capitalization")
165
+
166
+ gr.Markdown('''Speech Recognition''')
167
+
168
+ # text_output = gr.Textbox(label='Transcription', type='text')
169
+
170
+ timestamp_type = gr.Radio(["Segments", "Words"], value='Segments', label='Select timestamps granularity', show_label=True)
171
+
172
+ b2 = gr.Button("Get timestamps with text")
173
+
174
+ time_stamp = gr.DataFrame(wrap=True, label='Speech Recognition with TimeStamps',
175
+ row_count=(1, "dynamic"), headers=['start_time', 'end_time', 'text'])
176
+
177
+ # b1.click(run_nemo_models, inputs=[file_input, mic_input, yt_link], outputs=[text_output, yt_render])
178
+
179
+ b2.click(run_nemo_models, inputs=[yt_link, file_input, mic_input, timestamp_type], outputs=[time_stamp, yt_render]).then(
180
+ clear_youtube_link, None, yt_link, queue=False) #here clean up passing None to audio.
181
+
182
+ demo.queue(True)
183
+ demo.launch(share=True, debug=True)
184
+
nemo_align.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import omegaconf
3
+
4
+ def align_tdt_preds(model, audio_path, timestamps_type):
5
+
6
+ timestamps_type = 'segment' if timestamps_type == 'Segments' else 'word'
7
+
8
+ cfg = model.cfg.decoding
9
+ with omegaconf.open_dict(cfg):
10
+ cfg['compute_timestamps'] = True
11
+ cfg['rnnt_timestamp_type'] = timestamps_type
12
+
13
+ model.change_decoding_strategy(decoding_cfg=cfg)
14
+
15
+ with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16):
16
+ with torch.inference_mode():
17
+ hypotheses = model.transcribe(audio=[audio_path], return_hypotheses=True)
18
+
19
+
20
+ frame_dur = model.cfg.preprocessor.window_stride * model.cfg.encoder.subsampling_factor
21
+
22
+ if type(hypotheses) == tuple and len(hypotheses) == 2:
23
+ hypotheses = hypotheses[0]
24
+
25
+ offsets = hypotheses[0].timestep[timestamps_type]
26
+
27
+ timestamps = []
28
+
29
+ for unit in offsets:
30
+ start_s = unit['start_offset'] * frame_dur
31
+ end_s = unit['end_offset'] * frame_dur
32
+ timestamps.append((unit[timestamps_type], start_s, end_s))
33
+
34
+ return timestamps
35
+
36
+
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ libsndfile1
3
+ sox
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Cython
2
+ packaging
3
+ torch==2.2.0
4
+ IPython
5
+ numpy>=1.22,<2.0.0
6
+ git+https://github.com/NVIDIA/NeMo.git@msekoyan/tdt_compute_timestamps#egg=nemo_toolkit[asr]
7
+ yt_dlp