Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import json | |
from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS | |
from init import is_model_on_hub, upload_file, load_all_info_from_dataset_hub | |
from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message | |
from datetime import datetime, timezone | |
from utils_display import make_best_bold | |
import plotly.graph_objects as go | |
LAST_UPDATED = "Sep 11th 2024" | |
column_names = { | |
"MODEL": "Model", | |
"Avg. WER": "Average WER ⬇️ ", | |
"Avg. RTFx": "RTFx ⬆️ ", | |
"AMI WER": "AMI", | |
"Earnings22 WER": "Earnings22", | |
"Gigaspeech WER": "Gigaspeech", | |
"LS Clean WER": "LS Clean", | |
"LS Other WER": "LS Other", | |
"SPGISpeech WER": "SPGISpeech", | |
} | |
original_df = pd.read_csv("data.csv") | |
requested_models = [] | |
# Formats the columns | |
def formatter(x): | |
if type(x) is str: | |
x = x | |
else: | |
x = round(x, 2) | |
return x | |
def format_df(df:pd.DataFrame): | |
for col in df.columns: | |
if col == "model": | |
df[col] = df[col].apply(lambda x: x.replace(x, make_clickable_model(x))) | |
else: | |
df[col] = make_best_bold(df[col], col) | |
return df | |
original_df = format_df(original_df) | |
original_df.rename(columns=column_names, inplace=True) | |
original_df.sort_values(by='Average WER ⬇️ ', inplace=True) | |
COLS = [c.name for c in fields(AutoEvalColumn)] | |
TYPES = [c.type for c in fields(AutoEvalColumn)] | |
def request_model(model_text, chbcoco2017): | |
# ... (keep the existing request_model function as is) | |
pass | |
def update_table(column_selection, search:str): | |
original_df = pd.read_csv("data.csv") | |
original_df = original_df[original_df['model'].str.contains(search, case=False, na=False)] | |
if column_selection == "All Columns": | |
new_df = original_df | |
elif column_selection == "Main Metrics": | |
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ "]] | |
elif column_selection == "Narrated": | |
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "LS Clean", "LS Other", "Gigaspeech"]] | |
new_df["Average WER ⬇️ "] = new_df[["LS Clean", "LS Other", "Gigaspeech"]].mean(axis=1).round(2) | |
elif column_selection == "Oratory": | |
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Tedlium", "SPGISpeech", "Earnings22"]] | |
new_df["Average WER ⬇️ "] = new_df[["Tedlium", "SPGISpeech", "Earnings22"]].mean(axis=1).round(2) | |
elif column_selection == "Spontaneous": | |
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]] | |
new_df["Average WER ⬇️ "] = new_df[["Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]].mean(axis=1).round(2) | |
new_df = new_df.sort_values(by='Average WER ⬇️ ', ascending=True) | |
new_df = format_df(new_df) | |
return new_df | |
def generate_plot(): | |
df = pd.read_csv("data.csv") | |
fig = go.Figure() | |
fig.add_trace(go.Scatter( | |
x=df['Average WER ⬇️ '], | |
y=df['RTFx ⬆️ '], | |
mode='markers', | |
text=df['model'], | |
hovertemplate= | |
'<b>%{text}</b><br>' + | |
'Average WER: %{x:.2f}<br>' + | |
'RTFx: %{y:.2f}<br>' + | |
'<extra></extra>', | |
marker=dict( | |
size=10, | |
# color=df['Average WER ⬇️ '], | |
colorscale='Viridis', | |
# colorbar=dict(title='Average WER'), | |
# showscale=True | |
) | |
)) | |
# Update the layout | |
fig.update_layout( | |
title='ASR Model Performance: Average WER vs RTFx', | |
xaxis_title='Average WER (lower is better)', | |
yaxis_title='RTFx (higher is better)', | |
#yaxis_type='log', | |
hovermode='closest' | |
) | |
# Show the plot | |
return fig | |
with gr.Blocks() as demo: | |
gr.HTML(BANNER, elem_id="banner") | |
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("🏅 Leaderboard", elem_id="od-benchmark-tab-table", id=0): | |
leaderboard_table = gr.components.Dataframe( | |
value=original_df, | |
datatype=TYPES, | |
elem_id="leaderboard-table", | |
interactive=False, | |
visible=True, | |
height=500, | |
) | |
with gr.Accordion("📌 Select a more detailed subset",open=False): | |
column_radio = gr.Radio( | |
["All Columns", "Main Metrics", "Narrated", "Oratory", "Spontaneous"], | |
label="Categories", | |
value="All Columns" | |
) | |
search_bar = gr.Textbox(label="Search models", placeholder="Enter model name...") | |
column_radio.change(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table]) | |
search_bar.submit(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table]) | |
with gr.TabItem("📈 Metrics", elem_id="od-benchmark-tab-table", id=1): | |
gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text") | |
with gr.TabItem("✉️✨ Request a model here!", elem_id="od-benchmark-tab-table", id=2): | |
with gr.Column(): | |
gr.Markdown("# ✉️✨ Request results for a new model here!", elem_classes="markdown-text") | |
with gr.Column(): | |
gr.Markdown("Select a dataset:", elem_classes="markdown-text") | |
with gr.Column(): | |
model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)") | |
chb_coco2017 = gr.Checkbox(label="COCO validation 2017 dataset", visible=False, value=True, interactive=False) | |
with gr.Column(): | |
mdw_submission_result = gr.Markdown() | |
btn_submitt = gr.Button(value="🚀 Request") | |
btn_submitt.click(request_model, | |
[model_name_textbox, chb_coco2017], | |
mdw_submission_result) | |
with gr.TabItem("📊 Plots", elem_id="od-benchmark-tab-table", id=3): | |
plot = gr.Plot(generate_plot) | |
gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text") | |
with gr.Row(): | |
with gr.Accordion("📙 Citation", open=False): | |
gr.Textbox( | |
value=CITATION_TEXT, lines=7, | |
label="Copy the BibTeX snippet to cite this source", | |
elem_id="citation-button", | |
show_copy_button=True, | |
) | |
demo.launch() |