Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -68,6 +68,8 @@ def load_leaderboard():
|
|
68 |
result.drop('arch_name', axis=1, inplace=True)
|
69 |
result.drop('crop_pct', axis=1, inplace=True)
|
70 |
result.drop('interpolation', axis=1, inplace=True)
|
|
|
|
|
71 |
|
72 |
# Round numerical values
|
73 |
result = result.round(2)
|
@@ -106,6 +108,7 @@ def filter_leaderboard(df, model_name, sort_by):
|
|
106 |
|
107 |
return filtered_df
|
108 |
|
|
|
109 |
def create_scatter_plot(df, x_axis, y_axis):
|
110 |
fig = px.scatter(
|
111 |
df,
|
@@ -116,10 +119,15 @@ def create_scatter_plot(df, x_axis, y_axis):
|
|
116 |
hover_data=['model'],
|
117 |
trendline='ols',
|
118 |
trendline_options=dict(log_x=True, log_y=True),
|
|
|
|
|
119 |
title=f'{y_axis} vs {x_axis}'
|
120 |
)
|
|
|
|
|
121 |
return fig
|
122 |
|
|
|
123 |
# Load the leaderboard data
|
124 |
full_df = load_leaderboard()
|
125 |
|
@@ -132,14 +140,30 @@ DEFAULT_SORT = "avg_top1"
|
|
132 |
DEFAULT_X = "infer_samples_per_sec"
|
133 |
DEFAULT_Y = "avg_top1"
|
134 |
|
135 |
-
def update_leaderboard_and_plot(
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
|
145 |
with gr.Blocks(title="The timm Leaderboard") as app:
|
@@ -148,8 +172,11 @@ with gr.Blocks(title="The timm Leaderboard") as app:
|
|
148 |
gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
|
149 |
|
150 |
with gr.Row():
|
151 |
-
search_bar = gr.Textbox(lines=1, label="
|
152 |
sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
|
|
|
|
|
|
|
153 |
|
154 |
with gr.Row():
|
155 |
x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
|
@@ -164,13 +191,17 @@ with gr.Blocks(title="The timm Leaderboard") as app:
|
|
164 |
|
165 |
search_bar.submit(
|
166 |
update_leaderboard_and_plot,
|
167 |
-
inputs=[search_bar, sort_dropdown, x_axis, y_axis],
|
|
|
|
|
|
|
|
|
|
|
168 |
outputs=[leaderboard, plot]
|
169 |
)
|
170 |
-
|
171 |
update_btn.click(
|
172 |
update_leaderboard_and_plot,
|
173 |
-
inputs=[search_bar, sort_dropdown, x_axis, y_axis],
|
174 |
outputs=[leaderboard, plot]
|
175 |
)
|
176 |
|
|
|
68 |
result.drop('arch_name', axis=1, inplace=True)
|
69 |
result.drop('crop_pct', axis=1, inplace=True)
|
70 |
result.drop('interpolation', axis=1, inplace=True)
|
71 |
+
|
72 |
+
result['highlighted'] = False
|
73 |
|
74 |
# Round numerical values
|
75 |
result = result.round(2)
|
|
|
108 |
|
109 |
return filtered_df
|
110 |
|
111 |
+
|
112 |
def create_scatter_plot(df, x_axis, y_axis):
|
113 |
fig = px.scatter(
|
114 |
df,
|
|
|
119 |
hover_data=['model'],
|
120 |
trendline='ols',
|
121 |
trendline_options=dict(log_x=True, log_y=True),
|
122 |
+
color='highlighted',
|
123 |
+
color_discrete_map={True: 'red', False: 'blue'},
|
124 |
title=f'{y_axis} vs {x_axis}'
|
125 |
)
|
126 |
+
fig.update_layout(showlegend=False)
|
127 |
+
|
128 |
return fig
|
129 |
|
130 |
+
|
131 |
# Load the leaderboard data
|
132 |
full_df = load_leaderboard()
|
133 |
|
|
|
140 |
DEFAULT_X = "infer_samples_per_sec"
|
141 |
DEFAULT_Y = "avg_top1"
|
142 |
|
143 |
+
def update_leaderboard_and_plot(
|
144 |
+
model_name=DEFAULT_SEARCH,
|
145 |
+
highlight_name=None,
|
146 |
+
sort_by=DEFAULT_SORT,
|
147 |
+
x_axis=DEFAULT_X,
|
148 |
+
y_axis=DEFAULT_Y,
|
149 |
+
):
|
150 |
+
filtered_df = filter_leaderboard(full_df, model_name, sort_by)
|
151 |
+
|
152 |
+
# Apply the highlight filter to the entire dataset so the output will be union (comparison) if the filters are disjoint
|
153 |
+
highlight_df = filter_leaderboard(full_df, highlight_name, sort_by) if highlight_name else None
|
154 |
+
|
155 |
+
# Combine filtered_df and highlight_df, removing duplicates
|
156 |
+
if highlight_df is not None:
|
157 |
+
combined_df = pd.concat([filtered_df, highlight_df]).drop_duplicates().reset_index(drop=True)
|
158 |
+
combined_df = combined_df.sort_values(by=sort_by, ascending=False)
|
159 |
+
combined_df['highlighted'] = combined_df['model'].isin(highlight_df['model'])
|
160 |
+
else:
|
161 |
+
combined_df = filtered_df
|
162 |
+
|
163 |
+
fig = create_scatter_plot(combined_df, x_axis, y_axis)
|
164 |
+
highlighted_df = combined_df.style.apply(lambda x: ['background-color: #ffcccc' if x['highlighted'] else '' for _ in x], axis=1)
|
165 |
+
|
166 |
+
return highlighted_df, fig
|
167 |
|
168 |
|
169 |
with gr.Blocks(title="The timm Leaderboard") as app:
|
|
|
172 |
gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
|
173 |
|
174 |
with gr.Row():
|
175 |
+
search_bar = gr.Textbox(lines=1, label="Model Filter", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3)
|
176 |
sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
|
177 |
+
|
178 |
+
with gr.Row():
|
179 |
+
highlight_bar = gr.Textbox(lines=1, label="Model Highlight/Compare Filter", placeholder="e.g. convnext*, re:^efficient")
|
180 |
|
181 |
with gr.Row():
|
182 |
x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
|
|
|
191 |
|
192 |
search_bar.submit(
|
193 |
update_leaderboard_and_plot,
|
194 |
+
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
|
195 |
+
outputs=[leaderboard, plot]
|
196 |
+
)
|
197 |
+
highlight_bar.submit(
|
198 |
+
update_leaderboard_and_plot,
|
199 |
+
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
|
200 |
outputs=[leaderboard, plot]
|
201 |
)
|
|
|
202 |
update_btn.click(
|
203 |
update_leaderboard_and_plot,
|
204 |
+
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
|
205 |
outputs=[leaderboard, plot]
|
206 |
)
|
207 |
|