j-tobias
added search bar + plot
e4732e1
import gradio as gr
import pandas as pd
import json
from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS
from init import is_model_on_hub, upload_file, load_all_info_from_dataset_hub
from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message
from datetime import datetime, timezone
from utils_display import make_best_bold
import plotly.graph_objects as go
LAST_UPDATED = "Sep 11th 2024"
column_names = {
"MODEL": "Model",
"Avg. WER": "Average WER ⬇️ ",
"Avg. RTFx": "RTFx ⬆️ ",
"AMI WER": "AMI",
"Earnings22 WER": "Earnings22",
"Gigaspeech WER": "Gigaspeech",
"LS Clean WER": "LS Clean",
"LS Other WER": "LS Other",
"SPGISpeech WER": "SPGISpeech",
}
original_df = pd.read_csv("data.csv")
requested_models = []
# Formats the columns
def formatter(x):
if type(x) is str:
x = x
else:
x = round(x, 2)
return x
def format_df(df:pd.DataFrame):
for col in df.columns:
if col == "model":
df[col] = df[col].apply(lambda x: x.replace(x, make_clickable_model(x)))
else:
df[col] = make_best_bold(df[col], col)
return df
original_df = format_df(original_df)
original_df.rename(columns=column_names, inplace=True)
original_df.sort_values(by='Average WER ⬇️ ', inplace=True)
COLS = [c.name for c in fields(AutoEvalColumn)]
TYPES = [c.type for c in fields(AutoEvalColumn)]
def request_model(model_text, chbcoco2017):
# ... (keep the existing request_model function as is)
pass
def update_table(column_selection, search:str):
original_df = pd.read_csv("data.csv")
original_df = original_df[original_df['model'].str.contains(search, case=False, na=False)]
if column_selection == "All Columns":
new_df = original_df
elif column_selection == "Main Metrics":
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ "]]
elif column_selection == "Narrated":
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "LS Clean", "LS Other", "Gigaspeech"]]
new_df["Average WER ⬇️ "] = new_df[["LS Clean", "LS Other", "Gigaspeech"]].mean(axis=1).round(2)
elif column_selection == "Oratory":
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Tedlium", "SPGISpeech", "Earnings22"]]
new_df["Average WER ⬇️ "] = new_df[["Tedlium", "SPGISpeech", "Earnings22"]].mean(axis=1).round(2)
elif column_selection == "Spontaneous":
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]]
new_df["Average WER ⬇️ "] = new_df[["Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]].mean(axis=1).round(2)
new_df = new_df.sort_values(by='Average WER ⬇️ ', ascending=True)
new_df = format_df(new_df)
return new_df
def generate_plot():
df = pd.read_csv("data.csv")
fig = go.Figure()
fig.add_trace(go.Scatter(
x=df['Average WER ⬇️ '],
y=df['RTFx ⬆️ '],
mode='markers',
text=df['model'],
hovertemplate=
'<b>%{text}</b><br>' +
'Average WER: %{x:.2f}<br>' +
'RTFx: %{y:.2f}<br>' +
'<extra></extra>',
marker=dict(
size=10,
# color=df['Average WER ⬇️ '],
colorscale='Viridis',
# colorbar=dict(title='Average WER'),
# showscale=True
)
))
# Update the layout
fig.update_layout(
title='ASR Model Performance: Average WER vs RTFx',
xaxis_title='Average WER (lower is better)',
yaxis_title='RTFx (higher is better)',
#yaxis_type='log',
hovermode='closest'
)
# Show the plot
return fig
with gr.Blocks() as demo:
gr.HTML(BANNER, elem_id="banner")
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("🏅 Leaderboard", elem_id="od-benchmark-tab-table", id=0):
leaderboard_table = gr.components.Dataframe(
value=original_df,
datatype=TYPES,
elem_id="leaderboard-table",
interactive=False,
visible=True,
height=500,
)
with gr.Accordion("📌 Select a more detailed subset",open=False):
column_radio = gr.Radio(
["All Columns", "Main Metrics", "Narrated", "Oratory", "Spontaneous"],
label="Categories",
value="All Columns"
)
search_bar = gr.Textbox(label="Search models", placeholder="Enter model name...")
column_radio.change(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table])
search_bar.submit(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table])
with gr.TabItem("📈 Metrics", elem_id="od-benchmark-tab-table", id=1):
gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text")
with gr.TabItem("✉️✨ Request a model here!", elem_id="od-benchmark-tab-table", id=2):
with gr.Column():
gr.Markdown("# ✉️✨ Request results for a new model here!", elem_classes="markdown-text")
with gr.Column():
gr.Markdown("Select a dataset:", elem_classes="markdown-text")
with gr.Column():
model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)")
chb_coco2017 = gr.Checkbox(label="COCO validation 2017 dataset", visible=False, value=True, interactive=False)
with gr.Column():
mdw_submission_result = gr.Markdown()
btn_submitt = gr.Button(value="🚀 Request")
btn_submitt.click(request_model,
[model_name_textbox, chb_coco2017],
mdw_submission_result)
with gr.TabItem("📊 Plots", elem_id="od-benchmark-tab-table", id=3):
plot = gr.Plot(generate_plot)
gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text")
with gr.Row():
with gr.Accordion("📙 Citation", open=False):
gr.Textbox(
value=CITATION_TEXT, lines=7,
label="Copy the BibTeX snippet to cite this source",
elem_id="citation-button",
show_copy_button=True,
)
demo.launch()