rwightman HF staff commited on
Commit
d4ff2d1
1 Parent(s): adfbf11

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fnmatch
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ from rapidfuzz import fuzz
6
+ import re
7
+
8
+ def load_leaderboard():
9
+ # Load validation / test CSV files
10
+ results_csv_files = {
11
+ 'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv',
12
+ 'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv',
13
+ 'v2': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenetv2-matched-frequency.csv',
14
+ 'sketch': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-sketch.csv',
15
+ 'a': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-a.csv',
16
+ 'r': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-r.csv',
17
+ }
18
+
19
+ # Load benchmark CSV files
20
+ main_bench = 'amp-nhwc-pt210-cu121-rtx3090'
21
+ benchmark_csv_files = {
22
+ 'amp-nhwc-pt210-cu121-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt210-cu121-rtx3090.csv',
23
+ 'fp32-nchw-pt221-cpu-i9_10940x-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt221-cpu-i9_10940x-dynamo.csv',
24
+ }
25
+ # FIXME support selecting benchmark 'infer_samples_per_sec' / 'infer_step_time' from different benchmark files.
26
+
27
+ dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()}
28
+ bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()}
29
+ main_bench_dataframe = bench_dataframes[main_bench]
30
+
31
+ # Clean up dataframes
32
+ remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"]
33
+ for df in dataframes.values():
34
+ for col in remove_column_names:
35
+ if col in df.columns:
36
+ df.drop(columns=[col], inplace=True)
37
+
38
+ # Rename / process results columns
39
+ for name, df in dataframes.items():
40
+ df.rename(columns={"top1": f"{name}_top1", "top5": f"{name}_top5"}, inplace=True)
41
+ df['arch_name'] = df['model'].apply(lambda x: x.split('.')[0])
42
+
43
+ # Process benchmark dataframe
44
+ main_bench_dataframe['arch_name'] = main_bench_dataframe['model']
45
+ main_bench_dataframe.rename(columns={'infer_img_size': 'img_size'}, inplace=True)
46
+
47
+ # Merge all result dataframes
48
+ result = dataframes['imagenet']
49
+ for name, df in dataframes.items():
50
+ if name != 'imagenet':
51
+ result = pd.merge(result, df, on=['arch_name', 'model', 'img_size', 'crop_pct', 'interpolation'], how='outer')
52
+
53
+ # Merge with benchmark data
54
+ result = pd.merge(result, main_bench_dataframe, on=['arch_name', 'img_size'], how='left', suffixes=('', '_benchmark'))
55
+
56
+ # Calculate average scores
57
+ top1_columns = [col for col in result.columns if col.endswith('_top1')]
58
+ top5_columns = [col for col in result.columns if col.endswith('_top5')]
59
+ result['avg_top1'] = result[top1_columns].mean(axis=1)
60
+ result['avg_top5'] = result[top5_columns].mean(axis=1)
61
+
62
+ # Reorder columns
63
+ first_columns = ['model', 'img_size', 'avg_top1', 'avg_top5']
64
+ other_columns = [col for col in result.columns if col not in first_columns and col != 'model_benchmark']
65
+ result = result[first_columns + other_columns]
66
+
67
+ # Drop columns that are no longer needed / add too much noise
68
+ result.drop('arch_name', axis=1, inplace=True)
69
+ result.drop('crop_pct', axis=1, inplace=True)
70
+ result.drop('interpolation', axis=1, inplace=True)
71
+
72
+ # Round numerical values
73
+ result = result.round(2)
74
+
75
+ return result
76
+
77
+
78
+ REGEX_PREFIX = "re:"
79
+
80
+ def auto_match(pattern, text):
81
+ # Check if it's a regex pattern (starts with 're:')
82
+ if pattern.startswith(REGEX_PREFIX):
83
+ regex_pattern = pattern[len(REGEX_PREFIX):].strip()
84
+ try:
85
+ return bool(re.match(regex_pattern, text, re.IGNORECASE))
86
+ except re.error:
87
+ # If it's an invalid regex, return False
88
+ return False
89
+
90
+ # Check if it's a wildcard pattern
91
+ elif any(char in pattern for char in ['*', '?']):
92
+ return fnmatch.fnmatch(text.lower(), pattern.lower())
93
+
94
+ # If not regex or wildcard, use fuzzy matching
95
+ else:
96
+ return fuzz.partial_ratio(
97
+ pattern.lower(), text.lower(), score_cutoff=90) > 0
98
+
99
+
100
+ def filter_leaderboard(df, model_name, sort_by):
101
+ if not model_name:
102
+ return df.sort_values(by=sort_by, ascending=False)
103
+
104
+ mask = df['model'].apply(lambda x: auto_match(model_name, x))
105
+ filtered_df = df[mask].sort_values(by=sort_by, ascending=False)
106
+
107
+ return filtered_df
108
+
109
+ def create_scatter_plot(df, x_axis, y_axis):
110
+ fig = px.scatter(
111
+ df,
112
+ x=x_axis,
113
+ y=y_axis,
114
+ log_x=True,
115
+ log_y=True,
116
+ hover_data=['model'],
117
+ trendline='ols',
118
+ trendline_options=dict(log_x=True, log_y=True),
119
+ title=f'{y_axis} vs {x_axis}'
120
+ )
121
+ return fig
122
+
123
+ # Load the leaderboard data
124
+ full_df = load_leaderboard()
125
+
126
+ # Define the available columns for sorting and plotting
127
+ sort_columns = ['avg_top1', 'avg_top5', 'infer_samples_per_sec', 'param_count', 'infer_gmacs', 'infer_macts']
128
+ plot_columns = ['infer_samples_per_sec', 'infer_gmacs', 'infer_macts', 'param_count', 'avg_top1', 'avg_top5']
129
+
130
+ DEFAULT_SEARCH = ""
131
+ DEFAULT_SORT = "avg_top1"
132
+ DEFAULT_X = "infer_samples_per_sec"
133
+ DEFAULT_Y = "avg_top1"
134
+
135
+ def update_leaderboard_and_plot(model_name=DEFAULT_SEARCH, sort_by=DEFAULT_SORT, x_axis=DEFAULT_X, y_axis=DEFAULT_Y):
136
+ filtered_df = filter_leaderboard(
137
+ full_df, # in outer scope
138
+ model_name,
139
+ sort_by,
140
+ )
141
+ fig = create_scatter_plot(filtered_df, x_axis, y_axis)
142
+ return filtered_df, fig
143
+
144
+
145
+ with gr.Blocks(title="The timm Leaderboard") as app:
146
+ gr.HTML("<center><h1>PyTorch Image Models Leaderboard</h1></center>")
147
+ gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>PyTorch Image Models</a>.</p>")
148
+ gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
149
+
150
+ with gr.Row():
151
+ search_bar = gr.Textbox(lines=1, label="Search Model", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3)
152
+ sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
153
+
154
+ with gr.Row():
155
+ x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
156
+ y_axis = gr.Dropdown(choices=plot_columns, label="Y-axis", value=DEFAULT_Y)
157
+
158
+ update_btn = gr.Button(value="Update", variant="primary")
159
+
160
+ leaderboard = gr.Dataframe()
161
+ plot = gr.Plot()
162
+
163
+ app.load(update_leaderboard_and_plot, outputs=[leaderboard, plot])
164
+
165
+ update_btn.click(
166
+ update_leaderboard_and_plot,
167
+ inputs=[search_bar, sort_dropdown, x_axis, y_axis],
168
+ outputs=[leaderboard, plot]
169
+ )
170
+
171
+ app.launch()