import gradio as gr import pandas as pd import json from constants import BANNER, INTRODUCTION_TEXT, CITATION_TEXT, METRICS_TAB_TEXT, DIR_OUTPUT_REQUESTS from init import is_model_on_hub, upload_file, load_all_info_from_dataset_hub from utils_display import AutoEvalColumn, fields, make_clickable_model, styled_error, styled_message from datetime import datetime, timezone from utils_display import make_best_bold import plotly.graph_objects as go LAST_UPDATED = "Sep 11th 2024" column_names = { "MODEL": "Model", "Avg. WER": "Average WER ⬇️ ", "Avg. RTFx": "RTFx ⬆️ ", "AMI WER": "AMI", "Earnings22 WER": "Earnings22", "Gigaspeech WER": "Gigaspeech", "LS Clean WER": "LS Clean", "LS Other WER": "LS Other", "SPGISpeech WER": "SPGISpeech", } original_df = pd.read_csv("data.csv") requested_models = [] # Formats the columns def formatter(x): if type(x) is str: x = x else: x = round(x, 2) return x def format_df(df:pd.DataFrame): for col in df.columns: if col == "model": df[col] = df[col].apply(lambda x: x.replace(x, make_clickable_model(x))) else: df[col] = make_best_bold(df[col], col) return df original_df = format_df(original_df) original_df.rename(columns=column_names, inplace=True) original_df.sort_values(by='Average WER ⬇️ ', inplace=True) COLS = [c.name for c in fields(AutoEvalColumn)] TYPES = [c.type for c in fields(AutoEvalColumn)] def request_model(model_text, chbcoco2017): # ... (keep the existing request_model function as is) pass def update_table(column_selection, search:str): original_df = pd.read_csv("data.csv") original_df = original_df[original_df['model'].str.contains(search, case=False, na=False)] if column_selection == "All Columns": new_df = original_df elif column_selection == "Main Metrics": new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ "]] elif column_selection == "Narrated": new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "LS Clean", "LS Other", "Gigaspeech"]] new_df["Average WER ⬇️ "] = new_df[["LS Clean", "LS Other", "Gigaspeech"]].mean(axis=1).round(2) elif column_selection == "Oratory": new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Tedlium", "SPGISpeech", "Earnings22"]] new_df["Average WER ⬇️ "] = new_df[["Tedlium", "SPGISpeech", "Earnings22"]].mean(axis=1).round(2) elif column_selection == "Spontaneous": new_df = original_df[["model", "Average WER ⬇️ ", "RTFx ⬆️ ", "Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]] new_df["Average WER ⬇️ "] = new_df[["Gigaspeech", "SPGISpeech", "Earnings22", "AMI"]].mean(axis=1).round(2) new_df = new_df.sort_values(by='Average WER ⬇️ ', ascending=True) new_df = format_df(new_df) return new_df def generate_plot(): df = pd.read_csv("data.csv") fig = go.Figure() fig.add_trace(go.Scatter( x=df['Average WER ⬇️ '], y=df['RTFx ⬆️ '], mode='markers', text=df['model'], hovertemplate= '%{text}
' + 'Average WER: %{x:.2f}
' + 'RTFx: %{y:.2f}
' + '', marker=dict( size=10, # color=df['Average WER ⬇️ '], colorscale='Viridis', # colorbar=dict(title='Average WER'), # showscale=True ) )) # Update the layout fig.update_layout( title='ASR Model Performance: Average WER vs RTFx', xaxis_title='Average WER (lower is better)', yaxis_title='RTFx (higher is better)', #yaxis_type='log', hovermode='closest' ) # Show the plot return fig with gr.Blocks() as demo: gr.HTML(BANNER, elem_id="banner") gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text") with gr.Tabs(elem_classes="tab-buttons") as tabs: with gr.TabItem("🏅 Leaderboard", elem_id="od-benchmark-tab-table", id=0): leaderboard_table = gr.components.Dataframe( value=original_df, datatype=TYPES, elem_id="leaderboard-table", interactive=False, visible=True, height=500, ) with gr.Accordion("📌 Select a more detailed subset",open=False): column_radio = gr.Radio( ["All Columns", "Main Metrics", "Narrated", "Oratory", "Spontaneous"], label="Categories", value="All Columns" ) search_bar = gr.Textbox(label="Search models", placeholder="Enter model name...") column_radio.change(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table]) search_bar.submit(update_table, inputs=[column_radio, search_bar], outputs=[leaderboard_table]) with gr.TabItem("📈 Metrics", elem_id="od-benchmark-tab-table", id=1): gr.Markdown(METRICS_TAB_TEXT, elem_classes="markdown-text") with gr.TabItem("✉️✨ Request a model here!", elem_id="od-benchmark-tab-table", id=2): with gr.Column(): gr.Markdown("# ✉️✨ Request results for a new model here!", elem_classes="markdown-text") with gr.Column(): gr.Markdown("Select a dataset:", elem_classes="markdown-text") with gr.Column(): model_name_textbox = gr.Textbox(label="Model name (user_name/model_name)") chb_coco2017 = gr.Checkbox(label="COCO validation 2017 dataset", visible=False, value=True, interactive=False) with gr.Column(): mdw_submission_result = gr.Markdown() btn_submitt = gr.Button(value="🚀 Request") btn_submitt.click(request_model, [model_name_textbox, chb_coco2017], mdw_submission_result) with gr.TabItem("📊 Plots", elem_id="od-benchmark-tab-table", id=3): plot = gr.Plot(generate_plot) gr.Markdown(f"Last updated on **{LAST_UPDATED}**", elem_classes="markdown-text") with gr.Row(): with gr.Accordion("📙 Citation", open=False): gr.Textbox( value=CITATION_TEXT, lines=7, label="Copy the BibTeX snippet to cite this source", elem_id="citation-button", show_copy_button=True, ) demo.launch()