import re import gc import torch import shutil import numpy as np import pandas as pd import gradio as gr from pathlib import Path from pydub import AudioSegment from datetime import timedelta, datetime from dev_utils import DEV_LOGGER from yt_utils import Logger, get_audio_from_youtube from nemo_utils import get_aligned_transcription, process_audio, get_model_cache, get_offsets_cache from hf_utils import get_models_list, get_model_description, predownload_models class NeMoGradioApp: def __init__(self): self.cache_dir = '/tmp/gradio' self.loggers = {} self.caching_funcs = {} self.device = "cuda" if torch.cuda.is_available() else "cpu" self.default_lang = 'English' self.default_model_bck = 'parakeet-tdt_ctc-1.1b' self.default_model = 'msekoyan/parakeet2-0.6b-En' self.default_video_id = 'GlKBbsVX37c' self.default_video_id = 'erhqbyvPesY' self.demo_audio_path = 'demo_audio2_processed.flac' self.demo_results = self.get_results_for_demo() self.available_langs_models, self.available_models_langs = get_models_list(sort_alphabetically=True) predownload_models(list(self.available_models_langs.keys()), top=1) self.cm_js_code = """ function makeCmLinesHyperlinksIfHttps() { const lines = document.querySelectorAll(".cm-line"); lines.forEach((line, index) => { // Check if line includes "https" and make it a clickable hyperlink if (line.textContent.includes("https")) { line.style.cursor = "pointer"; line.style.color = "blue"; line.addEventListener("click", () => { const url = line.textContent.match(/https?:\/\/[^\s]+/)[0]; window.open("https://www.google.com/device", "_blank"); }); } // Check if line includes the word "code" and raise an alert if (line.textContent.includes("code")) { alert(`${line.textContent.trim()} on https://www.google.com/device. See more in the logs section below.`); } }); } """ self.demo = gr.Blocks( title="NeMo Speech-to-Text", head=""" """, css=""" textarea { font-size: 18px;} #model_output_text_box span { font-size: 18px; font-weight: bold; } .cm-line { color: #2463eb; } #csv-button { background: #2463eb !important; border: #2463eb !important; color: white !important; transition: background-color 0.3s !important; } } """, theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md, ) self.build_inference() def get_results_for_demo(self): get_model_func = get_model_cache() get_offsets_func = get_offsets_cache() self.preloaded_demo_model = get_model_func(self.default_model) demo_results = {} for timestamp_level in ["Segments", "Words"]: timestamps = get_aligned_transcription(self.default_model, self.demo_audio_path, timestamp_level, get_model_func, get_offsets_func, self.device, preloaded_model=self.preloaded_demo_model) timestamps_df = self.get_ts_dataframe(timestamps) csv_path = self.save_results_to_csv(timestamps_df, self.demo_audio_path, timestamp_level, 'master') demo_results[timestamp_level] = (timestamps_df, csv_path) get_offsets_func.cache_clear() get_model_func.cache_clear() gc.collect() return demo_results def get_session_starting(self, request: gr.Request): session_hash = request.session_hash yt_logger = Logger(self.get_session_log_file(session_hash)) self.loggers[session_hash] = yt_logger self.caching_funcs[session_hash] = { 'get_model': get_model_cache(), 'get_offsets': get_offsets_cache() } DEV_LOGGER.info(f'STARTING: {session_hash}') return session_hash def get_logs(self, request: gr.Request): if not self.loggers: return "The download has not started yet!" logger = self.loggers.get(request.session_hash) if logger: return logger.get_logs() return None def get_yt_cache_dir(self, session_hash): yt_cache_dir = Path(self.cache_dir, session_hash, 'youtube') yt_cache_dir.mkdir(exist_ok=True, parents=True) return yt_cache_dir def get_session_log_file(self, session_hash): session_log_dir = Path(self.cache_dir, session_hash, 'logs') session_log_dir.mkdir(exist_ok=True, parents=True) return Path(session_log_dir, 'output.log') @property def session_log_file(self): session_log_dir = Path(self.cache_dir, 'logs') session_log_dir.mkdir(exist_ok=True, parents=True) return Path(session_log_dir, 'output.log') @staticmethod def sec_to_hrs(seconds): seconds = round(seconds) return str(timedelta(seconds=seconds)) @staticmethod def hrs_to_sec(hrs): time_obj = datetime.strptime(hrs, "%H:%M:%S") seconds = time_obj.hour * 3600 + time_obj.minute * 60 + time_obj.second return seconds @staticmethod def get_ts_dataframe(segments): df = pd.DataFrame(columns=['start_time', 'end_time', 'text', 'start_ms', 'end_ms']) if len(segments) == 0: df.loc[0] = 0, 0, '' return df for segment in segments: text, start_time, end_time = segment if len(text)>0: df.loc[len(df)] = (NeMoGradioApp.sec_to_hrs(start_time), NeMoGradioApp.sec_to_hrs(end_time), text, round(start_time, 2), round(end_time, 2)) return df @staticmethod def get_max_duration(model_name): if model_name == 'parakeet-tdt_ctc-1.1b': return None # Limit input audio duration to 20 mins for fastconformer models elif 'fastconformer' in model_name or 'parakeet' in model_name: return 20 * 60 # Limit input audio duration to 10 mins for conformer models elif 'conformer' in model_name: return 10 * 60 @staticmethod def get_audio_segment(audio_path, start_second, end_second): if audio_path is None: return None start_second = start_second * 1000 end_second = end_second * 1000 audio = AudioSegment.from_file(audio_path) clipped_audio = audio[start_second:end_second] samples = np.array(clipped_audio.get_array_of_samples()) if clipped_audio.channels == 2: samples = samples.reshape((-1, 2)) samples = samples.mean(axis=1) return (clipped_audio.frame_rate, samples) def reset_demo_materials(self, selected_tab, selected_model, yt_link, yt_render, results_df, download_button, session_hash): if selected_model == self.default_model and yt_link == f'https://www.youtube.com/watch?v={self.default_video_id}' and selected_tab == "youtube": results_df = self.demo_results["Segments"][0].drop(['start_ms', 'end_ms'], axis=1) yt_render = gr.HTML(f'
') download_button = gr.DownloadButton("Download CSV", value=self.demo_results["Segments"][1], visible=True, interactive=True, elem_id='csv-button') else: self.loggers[session_hash].reset_logs() yt_link = "" yt_render = gr.HTML(f'
') return yt_link, yt_render, results_df, download_button def reset_yt_logger(self, results_df, selected_model, session_hash): if selected_model != self.default_model: return demo_segments_df = self.demo_results['Segments'][0].drop(['start_ms', 'end_ms'], axis=1) demo_words_df = self.demo_results['Words'][0].drop(['start_ms', 'end_ms'], axis=1) if not results_df.empty and not results_df.equals(demo_segments_df) and not results_df.equals(demo_words_df): self.loggers[session_hash].remind_about_demo() def change_based_on_model(self, selected_model): desc, more_info = self.show_model_description(selected_model) max_duration = self.get_max_duration(selected_model) if max_duration: note_on_max_duration = gr.Markdown(f' ⚠️ NOTE: In the current setup, you can transcribe **up to {max_duration // 60} minutes** of speech. ⚠️ ', visible=True) else: note_on_max_duration = gr.Markdown(value=None, visible=False) file_input = gr.Audio(sources='upload', label='Upload Audio', type='filepath', max_length=max_duration) mic_input = gr.Audio(sources='microphone', label='Record Audio', type='filepath', max_length=max_duration) return desc, more_info, note_on_max_duration, gr.Number(max_duration, visible=False), file_input, mic_input def show_available_models(self, lang=None): if lang: models_list = self.available_langs_models[lang] else: models_list = self.available_models_langs.keys() return models_list def update_dropdown(self, selected_value): lang_models = self.show_available_models(selected_value) return gr.Dropdown(choices=lang_models, value=lang_models[0], interactive=True) def show_model_description(self, selected_model): if selected_model: description, more_info = get_model_description(selected_model) return gr.Markdown(value=description, visible=True, label='Model Description'), gr.Markdown(value=more_info, visible=True, label='Model Description') return gr.Markdown(visible=False), gr.Markdown(visible=False) def save_results_to_csv(self, df, audio_path, timestamp_type, session_hash): audio_name = Path(audio_path).stem.replace('_processed', '') csv_name = f'{audio_name}_{timestamp_type.lower()}_timestamps' csv_dir = Path(self.cache_dir, session_hash) csv_dir.mkdir(exist_ok=True, parents=True) csv_path = Path(csv_dir, csv_name).with_suffix('.csv').as_posix() df.to_csv(csv_path, index=False) return csv_path # Return the path for download def on_row_click(self, evt: gr.SelectData, df, selected_src, html, file_path, mic_path): selected_row = df.iloc[evt.index[0]] start_seconds = selected_row['start_ms'] end_seconds = selected_row['end_ms'] file_selected_segment = gr.Audio(label='Selected Segment', visible=False) mic_selected_segment = gr.Audio(label='Selected Segment', visible=False) DEV_LOGGER.info(f"Selected Source {selected_src}") if selected_src == 'youtube': start_seconds = round(start_seconds) end_seconds = round(end_seconds) DEV_LOGGER.info(f"Start: {start_seconds} | End: {end_seconds}") if start_seconds == end_seconds: end_seconds = start_seconds + 1 match = re.search(r'src="([^"?]+)', html) src_url = match.group(1) if match else None html = gr.HTML(f'
') elif selected_src == 'file' and file_path: segment = self.get_audio_segment(file_path, start_seconds, end_seconds) file_selected_segment = gr.Audio(segment, autoplay=True, label=f"{start_seconds}-{end_seconds} Second Segment", visible=True) elif selected_src == 'mic' and mic_path: segment = self.get_audio_segment(mic_path, start_seconds, end_seconds) mic_selected_segment = gr.Audio(segment, autoplay=True, label=f"{start_seconds}-{end_seconds} Second Segment", visible=True) return html, file_selected_segment, mic_selected_segment def cleanup_yt_cache(self, yt_link, selected_model, session_hash): if yt_link == f"https://www.youtube.com/watch?v={self.default_video_id}" and selected_model == self.default_model: self.loggers[session_hash].reset_logs_for_demo() else: self.loggers[session_hash].reset_logs() for file_path in self.get_yt_cache_dir(session_hash).glob('*'): if file_path.is_file(): file_path.unlink() def cleanup(self, request: gr.Request): DEV_LOGGER.info(f'DELETING EVERYTHING FOR SESSION: {request.session_hash}') session_cache_dir = Path(self.cache_dir, request.session_hash) if session_cache_dir.exists(): shutil.rmtree(session_cache_dir) if request.session_hash in self.loggers: del self.loggers[request.session_hash] if request.session_hash in self.caching_funcs: DEV_LOGGER.info(f'DELETING SESSION CACHE FOR: {request.session_hash}') self.caching_funcs[request.session_hash]['get_offsets'].cache_clear() self.caching_funcs[request.session_hash]['get_model'].cache_clear() del self.caching_funcs[request.session_hash] gc.collect() def get_processed_audio(self, model_name, timestamp_type, channel_to_use, url, file_path, microphone, html, session_hash, max_audio_length): if not model_name: raise gr.Error('Please, select a model to transcribe with!') processed_path = None if channel_to_use == 'youtube' and url == f"https://www.youtube.com/watch?v={self.default_video_id}": audio_path = self.demo_audio_path processed_path = self.demo_audio_path elif channel_to_use == 'youtube': yt_video_id = url.split('v=')[-1] yt_cache_dir = self.get_yt_cache_dir(session_hash) possible_files = list(yt_cache_dir.glob(f'{yt_video_id}_*_processed.flac')) if possible_files: processed_path = possible_files[0].as_posix() audio_path = processed_path else: gr.Info("Downloading and processing audio from Youtube", duration=None) audio_path, html = get_audio_from_youtube(url, yt_cache_dir, self.loggers[session_hash], max_audio_length) elif channel_to_use == 'file': audio_path = file_path else: audio_path = microphone DEV_LOGGER.info(f'SESSION ID: {session_hash} | USING CHANNEL: {channel_to_use}') DEV_LOGGER.info(f'SESSION ID: {session_hash} | USING PATH: {audio_path}') if not processed_path: processed_path = process_audio(audio_path) return "%".join([processed_path, timestamp_type, model_name]), html, channel_to_use def get_timestamps(self, model_name, processed_path, timestamp_type, session_hash): processed_path = "%".join(processed_path.split('%')[:-2]) if model_name == self.default_model and processed_path == 'demo_audio.flac': results_df = self.demo_results[timestamp_type][0] csv_path = self.demo_results[timestamp_type][1] gr.Info("Results are ready!", duration=2) return (results_df.drop(['start_ms', 'end_ms'], axis=1), results_df[['start_ms', 'end_ms']], gr.DownloadButton(value=csv_path, visible=True, interactive=True) ) gr.Info("Running NeMo Model", duration=None) preloaded_model = self.preloaded_demo_model if model_name == self.default_model else None timestamps = get_aligned_transcription(model_name, processed_path, timestamp_type, self.caching_funcs[session_hash]['get_model'], self.caching_funcs[session_hash]['get_offsets'], self.device, preloaded_model=preloaded_model) df = self.get_ts_dataframe(timestamps) csv_path = self.save_results_to_csv(df, processed_path, timestamp_type, session_hash) gr.Info("Results are ready!", duration=2) return (df.drop(['start_ms', 'end_ms'], axis=1), df[['start_ms', 'end_ms']], gr.DownloadButton(value=csv_path, visible=True, interactive=True)) def build_inference(self): with self.demo: gr.HTML("

Transcription with Timestamps using NeMo STT Models 🤗

") gr.Markdown(f"""
Transcribe speech in {round(len(self.available_langs_models) / 5) * 5}+ languages!
""") # gr.Button("Show Client Host").click(lambda client_host: client_host, inputs=client_host, outputs=output) session_hash = gr.Textbox(visible=False) max_audio_length = gr.Number(visible=False) self.demo.load(self.get_session_starting, outputs=session_hash) # User selection section with gr.Row(): lang_dropdown = gr.Dropdown(choices=list(self.available_langs_models.keys()), value=self.default_lang, label="Select a Language", interactive=False) model_dropdown = gr.Dropdown(choices=self.show_available_models(self.default_lang), value=self.default_model, label="Select a Model", interactive=False) model_desc = gr.Markdown(visible=True, value=get_model_description(self.default_model_bck)[0]) model_more_info = gr.Markdown(visible=True, value=get_model_description(self.default_model_bck)[1]) # note_on_max_duration = gr.Markdown(visible=False) note_on_max_duration = gr.Markdown(f' ⚠️ NOTE: In the current setup, you can transcribe **up to 20 minutes** of speech. ⚠️ ', visible=True) lang_dropdown.select( fn=self.update_dropdown, inputs=[lang_dropdown], outputs=[model_dropdown] ) gr.Markdown(' ⚠️ This experimental space is for showcasing the new Parakeet2 model. That is why most of the features are not available. ⚠️') selected_tab = gr.State('youtube') #Youtube Block with gr.Tab('Audio from Youtube') as yt_tab: gr.Markdown(' ⚠️ You may be required to authenticate on [https://www.google.com/device](https://www.google.com/device) using the code provided in the logs to download a video from YouTube. ⚠️') yt_logs = gr.Code(value=None, language='markdown', lines=2, label='YouTube Logs') with gr.Row(): yt_link = gr.Textbox(value=f'https://www.youtube.com/watch?v={self.default_video_id}', label='Enter Youtube Link', type='text') yt_link.change(self.cleanup_yt_cache, inputs=[yt_link, model_dropdown, session_hash]) yt_render = gr.HTML(f'
') yt_tab.select(lambda: 'youtube', outputs=selected_tab) yt_logs.change(fn=None, inputs=None, outputs=None, js=self.cm_js_code) timer = gr.Timer(value=1) timer.tick(self.get_logs, outputs=yt_logs) #File Block with gr.Tab('Audio from File') as file_tab: file_input = gr.Audio(sources='upload', label='Upload Audio', type='filepath', max_length=1200) file_selected_segment = gr.Audio(label='Selected Segment', visible=False) file_input.change(lambda: gr.Audio(label='Selected Segment', visible=False), outputs=file_selected_segment) file_tab.select(lambda: 'file', outputs=selected_tab) #Mic Block with gr.Tab('Audio from Microphone') as mic_tab: mic_input = gr.Audio(sources='microphone', label='Record Audio', type='filepath', max_length=1200) mic_selected_segment = gr.Audio(label='Selected Segment', visible=False) mic_input.change(lambda: gr.Audio(label='Selected Segment', visible=False), outputs=mic_selected_segment) mic_tab.select(lambda: 'mic', outputs=selected_tab) with gr.Row(): timestamp_type = gr.Radio(["Segments", "Words"], value='Segments', label='Select timestamps granularity', show_label=True) gr.Markdown('Currently segments are formed based on the following punctuation marks: `. ? !`. \nIf the selected model does not support these punctuation marks, the segments will be formed based on silence duration between words.', line_breaks=True) with gr.Row(): timestamps_button = gr.Button("Get timestamps with text", variant='primary') download_button = gr.DownloadButton("Download CSV", value=self.demo_results['Segments'][1], visible=True, interactive=True, elem_id='csv-button') ms_df = gr.DataFrame(value=self.demo_results['Segments'][0][['start_ms', 'end_ms']], visible=False) click_message = gr.Markdown(f"""
Ready to dive in? Just click on the text to jump to the part you need!
""") user_inputs = gr.Textbox(value=None, visible=False) displayed_results_source = gr.Textbox(value="youtube", visible=False) timestamps_df = gr.DataFrame(value=self.demo_results['Segments'][0].drop(['start_ms', 'end_ms'], axis=1), wrap=True, label='Click on the text to jump to that part of the speech.', show_label=False, row_count=(1, "dynamic"), col_count=(3, 'fixed'), headers=['start_time', 'end_time', 'text'], elem_id="target-table", interactive=False) model_dropdown.change( fn=self.change_based_on_model, inputs=[model_dropdown], outputs=[model_desc, model_more_info, note_on_max_duration, max_audio_length, file_input, mic_input] ).then(fn=self.reset_demo_materials, inputs=[selected_tab, model_dropdown, yt_link, yt_render, timestamps_df, download_button, session_hash], outputs=[yt_link, yt_render, timestamps_df, download_button]) timestamps_df.select(self.on_row_click, inputs=[ms_df, displayed_results_source, yt_render, file_input, mic_input], outputs=[yt_render, file_selected_segment, mic_selected_segment]) timestamps_df.change(self.reset_yt_logger, inputs=[timestamps_df, model_dropdown, session_hash]) timestamps_button.click(self.get_processed_audio, inputs=[model_dropdown, timestamp_type, selected_tab, yt_link, file_input, mic_input, yt_render, session_hash, max_audio_length], outputs=[user_inputs, yt_render, displayed_results_source], concurrency_limit=8) user_inputs.change(self.get_timestamps, inputs=[model_dropdown, user_inputs, timestamp_type, session_hash], outputs=[timestamps_df, ms_df, download_button], concurrency_limit=2) self.demo.unload(self.cleanup) def launch(self): self.demo.queue(True) self.demo.launch(share=True, debug=True) nemo_app = NeMoGradioApp() demo = nemo_app.demo if __name__ == '__main__': nemo_app.launch()