Update app.py
Browse files
app.py
CHANGED
@@ -47,7 +47,17 @@ custom_css = """
|
|
47 |
"""
|
48 |
|
49 |
# Pattern: 0 Default, 1 Grid, 2 Chain, 3 Circle, 4 Square, 5 Cross, 6 Two_Rows, 7 Field, 8 Random
|
50 |
-
pattern_map = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
hivex_envs = [
|
53 |
{
|
@@ -86,7 +96,7 @@ def restart():
|
|
86 |
def download_leaderboard_dataset():
|
87 |
path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
|
88 |
return path
|
89 |
-
|
90 |
|
91 |
def get_total_models():
|
92 |
total_models = 0
|
@@ -94,7 +104,7 @@ def get_total_models():
|
|
94 |
model_ids = get_model_ids(hivex_env["hivex_env"])
|
95 |
total_models += len(model_ids)
|
96 |
return total_models
|
97 |
-
|
98 |
|
99 |
def get_model_ids(hivex_env):
|
100 |
api = HfApi()
|
@@ -130,7 +140,11 @@ def update_leaderboard_dataset_parallel(hivex_env, path):
|
|
130 |
row["Task"] = results["task"]["name"]
|
131 |
if "pattern-id" in results["task"] or "difficulty-id" in results["task"]:
|
132 |
key = "Pattern" if "pattern-id" in results["task"] else "Difficulty"
|
133 |
-
row[key] =
|
|
|
|
|
|
|
|
|
134 |
|
135 |
results_metrics = results["metrics"]
|
136 |
|
@@ -146,7 +160,7 @@ def update_leaderboard_dataset_parallel(hivex_env, path):
|
|
146 |
|
147 |
# ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
|
148 |
ranked_dataframe = pd.DataFrame.from_records(data)
|
149 |
-
|
150 |
new_history = ranked_dataframe
|
151 |
file_path = path + "/" + hivex_env + ".csv"
|
152 |
new_history.to_csv(file_path, index=False)
|
@@ -187,7 +201,7 @@ def get_data(rl_env, task_id, path) -> pd.DataFrame:
|
|
187 |
filtered_data = filtered_data.drop(columns=["Task"])
|
188 |
|
189 |
# Drop columns that have no data (all values are NaN)
|
190 |
-
filtered_data = filtered_data.dropna(axis=1, how=
|
191 |
|
192 |
# Drop columns where all values are 0.0
|
193 |
filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)]
|
@@ -201,6 +215,7 @@ def get_data(rl_env, task_id, path) -> pd.DataFrame:
|
|
201 |
|
202 |
return filtered_data
|
203 |
|
|
|
204 |
def get_task(rl_env, task_id, path) -> str:
|
205 |
"""
|
206 |
Get the task name from the leaderboard dataset based on the rl_env and task_id.
|
@@ -223,13 +238,24 @@ def get_task(rl_env, task_id, path) -> str:
|
|
223 |
def convert_to_title_case(text: str) -> str:
|
224 |
# Replace underscores with spaces
|
225 |
text = text.replace("_", " ")
|
226 |
-
|
227 |
# Convert each word to title case (capitalize the first letter)
|
228 |
title_case_text = text.title()
|
229 |
-
|
230 |
return title_case_text
|
231 |
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
run_update_dataset()
|
234 |
|
235 |
block = gr.Blocks(css=custom_css) # Attach the custom CSS here
|
@@ -247,34 +273,47 @@ with block:
|
|
247 |
<h1 style="font-weight: bold;">HIVEX Leaderboard</h1>
|
248 |
</div>
|
249 |
"""
|
250 |
-
)
|
251 |
with gr.Row(elem_id="header-row"):
|
252 |
-
gr.HTML(
|
|
|
|
|
253 |
with gr.Row(elem_id="header-row"):
|
254 |
-
gr.HTML(
|
|
|
|
|
255 |
|
256 |
path_ = download_leaderboard_dataset()
|
257 |
# gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
258 |
# ENVIRONMENT TABS
|
259 |
-
with gr.Tabs() as tabs:
|
260 |
for env_index in range(0, len(hivex_envs)):
|
261 |
hivex_env = hivex_envs[env_index]
|
262 |
with gr.Tab(f"{hivex_env['title']}") as env_tabs:
|
263 |
# ADD CHECK BOX GROUP TO SELECT DIFFICULTY / PATTERN IDs
|
264 |
-
|
|
|
|
|
|
|
|
|
265 |
# TASK TABS
|
266 |
for task_id in range(0, hivex_env["task_count"]):
|
267 |
-
task_title = convert_to_title_case(
|
|
|
|
|
268 |
with gr.TabItem(f"Task {task_id}: {task_title}"):
|
269 |
with gr.Row():
|
270 |
data = get_data(hivex_env["hivex_env"], task_id, path_)
|
271 |
row_count = len(data) # Number of rows in the data
|
272 |
-
|
273 |
gr_dataframe = gr.components.Dataframe(
|
274 |
value=data,
|
275 |
headers=["User", "Model"],
|
276 |
datatype=["markdown", "markdown"],
|
277 |
-
row_count=(
|
|
|
|
|
|
|
278 |
)
|
279 |
|
280 |
|
|
|
47 |
"""
|
48 |
|
49 |
# Pattern: 0 Default, 1 Grid, 2 Chain, 3 Circle, 4 Square, 5 Cross, 6 Two_Rows, 7 Field, 8 Random
|
50 |
+
pattern_map = {
|
51 |
+
0: "0: Default",
|
52 |
+
1: "1: Grid",
|
53 |
+
2: "2: Chain",
|
54 |
+
3: "3: Circle",
|
55 |
+
4: "4: Square",
|
56 |
+
5: "5: Cross",
|
57 |
+
6: "6: Two Rows",
|
58 |
+
7: "7: Field",
|
59 |
+
8: "8: Random",
|
60 |
+
}
|
61 |
|
62 |
hivex_envs = [
|
63 |
{
|
|
|
96 |
def download_leaderboard_dataset():
|
97 |
path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
|
98 |
return path
|
99 |
+
|
100 |
|
101 |
def get_total_models():
|
102 |
total_models = 0
|
|
|
104 |
model_ids = get_model_ids(hivex_env["hivex_env"])
|
105 |
total_models += len(model_ids)
|
106 |
return total_models
|
107 |
+
|
108 |
|
109 |
def get_model_ids(hivex_env):
|
110 |
api = HfApi()
|
|
|
140 |
row["Task"] = results["task"]["name"]
|
141 |
if "pattern-id" in results["task"] or "difficulty-id" in results["task"]:
|
142 |
key = "Pattern" if "pattern-id" in results["task"] else "Difficulty"
|
143 |
+
row[key] = (
|
144 |
+
pattern_map[results["task"]["pattern-id"]]
|
145 |
+
if "pattern-id" in results["task"]
|
146 |
+
else results["task"]["difficulty-id"]
|
147 |
+
)
|
148 |
|
149 |
results_metrics = results["metrics"]
|
150 |
|
|
|
160 |
|
161 |
# ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
|
162 |
ranked_dataframe = pd.DataFrame.from_records(data)
|
163 |
+
|
164 |
new_history = ranked_dataframe
|
165 |
file_path = path + "/" + hivex_env + ".csv"
|
166 |
new_history.to_csv(file_path, index=False)
|
|
|
201 |
filtered_data = filtered_data.drop(columns=["Task"])
|
202 |
|
203 |
# Drop columns that have no data (all values are NaN)
|
204 |
+
filtered_data = filtered_data.dropna(axis=1, how="all")
|
205 |
|
206 |
# Drop columns where all values are 0.0
|
207 |
filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)]
|
|
|
215 |
|
216 |
return filtered_data
|
217 |
|
218 |
+
|
219 |
def get_task(rl_env, task_id, path) -> str:
|
220 |
"""
|
221 |
Get the task name from the leaderboard dataset based on the rl_env and task_id.
|
|
|
238 |
def convert_to_title_case(text: str) -> str:
|
239 |
# Replace underscores with spaces
|
240 |
text = text.replace("_", " ")
|
241 |
+
|
242 |
# Convert each word to title case (capitalize the first letter)
|
243 |
title_case_text = text.title()
|
244 |
+
|
245 |
return title_case_text
|
246 |
|
247 |
|
248 |
+
def get_difficulty_pattern_ids_and_key(rl_env, path):
|
249 |
+
csv_path = path + "/" + rl_env + ".csv"
|
250 |
+
data = pd.read_csv(csv_path)
|
251 |
+
|
252 |
+
key = "Pattern" if "Pattern" in data.columns else "Difficulty"
|
253 |
+
# Get the unique values in the "Difficulty" column
|
254 |
+
difficulty_pattern_ids = data[key].unique()
|
255 |
+
|
256 |
+
key, difficulty_pattern_ids
|
257 |
+
|
258 |
+
|
259 |
run_update_dataset()
|
260 |
|
261 |
block = gr.Blocks(css=custom_css) # Attach the custom CSS here
|
|
|
273 |
<h1 style="font-weight: bold;">HIVEX Leaderboard</h1>
|
274 |
</div>
|
275 |
"""
|
276 |
+
)
|
277 |
with gr.Row(elem_id="header-row"):
|
278 |
+
gr.HTML(
|
279 |
+
f"<p style='text-align: center;'>Total models: {get_total_models()}</p>"
|
280 |
+
)
|
281 |
with gr.Row(elem_id="header-row"):
|
282 |
+
gr.HTML(
|
283 |
+
f"<p style='text-align: center;'>Get started π on our <a href='https://github.com/hivex-research/hivex'>GitHub repository</a>!</p>"
|
284 |
+
)
|
285 |
|
286 |
path_ = download_leaderboard_dataset()
|
287 |
# gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
|
288 |
# ENVIRONMENT TABS
|
289 |
+
with gr.Tabs() as tabs: # elem_classes="tab-buttons"
|
290 |
for env_index in range(0, len(hivex_envs)):
|
291 |
hivex_env = hivex_envs[env_index]
|
292 |
with gr.Tab(f"{hivex_env['title']}") as env_tabs:
|
293 |
# ADD CHECK BOX GROUP TO SELECT DIFFICULTY / PATTERN IDs
|
294 |
+
dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key(
|
295 |
+
hivex_env["hivex_env"], path_
|
296 |
+
)
|
297 |
+
gr.CheckboxGroup(difficulty_pattern_ids, label=dp_key)
|
298 |
+
|
299 |
# TASK TABS
|
300 |
for task_id in range(0, hivex_env["task_count"]):
|
301 |
+
task_title = convert_to_title_case(
|
302 |
+
get_task(hivex_env["hivex_env"], task_id, path_)
|
303 |
+
)
|
304 |
with gr.TabItem(f"Task {task_id}: {task_title}"):
|
305 |
with gr.Row():
|
306 |
data = get_data(hivex_env["hivex_env"], task_id, path_)
|
307 |
row_count = len(data) # Number of rows in the data
|
308 |
+
|
309 |
gr_dataframe = gr.components.Dataframe(
|
310 |
value=data,
|
311 |
headers=["User", "Model"],
|
312 |
datatype=["markdown", "markdown"],
|
313 |
+
row_count=(
|
314 |
+
row_count,
|
315 |
+
"fixed",
|
316 |
+
), # Set to the exact number of rows in the data
|
317 |
)
|
318 |
|
319 |
|