import fnmatch import gradio as gr import pandas as pd import plotly.express as px from rapidfuzz import fuzz import re def load_leaderboard(): # Load validation / test CSV files results_csv_files = { 'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv', 'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv', 'v2': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenetv2-matched-frequency.csv', 'sketch': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-sketch.csv', 'a': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-a.csv', 'r': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-r.csv', } # Load benchmark CSV files main_bench = 'amp-nhwc-pt210-cu121-rtx3090' benchmark_csv_files = { 'amp-nhwc-pt210-cu121-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt210-cu121-rtx3090.csv', 'fp32-nchw-pt221-cpu-i9_10940x-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt221-cpu-i9_10940x-dynamo.csv', } # FIXME support selecting benchmark 'infer_samples_per_sec' / 'infer_step_time' from different benchmark files. dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()} bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()} main_bench_dataframe = bench_dataframes[main_bench] # Clean up dataframes remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"] for df in dataframes.values(): for col in remove_column_names: if col in df.columns: df.drop(columns=[col], inplace=True) # Rename / process results columns for name, df in dataframes.items(): df.rename(columns={"top1": f"{name}_top1", "top5": f"{name}_top5"}, inplace=True) df['arch_name'] = df['model'].apply(lambda x: x.split('.')[0]) # Process benchmark dataframe main_bench_dataframe['arch_name'] = main_bench_dataframe['model'] main_bench_dataframe.rename(columns={'infer_img_size': 'img_size'}, inplace=True) # Merge all result dataframes result = dataframes['imagenet'] for name, df in dataframes.items(): if name != 'imagenet': result = pd.merge(result, df, on=['arch_name', 'model', 'img_size', 'crop_pct', 'interpolation'], how='outer') # Merge with benchmark data result = pd.merge(result, main_bench_dataframe, on=['arch_name', 'img_size'], how='left', suffixes=('', '_benchmark')) # Calculate average scores top1_columns = [col for col in result.columns if col.endswith('_top1')] top5_columns = [col for col in result.columns if col.endswith('_top5')] result['avg_top1'] = result[top1_columns].mean(axis=1) result['avg_top5'] = result[top5_columns].mean(axis=1) # Reorder columns first_columns = ['model', 'img_size', 'avg_top1', 'avg_top5'] other_columns = [col for col in result.columns if col not in first_columns and col != 'model_benchmark'] result = result[first_columns + other_columns] # Drop columns that are no longer needed / add too much noise result.drop('arch_name', axis=1, inplace=True) result.drop('crop_pct', axis=1, inplace=True) result.drop('interpolation', axis=1, inplace=True) # Round numerical values result = result.round(2) return result REGEX_PREFIX = "re:" def auto_match(pattern, text): # Check if it's a regex pattern (starts with 're:') if pattern.startswith(REGEX_PREFIX): regex_pattern = pattern[len(REGEX_PREFIX):].strip() try: return bool(re.match(regex_pattern, text, re.IGNORECASE)) except re.error: # If it's an invalid regex, return False return False # Check if it's a wildcard pattern elif any(char in pattern for char in ['*', '?']): return fnmatch.fnmatch(text.lower(), pattern.lower()) # If not regex or wildcard, use fuzzy matching else: return fuzz.partial_ratio( pattern.lower(), text.lower(), score_cutoff=90) > 0 def filter_leaderboard(df, model_name, sort_by): if not model_name: return df.sort_values(by=sort_by, ascending=False) mask = df['model'].apply(lambda x: auto_match(model_name, x)) filtered_df = df[mask].sort_values(by=sort_by, ascending=False) return filtered_df def create_scatter_plot(df, x_axis, y_axis): fig = px.scatter( df, x=x_axis, y=y_axis, log_x=True, log_y=True, hover_data=['model'], trendline='ols', trendline_options=dict(log_x=True, log_y=True), title=f'{y_axis} vs {x_axis}' ) return fig # Load the leaderboard data full_df = load_leaderboard() # Define the available columns for sorting and plotting sort_columns = ['avg_top1', 'avg_top5', 'infer_samples_per_sec', 'param_count', 'infer_gmacs', 'infer_macts'] plot_columns = ['infer_samples_per_sec', 'infer_gmacs', 'infer_macts', 'param_count', 'avg_top1', 'avg_top5'] DEFAULT_SEARCH = "" DEFAULT_SORT = "avg_top1" DEFAULT_X = "infer_samples_per_sec" DEFAULT_Y = "avg_top1" def update_leaderboard_and_plot(model_name=DEFAULT_SEARCH, sort_by=DEFAULT_SORT, x_axis=DEFAULT_X, y_axis=DEFAULT_Y): filtered_df = filter_leaderboard( full_df, # in outer scope model_name, sort_by, ) fig = create_scatter_plot(filtered_df, x_axis, y_axis) return filtered_df, fig with gr.Blocks(title="The timm Leaderboard") as app: gr.HTML("

The timm (PyTorch Image Models) Leaderboard

") gr.HTML("

This leaderboard is based on the results of the models from timm.

") gr.HTML("

Search tips:
- Use wildcards (* or ?) for pattern matching
- Use 're:' prefix for regex search
- Otherwise, fuzzy matching will be used

") with gr.Row(): search_bar = gr.Textbox(lines=1, label="Search Model", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3) sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1) with gr.Row(): x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X) y_axis = gr.Dropdown(choices=plot_columns, label="Y-axis", value=DEFAULT_Y) update_btn = gr.Button(value="Update", variant="primary") leaderboard = gr.Dataframe() plot = gr.Plot() app.load(update_leaderboard_and_plot, outputs=[leaderboard, plot]) search_bar.submit( update_leaderboard_and_plot, inputs=[search_bar, sort_dropdown, x_axis, y_axis], outputs=[leaderboard, plot] ) update_btn.click( update_leaderboard_and_plot, inputs=[search_bar, sort_dropdown, x_axis, y_axis], outputs=[leaderboard, plot] ) app.launch()