rwightman HF staff commited on
Commit
fcec301
1 Parent(s): 02c5d0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -12
app.py CHANGED
@@ -68,6 +68,8 @@ def load_leaderboard():
68
  result.drop('arch_name', axis=1, inplace=True)
69
  result.drop('crop_pct', axis=1, inplace=True)
70
  result.drop('interpolation', axis=1, inplace=True)
 
 
71
 
72
  # Round numerical values
73
  result = result.round(2)
@@ -106,6 +108,7 @@ def filter_leaderboard(df, model_name, sort_by):
106
 
107
  return filtered_df
108
 
 
109
  def create_scatter_plot(df, x_axis, y_axis):
110
  fig = px.scatter(
111
  df,
@@ -116,10 +119,15 @@ def create_scatter_plot(df, x_axis, y_axis):
116
  hover_data=['model'],
117
  trendline='ols',
118
  trendline_options=dict(log_x=True, log_y=True),
 
 
119
  title=f'{y_axis} vs {x_axis}'
120
  )
 
 
121
  return fig
122
 
 
123
  # Load the leaderboard data
124
  full_df = load_leaderboard()
125
 
@@ -132,14 +140,30 @@ DEFAULT_SORT = "avg_top1"
132
  DEFAULT_X = "infer_samples_per_sec"
133
  DEFAULT_Y = "avg_top1"
134
 
135
- def update_leaderboard_and_plot(model_name=DEFAULT_SEARCH, sort_by=DEFAULT_SORT, x_axis=DEFAULT_X, y_axis=DEFAULT_Y):
136
- filtered_df = filter_leaderboard(
137
- full_df, # in outer scope
138
- model_name,
139
- sort_by,
140
- )
141
- fig = create_scatter_plot(filtered_df, x_axis, y_axis)
142
- return filtered_df, fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  with gr.Blocks(title="The timm Leaderboard") as app:
@@ -148,8 +172,11 @@ with gr.Blocks(title="The timm Leaderboard") as app:
148
  gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
149
 
150
  with gr.Row():
151
- search_bar = gr.Textbox(lines=1, label="Search Model", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3)
152
  sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
 
 
 
153
 
154
  with gr.Row():
155
  x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
@@ -164,13 +191,17 @@ with gr.Blocks(title="The timm Leaderboard") as app:
164
 
165
  search_bar.submit(
166
  update_leaderboard_and_plot,
167
- inputs=[search_bar, sort_dropdown, x_axis, y_axis],
 
 
 
 
 
168
  outputs=[leaderboard, plot]
169
  )
170
-
171
  update_btn.click(
172
  update_leaderboard_and_plot,
173
- inputs=[search_bar, sort_dropdown, x_axis, y_axis],
174
  outputs=[leaderboard, plot]
175
  )
176
 
 
68
  result.drop('arch_name', axis=1, inplace=True)
69
  result.drop('crop_pct', axis=1, inplace=True)
70
  result.drop('interpolation', axis=1, inplace=True)
71
+
72
+ result['highlighted'] = False
73
 
74
  # Round numerical values
75
  result = result.round(2)
 
108
 
109
  return filtered_df
110
 
111
+
112
  def create_scatter_plot(df, x_axis, y_axis):
113
  fig = px.scatter(
114
  df,
 
119
  hover_data=['model'],
120
  trendline='ols',
121
  trendline_options=dict(log_x=True, log_y=True),
122
+ color='highlighted',
123
+ color_discrete_map={True: 'red', False: 'blue'},
124
  title=f'{y_axis} vs {x_axis}'
125
  )
126
+ fig.update_layout(showlegend=False)
127
+
128
  return fig
129
 
130
+
131
  # Load the leaderboard data
132
  full_df = load_leaderboard()
133
 
 
140
  DEFAULT_X = "infer_samples_per_sec"
141
  DEFAULT_Y = "avg_top1"
142
 
143
+ def update_leaderboard_and_plot(
144
+ model_name=DEFAULT_SEARCH,
145
+ highlight_name=None,
146
+ sort_by=DEFAULT_SORT,
147
+ x_axis=DEFAULT_X,
148
+ y_axis=DEFAULT_Y,
149
+ ):
150
+ filtered_df = filter_leaderboard(full_df, model_name, sort_by)
151
+
152
+ # Apply the highlight filter to the entire dataset so the output will be union (comparison) if the filters are disjoint
153
+ highlight_df = filter_leaderboard(full_df, highlight_name, sort_by) if highlight_name else None
154
+
155
+ # Combine filtered_df and highlight_df, removing duplicates
156
+ if highlight_df is not None:
157
+ combined_df = pd.concat([filtered_df, highlight_df]).drop_duplicates().reset_index(drop=True)
158
+ combined_df = combined_df.sort_values(by=sort_by, ascending=False)
159
+ combined_df['highlighted'] = combined_df['model'].isin(highlight_df['model'])
160
+ else:
161
+ combined_df = filtered_df
162
+
163
+ fig = create_scatter_plot(combined_df, x_axis, y_axis)
164
+ highlighted_df = combined_df.style.apply(lambda x: ['background-color: #ffcccc' if x['highlighted'] else '' for _ in x], axis=1)
165
+
166
+ return highlighted_df, fig
167
 
168
 
169
  with gr.Blocks(title="The timm Leaderboard") as app:
 
172
  gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
173
 
174
  with gr.Row():
175
+ search_bar = gr.Textbox(lines=1, label="Model Filter", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3)
176
  sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
177
+
178
+ with gr.Row():
179
+ highlight_bar = gr.Textbox(lines=1, label="Model Highlight/Compare Filter", placeholder="e.g. convnext*, re:^efficient")
180
 
181
  with gr.Row():
182
  x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
 
191
 
192
  search_bar.submit(
193
  update_leaderboard_and_plot,
194
+ inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
195
+ outputs=[leaderboard, plot]
196
+ )
197
+ highlight_bar.submit(
198
+ update_leaderboard_and_plot,
199
+ inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
200
  outputs=[leaderboard, plot]
201
  )
 
202
  update_btn.click(
203
  update_leaderboard_and_plot,
204
+ inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
205
  outputs=[leaderboard, plot]
206
  )
207