msekoyan's picture
Update app.py
435af44 verified
import re
import gc
import torch
import shutil
import numpy as np
import pandas as pd
import gradio as gr
from pathlib import Path
from pydub import AudioSegment
from datetime import timedelta, datetime
from dev_utils import DEV_LOGGER
from yt_utils import Logger, get_audio_from_youtube
from nemo_utils import get_aligned_transcription, process_audio, get_model_cache, get_offsets_cache
from hf_utils import get_models_list, get_model_description, predownload_models
class NeMoGradioApp:
def __init__(self):
self.cache_dir = '/tmp/gradio'
self.loggers = {}
self.caching_funcs = {}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.default_lang = 'English'
self.default_model_bck = 'parakeet-tdt_ctc-1.1b'
self.default_model = 'msekoyan/parakeet2-0.6b-En'
self.default_video_id = 'GlKBbsVX37c'
self.default_video_id = 'erhqbyvPesY'
self.demo_audio_path = 'demo_audio2_processed.flac'
self.demo_results = self.get_results_for_demo()
self.available_langs_models, self.available_models_langs = get_models_list(sort_alphabetically=True)
predownload_models(list(self.available_models_langs.keys()), top=1)
self.cm_js_code = """
function makeCmLinesHyperlinksIfHttps() {
const lines = document.querySelectorAll(".cm-line");
lines.forEach((line, index) => {
// Check if line includes "https" and make it a clickable hyperlink
if (line.textContent.includes("https")) {
line.style.cursor = "pointer";
line.style.color = "blue";
line.addEventListener("click", () => {
const url = line.textContent.match(/https?:\/\/[^\s]+/)[0];
window.open("https://www.google.com/device", "_blank");
});
}
// Check if line includes the word "code" and raise an alert
if (line.textContent.includes("code")) {
alert(`${line.textContent.trim()} on https://www.google.com/device. See more in the logs section below.`);
}
});
}
"""
self.demo = gr.Blocks(
title="NeMo Speech-to-Text",
head="""
<script>
const new_ce = document.getElementsByTagName("gradio-app");
if (new_ce[0]) {
// Create a MutationObserver to detect changes
const observer = new MutationObserver((mutationsList) => {
mutationsList.forEach((mutation) => {
if (mutation.type === "childList") {
mutation.addedNodes.forEach((node) => {
// Block for role="alert" and class="error"
if (
node.nodeType === 1 &&
node.getAttribute("role") === "alert" &&
node.classList.contains("error")
) {
// console.log("An alert element with 'error' class was added:", node);
const toastTextDiv = node.querySelector('.toast-text');
if (toastTextDiv) {
const textContent = toastTextDiv.textContent.trim();
// console.log("Alert with error class detected. Toast text:", textContent);
if (textContent === "Terminating session due to inactivity...") {
// console.log("Text is 'Terminating session due to inactivity...', waiting 5 seconds before reloading the page...");
setTimeout(() => {
location.reload(); // Reload the page after 5 seconds
}, 5000);
}
} else {
// console.log("Alert with error class detected, but no 'toast-text' div found.");
}
}
// Block for role="alert" and class="info"
if (
node.nodeType === 1 &&
node.getAttribute("role") === "alert" &&
node.classList.contains("info")
) {
// console.log("An alert element with 'info' class was added:", node);
// Find and close any existing "info" alerts (excluding the newly added node)
const existingInfoAlerts = document.querySelectorAll('[role="alert"].info');
existingInfoAlerts.forEach((existingAlert) => {
const closeButton = existingAlert.querySelector('.toast-close');
if (closeButton && existingAlert !== node) { // Exclude the newly added node
// console.log("Closing existing 'info' alert:", existingAlert);
closeButton.click();
}
});
}
});
}
});
});
// Start observing the selected element for changes in its children
observer.observe(new_ce[0], { childList: true, subtree: true });
}
// Example to trigger a message event for testing
window.postMessage("Hello from the Gradio app!", "*");
</script>
""",
css="""
textarea { font-size: 18px;}
#model_output_text_box span {
font-size: 18px;
font-weight: bold;
}
.cm-line {
color: #2463eb;
}
#csv-button {
background: #2463eb !important;
border: #2463eb !important;
color: white !important;
transition: background-color 0.3s !important;
}
}
""",
theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md,
)
self.build_inference()
def get_results_for_demo(self):
get_model_func = get_model_cache()
get_offsets_func = get_offsets_cache()
self.preloaded_demo_model = get_model_func(self.default_model)
demo_results = {}
for timestamp_level in ["Segments", "Words"]:
timestamps = get_aligned_transcription(self.default_model,
self.demo_audio_path,
timestamp_level,
get_model_func,
get_offsets_func,
self.device,
preloaded_model=self.preloaded_demo_model)
timestamps_df = self.get_ts_dataframe(timestamps)
csv_path = self.save_results_to_csv(timestamps_df, self.demo_audio_path, timestamp_level, 'master')
demo_results[timestamp_level] = (timestamps_df, csv_path)
get_offsets_func.cache_clear()
get_model_func.cache_clear()
gc.collect()
return demo_results
def get_session_starting(self, request: gr.Request):
session_hash = request.session_hash
yt_logger = Logger(self.get_session_log_file(session_hash))
self.loggers[session_hash] = yt_logger
self.caching_funcs[session_hash] = {
'get_model': get_model_cache(),
'get_offsets': get_offsets_cache()
}
DEV_LOGGER.info(f'STARTING: {session_hash}')
return session_hash
def get_logs(self, request: gr.Request):
if not self.loggers:
return "The download has not started yet!"
logger = self.loggers.get(request.session_hash)
if logger:
return logger.get_logs()
return None
def get_yt_cache_dir(self, session_hash):
yt_cache_dir = Path(self.cache_dir, session_hash, 'youtube')
yt_cache_dir.mkdir(exist_ok=True, parents=True)
return yt_cache_dir
def get_session_log_file(self, session_hash):
session_log_dir = Path(self.cache_dir, session_hash, 'logs')
session_log_dir.mkdir(exist_ok=True, parents=True)
return Path(session_log_dir, 'output.log')
@property
def session_log_file(self):
session_log_dir = Path(self.cache_dir, 'logs')
session_log_dir.mkdir(exist_ok=True, parents=True)
return Path(session_log_dir, 'output.log')
@staticmethod
def sec_to_hrs(seconds):
seconds = round(seconds)
return str(timedelta(seconds=seconds))
@staticmethod
def hrs_to_sec(hrs):
time_obj = datetime.strptime(hrs, "%H:%M:%S")
seconds = time_obj.hour * 3600 + time_obj.minute * 60 + time_obj.second
return seconds
@staticmethod
def get_ts_dataframe(segments):
df = pd.DataFrame(columns=['start_time', 'end_time', 'text', 'start_ms', 'end_ms'])
if len(segments) == 0:
df.loc[0] = 0, 0, ''
return df
for segment in segments:
text, start_time, end_time = segment
if len(text)>0:
df.loc[len(df)] = (NeMoGradioApp.sec_to_hrs(start_time),
NeMoGradioApp.sec_to_hrs(end_time),
text,
round(start_time, 2),
round(end_time, 2))
return df
@staticmethod
def get_max_duration(model_name):
if model_name == 'parakeet-tdt_ctc-1.1b':
return None
# Limit input audio duration to 20 mins for fastconformer models
elif 'fastconformer' in model_name or 'parakeet' in model_name:
return 20 * 60
# Limit input audio duration to 10 mins for conformer models
elif 'conformer' in model_name:
return 10 * 60
@staticmethod
def get_audio_segment(audio_path, start_second, end_second):
if audio_path is None:
return None
start_second = start_second * 1000
end_second = end_second * 1000
audio = AudioSegment.from_file(audio_path)
clipped_audio = audio[start_second:end_second]
samples = np.array(clipped_audio.get_array_of_samples())
if clipped_audio.channels == 2:
samples = samples.reshape((-1, 2))
samples = samples.mean(axis=1)
return (clipped_audio.frame_rate, samples)
def reset_demo_materials(self, selected_tab, selected_model, yt_link, yt_render, results_df, download_button, session_hash):
if selected_model == self.default_model and yt_link == f'https://www.youtube.com/watch?v={self.default_video_id}' and selected_tab == "youtube":
results_df = self.demo_results["Segments"][0].drop(['start_ms', 'end_ms'], axis=1)
yt_render = gr.HTML(f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{self.default_video_id}" allow="autoplay; encrypted-media"> </iframe>')
download_button = gr.DownloadButton("Download CSV", value=self.demo_results["Segments"][1], visible=True, interactive=True, elem_id='csv-button')
else:
self.loggers[session_hash].reset_logs()
yt_link = ""
yt_render = gr.HTML(f'<center> <iframe width="500" height="50" src="" allow="autoplay; encrypted-media"> </iframe>')
return yt_link, yt_render, results_df, download_button
def reset_yt_logger(self, results_df, selected_model, session_hash):
if selected_model != self.default_model:
return
demo_segments_df = self.demo_results['Segments'][0].drop(['start_ms', 'end_ms'], axis=1)
demo_words_df = self.demo_results['Words'][0].drop(['start_ms', 'end_ms'], axis=1)
if not results_df.empty and not results_df.equals(demo_segments_df) and not results_df.equals(demo_words_df):
self.loggers[session_hash].remind_about_demo()
def change_based_on_model(self, selected_model):
desc, more_info = self.show_model_description(selected_model)
max_duration = self.get_max_duration(selected_model)
if max_duration:
note_on_max_duration = gr.Markdown(f'<span style="color: orange; font-weight: bold;"> ⚠️ NOTE: In the current setup, you can transcribe **up to {max_duration // 60} minutes** of speech. ⚠️ </span>', visible=True)
else:
note_on_max_duration = gr.Markdown(value=None, visible=False)
file_input = gr.Audio(sources='upload', label='Upload Audio', type='filepath', max_length=max_duration)
mic_input = gr.Audio(sources='microphone', label='Record Audio', type='filepath', max_length=max_duration)
return desc, more_info, note_on_max_duration, gr.Number(max_duration, visible=False), file_input, mic_input
def show_available_models(self, lang=None):
if lang:
models_list = self.available_langs_models[lang]
else:
models_list = self.available_models_langs.keys()
return models_list
def update_dropdown(self, selected_value):
lang_models = self.show_available_models(selected_value)
return gr.Dropdown(choices=lang_models, value=lang_models[0], interactive=True)
def show_model_description(self, selected_model):
if selected_model:
description, more_info = get_model_description(selected_model)
return gr.Markdown(value=description, visible=True, label='Model Description'), gr.Markdown(value=more_info, visible=True, label='Model Description')
return gr.Markdown(visible=False), gr.Markdown(visible=False)
def save_results_to_csv(self, df, audio_path, timestamp_type, session_hash):
audio_name = Path(audio_path).stem.replace('_processed', '')
csv_name = f'{audio_name}_{timestamp_type.lower()}_timestamps'
csv_dir = Path(self.cache_dir, session_hash)
csv_dir.mkdir(exist_ok=True, parents=True)
csv_path = Path(csv_dir, csv_name).with_suffix('.csv').as_posix()
df.to_csv(csv_path, index=False)
return csv_path # Return the path for download
def on_row_click(self, evt: gr.SelectData, df, selected_src, html, file_path, mic_path):
selected_row = df.iloc[evt.index[0]]
start_seconds = selected_row['start_ms']
end_seconds = selected_row['end_ms']
file_selected_segment = gr.Audio(label='Selected Segment', visible=False)
mic_selected_segment = gr.Audio(label='Selected Segment', visible=False)
DEV_LOGGER.info(f"Selected Source {selected_src}")
if selected_src == 'youtube':
start_seconds = round(start_seconds)
end_seconds = round(end_seconds)
DEV_LOGGER.info(f"Start: {start_seconds} | End: {end_seconds}")
if start_seconds == end_seconds:
end_seconds = start_seconds + 1
match = re.search(r'src="([^"?]+)', html)
src_url = match.group(1) if match else None
html = gr.HTML(f'<center> <iframe width="500" height="320" src="{src_url}?start={start_seconds}&end={end_seconds}&autoplay=1" allow="autoplay; encrypted-media"> </iframe>')
elif selected_src == 'file' and file_path:
segment = self.get_audio_segment(file_path, start_seconds, end_seconds)
file_selected_segment = gr.Audio(segment, autoplay=True, label=f"{start_seconds}-{end_seconds} Second Segment", visible=True)
elif selected_src == 'mic' and mic_path:
segment = self.get_audio_segment(mic_path, start_seconds, end_seconds)
mic_selected_segment = gr.Audio(segment, autoplay=True, label=f"{start_seconds}-{end_seconds} Second Segment", visible=True)
return html, file_selected_segment, mic_selected_segment
def cleanup_yt_cache(self, yt_link, selected_model, session_hash):
if yt_link == f"https://www.youtube.com/watch?v={self.default_video_id}" and selected_model == self.default_model:
self.loggers[session_hash].reset_logs_for_demo()
else:
self.loggers[session_hash].reset_logs()
for file_path in self.get_yt_cache_dir(session_hash).glob('*'):
if file_path.is_file():
file_path.unlink()
def cleanup(self, request: gr.Request):
DEV_LOGGER.info(f'DELETING EVERYTHING FOR SESSION: {request.session_hash}')
session_cache_dir = Path(self.cache_dir, request.session_hash)
if session_cache_dir.exists():
shutil.rmtree(session_cache_dir)
if request.session_hash in self.loggers:
del self.loggers[request.session_hash]
if request.session_hash in self.caching_funcs:
DEV_LOGGER.info(f'DELETING SESSION CACHE FOR: {request.session_hash}')
self.caching_funcs[request.session_hash]['get_offsets'].cache_clear()
self.caching_funcs[request.session_hash]['get_model'].cache_clear()
del self.caching_funcs[request.session_hash]
gc.collect()
def get_processed_audio(self, model_name, timestamp_type, channel_to_use, url, file_path, microphone, html, session_hash, max_audio_length):
if not model_name:
raise gr.Error('Please, select a model to transcribe with!')
processed_path = None
if channel_to_use == 'youtube' and url == f"https://www.youtube.com/watch?v={self.default_video_id}":
audio_path = self.demo_audio_path
processed_path = self.demo_audio_path
elif channel_to_use == 'youtube':
yt_video_id = url.split('v=')[-1]
yt_cache_dir = self.get_yt_cache_dir(session_hash)
possible_files = list(yt_cache_dir.glob(f'{yt_video_id}_*_processed.flac'))
if possible_files:
processed_path = possible_files[0].as_posix()
audio_path = processed_path
else:
gr.Info("Downloading and processing audio from Youtube", duration=None)
audio_path, html = get_audio_from_youtube(url,
yt_cache_dir,
self.loggers[session_hash],
max_audio_length)
elif channel_to_use == 'file':
audio_path = file_path
else:
audio_path = microphone
DEV_LOGGER.info(f'SESSION ID: {session_hash} | USING CHANNEL: {channel_to_use}')
DEV_LOGGER.info(f'SESSION ID: {session_hash} | USING PATH: {audio_path}')
if not processed_path:
processed_path = process_audio(audio_path)
return "%".join([processed_path, timestamp_type, model_name]), html, channel_to_use
def get_timestamps(self, model_name, processed_path, timestamp_type, session_hash):
processed_path = "%".join(processed_path.split('%')[:-2])
if model_name == self.default_model and processed_path == 'demo_audio.flac':
results_df = self.demo_results[timestamp_type][0]
csv_path = self.demo_results[timestamp_type][1]
gr.Info("Results are ready!", duration=2)
return (results_df.drop(['start_ms', 'end_ms'], axis=1),
results_df[['start_ms', 'end_ms']],
gr.DownloadButton(value=csv_path, visible=True, interactive=True)
)
gr.Info("Running NeMo Model", duration=None)
preloaded_model = self.preloaded_demo_model if model_name == self.default_model else None
timestamps = get_aligned_transcription(model_name,
processed_path,
timestamp_type,
self.caching_funcs[session_hash]['get_model'],
self.caching_funcs[session_hash]['get_offsets'],
self.device,
preloaded_model=preloaded_model)
df = self.get_ts_dataframe(timestamps)
csv_path = self.save_results_to_csv(df, processed_path, timestamp_type, session_hash)
gr.Info("Results are ready!", duration=2)
return (df.drop(['start_ms', 'end_ms'], axis=1),
df[['start_ms', 'end_ms']],
gr.DownloadButton(value=csv_path, visible=True, interactive=True))
def build_inference(self):
with self.demo:
gr.HTML("<h1 style='text-align: center'>Transcription with Timestamps using NeMo STT Models &#129303</h1>")
gr.Markdown(f"""
<div style="text-align: center; font-size: 1.2em; font-weight: bold; padding: 10px; margin: 10px 0; border-top: 0px solid #ccc; border-bottom: 1px solid #ccc;">
Transcribe speech in <span style="color:orange">{round(len(self.available_langs_models) / 5) * 5}+ languages!</span>
</div>
""")
# gr.Button("Show Client Host").click(lambda client_host: client_host, inputs=client_host, outputs=output)
session_hash = gr.Textbox(visible=False)
max_audio_length = gr.Number(visible=False)
self.demo.load(self.get_session_starting, outputs=session_hash)
# User selection section
with gr.Row():
lang_dropdown = gr.Dropdown(choices=list(self.available_langs_models.keys()), value=self.default_lang, label="Select a Language", interactive=False)
model_dropdown = gr.Dropdown(choices=self.show_available_models(self.default_lang), value=self.default_model, label="Select a Model", interactive=False)
model_desc = gr.Markdown(visible=True, value=get_model_description(self.default_model_bck)[0])
model_more_info = gr.Markdown(visible=True, value=get_model_description(self.default_model_bck)[1])
# note_on_max_duration = gr.Markdown(visible=False)
note_on_max_duration = gr.Markdown(f'<span style="color: orange; font-weight: bold;"> ⚠️ NOTE: In the current setup, you can transcribe **up to 20 minutes** of speech. ⚠️ </span>', visible=True)
lang_dropdown.select(
fn=self.update_dropdown,
inputs=[lang_dropdown],
outputs=[model_dropdown]
)
gr.Markdown('<span style="color: red; font-weight: bold;"> ⚠️ This experimental space is for showcasing the new Parakeet2 model. That is why most of the features are not available. ⚠️</span>')
selected_tab = gr.State('youtube')
#Youtube Block
with gr.Tab('Audio from Youtube') as yt_tab:
gr.Markdown('<span style="color: red; font-weight: bold;"> ⚠️ You may be required to authenticate on [https://www.google.com/device](https://www.google.com/device) using the code provided in the logs to download a video from YouTube. ⚠️</span>')
yt_logs = gr.Code(value=None, language='markdown', lines=2, label='YouTube Logs')
with gr.Row():
yt_link = gr.Textbox(value=f'https://www.youtube.com/watch?v={self.default_video_id}', label='Enter Youtube Link', type='text')
yt_link.change(self.cleanup_yt_cache, inputs=[yt_link, model_dropdown, session_hash])
yt_render = gr.HTML(f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{self.default_video_id}" allow="autoplay; encrypted-media"> </iframe>')
yt_tab.select(lambda: 'youtube', outputs=selected_tab)
yt_logs.change(fn=None,
inputs=None,
outputs=None,
js=self.cm_js_code)
timer = gr.Timer(value=1)
timer.tick(self.get_logs, outputs=yt_logs)
#File Block
with gr.Tab('Audio from File') as file_tab:
file_input = gr.Audio(sources='upload', label='Upload Audio', type='filepath', max_length=1200)
file_selected_segment = gr.Audio(label='Selected Segment', visible=False)
file_input.change(lambda: gr.Audio(label='Selected Segment', visible=False), outputs=file_selected_segment)
file_tab.select(lambda: 'file', outputs=selected_tab)
#Mic Block
with gr.Tab('Audio from Microphone') as mic_tab:
mic_input = gr.Audio(sources='microphone', label='Record Audio', type='filepath', max_length=1200)
mic_selected_segment = gr.Audio(label='Selected Segment', visible=False)
mic_input.change(lambda: gr.Audio(label='Selected Segment', visible=False), outputs=mic_selected_segment)
mic_tab.select(lambda: 'mic', outputs=selected_tab)
with gr.Row():
timestamp_type = gr.Radio(["Segments", "Words"],
value='Segments',
label='Select timestamps granularity',
show_label=True)
gr.Markdown('Currently segments are formed based on the following punctuation marks: `. ? !`. \nIf the selected model does not support these punctuation marks, the segments will be formed based on silence duration between words.',
line_breaks=True)
with gr.Row():
timestamps_button = gr.Button("Get timestamps with text", variant='primary')
download_button = gr.DownloadButton("Download CSV", value=self.demo_results['Segments'][1], visible=True, interactive=True, elem_id='csv-button')
ms_df = gr.DataFrame(value=self.demo_results['Segments'][0][['start_ms', 'end_ms']], visible=False)
click_message = gr.Markdown(f"""
<div style="text-align: center; font-size: 1.2em; font-weight: bold; padding: 10px; margin: 10px 0; border-top: 0px solid #ccc; border-bottom: 1px solid #ccc;">
Ready to dive in? Just <span style="color:orange">click</span> on the text to jump to the part you need!
</div>
""")
user_inputs = gr.Textbox(value=None, visible=False)
displayed_results_source = gr.Textbox(value="youtube", visible=False)
timestamps_df = gr.DataFrame(value=self.demo_results['Segments'][0].drop(['start_ms', 'end_ms'], axis=1),
wrap=True,
label='Click on the text to jump to that part of the speech.',
show_label=False,
row_count=(1, "dynamic"),
col_count=(3, 'fixed'),
headers=['start_time', 'end_time', 'text'],
elem_id="target-table",
interactive=False)
model_dropdown.change(
fn=self.change_based_on_model,
inputs=[model_dropdown],
outputs=[model_desc, model_more_info, note_on_max_duration, max_audio_length, file_input, mic_input]
).then(fn=self.reset_demo_materials,
inputs=[selected_tab, model_dropdown, yt_link, yt_render, timestamps_df, download_button, session_hash],
outputs=[yt_link, yt_render, timestamps_df, download_button])
timestamps_df.select(self.on_row_click,
inputs=[ms_df, displayed_results_source, yt_render, file_input, mic_input],
outputs=[yt_render, file_selected_segment, mic_selected_segment])
timestamps_df.change(self.reset_yt_logger,
inputs=[timestamps_df, model_dropdown, session_hash])
timestamps_button.click(self.get_processed_audio,
inputs=[model_dropdown, timestamp_type, selected_tab, yt_link, file_input, mic_input, yt_render, session_hash, max_audio_length],
outputs=[user_inputs, yt_render, displayed_results_source],
concurrency_limit=8)
user_inputs.change(self.get_timestamps,
inputs=[model_dropdown, user_inputs, timestamp_type, session_hash],
outputs=[timestamps_df, ms_df, download_button],
concurrency_limit=2)
self.demo.unload(self.cleanup)
def launch(self):
self.demo.queue(True)
self.demo.launch(share=True, debug=True)
nemo_app = NeMoGradioApp()
demo = nemo_app.demo
if __name__ == '__main__':
nemo_app.launch()