Spaces:
Sleeping
Sleeping
File size: 6,696 Bytes
946726e e4732e1 946726e e4732e1 946726e e4732e1 946726e e4732e1 946726e 2ff44a9 e4732e1 2ff44a9 e4732e1 946726e e4732e1 946726e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import pandas as pd
import json
from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS
from init import is_model_on_hub, upload_file, load_all_info_from_dataset_hub
from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message
from datetime import datetime, timezone
from utils_display import make_best_bold
import plotly.graph_objects as go
LAST_UPDATED = "Sep 11th 2024"
column_names = {
"MODEL": "Model",
"Avg. WER": "Average WER ⬇️ ",
"Avg. RTFx": "RTFx ⬆️ ",
"AMI WER": "AMI",
"Earnings22 WER": "Earnings22",
"Gigaspeech WER": "Gigaspeech",
"LS Clean WER": "LS Clean",
"LS Other WER": "LS Other",
"SPGISpeech WER": "SPGISpeech",
}
original_df = pd.read_csv("data.csv")
requested_models = []
# Formats the columns
def formatter(x):
if type(x) is str:
x = x
else:
x = round(x, 2)
return x
def format_df(df:pd.DataFrame):
for col in df.columns:
if col == "model":
df[col] = df[col].apply(lambda x: x.replace(x, make_clickable_model(x)))
else:
df[col] = make_best_bold(df[col], col)
return df
original_df = format_df(original_df)
original_df.rename(columns=column_names, inplace=True)
original_df.sort_values(by='Average WER ⬇️ ', inplace=True)
COLS = [c.name for c in fields(AutoEvalColumn)]
TYPES = [c.type for c in fields(AutoEvalColumn)]
def request_model(model_text, chbcoco2017):
# ... (keep the existing request_model function as is)
pass
def update_table(column_selection, search:str):
original_df = pd.read_csv("data.csv")
original_df = original_df[original_df['model'].str.contains(search, case=False, na=False)]
if column_selection == "All Columns":
new_df = original_df
elif column_selection == "Main Metrics":
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ "]]
elif column_selection == "Narrated":
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "LS Clean", "LS Other", "Gigaspeech"]]
new_df["Average WER ⬇️ "] = new_df[["LS Clean", "LS Other", "Gigaspeech"]].mean(axis=1).round(2)
elif column_selection == "Oratory":
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Tedlium", "SPGISpeech", "Earnings22"]]
new_df["Average WER ⬇️ "] = new_df[["Tedlium", "SPGISpeech", "Earnings22"]].mean(axis=1).round(2)
elif column_selection == "Spontaneous":
new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]]
new_df["Average WER ⬇️ "] = new_df[["Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]].mean(axis=1).round(2)
new_df = new_df.sort_values(by='Average WER ⬇️ ', ascending=True)
new_df = format_df(new_df)
return new_df
def generate_plot():
df = pd.read_csv("data.csv")
fig = go.Figure()
fig.add_trace(go.Scatter(
x=df['Average WER ⬇️ '],
y=df['RTFx ⬆️ '],
mode='markers',
text=df['model'],
hovertemplate=
'<b>%{text}</b><br>' +
'Average WER: %{x:.2f}<br>' +
'RTFx: %{y:.2f}<br>' +
'<extra></extra>',
marker=dict(
size=10,
# color=df['Average WER ⬇️ '],
colorscale='Viridis',
# colorbar=dict(title='Average WER'),
# showscale=True
)
))
# Update the layout
fig.update_layout(
title='ASR Model Performance: Average WER vs RTFx',
xaxis_title='Average WER (lower is better)',
yaxis_title='RTFx (higher is better)',
#yaxis_type='log',
hovermode='closest'
)
# Show the plot
return fig
with gr.Blocks() as demo:
gr.HTML(BANNER, elem_id="banner")
gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("🏅 Leaderboard", elem_id="od-benchmark-tab-table", id=0):
leaderboard_table = gr.components.Dataframe(
value=original_df,
datatype=TYPES,
elem_id="leaderboard-table",
interactive=False,
visible=True,
height=500,
)
with gr.Accordion("📌 Select a more detailed subset",open=False):
column_radio = gr.Radio(
["All Columns", "Main Metrics", "Narrated", "Oratory", "Spontaneous"],
label="Categories",
value="All Columns"
)
search_bar = gr.Textbox(label="Search models", placeholder="Enter model name...")
column_radio.change(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table])
search_bar.submit(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table])
with gr.TabItem("📈 Metrics", elem_id="od-benchmark-tab-table", id=1):
gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text")
with gr.TabItem("✉️✨ Request a model here!", elem_id="od-benchmark-tab-table", id=2):
with gr.Column():
gr.Markdown("# ✉️✨ Request results for a new model here!", elem_classes="markdown-text")
with gr.Column():
gr.Markdown("Select a dataset:", elem_classes="markdown-text")
with gr.Column():
model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)")
chb_coco2017 = gr.Checkbox(label="COCO validation 2017 dataset", visible=False, value=True, interactive=False)
with gr.Column():
mdw_submission_result = gr.Markdown()
btn_submitt = gr.Button(value="🚀 Request")
btn_submitt.click(request_model,
[model_name_textbox, chb_coco2017],
mdw_submission_result)
with gr.TabItem("📊 Plots", elem_id="od-benchmark-tab-table", id=3):
plot = gr.Plot(generate_plot)
gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text")
with gr.Row():
with gr.Accordion("📙 Citation", open=False):
gr.Textbox(
value=CITATION_TEXT, lines=7,
label="Copy the BibTeX snippet to cite this source",
elem_id="citation-button",
show_copy_button=True,
)
demo.launch() |