philippds commited on
Commit
638e9bd
Β·
verified Β·
1 Parent(s): 02dd1a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -16
app.py CHANGED
@@ -47,7 +47,17 @@ custom_css = """
47
  """
48
 
49
  # Pattern: 0 Default, 1 Grid, 2 Chain, 3 Circle, 4 Square, 5 Cross, 6 Two_Rows, 7 Field, 8 Random
50
- pattern_map = {0: "0: Default", 1: "1: Grid", 2: "2: Chain", 3: "3: Circle", 4: "4: Square", 5: "5: Cross", 6: "6: Two Rows", 7: "7: Field", 8: "8: Random" }
 
 
 
 
 
 
 
 
 
 
51
 
52
  hivex_envs = [
53
  {
@@ -86,7 +96,7 @@ def restart():
86
  def download_leaderboard_dataset():
87
  path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
88
  return path
89
-
90
 
91
  def get_total_models():
92
  total_models = 0
@@ -94,7 +104,7 @@ def get_total_models():
94
  model_ids = get_model_ids(hivex_env["hivex_env"])
95
  total_models += len(model_ids)
96
  return total_models
97
-
98
 
99
  def get_model_ids(hivex_env):
100
  api = HfApi()
@@ -130,7 +140,11 @@ def update_leaderboard_dataset_parallel(hivex_env, path):
130
  row["Task"] = results["task"]["name"]
131
  if "pattern-id" in results["task"] or "difficulty-id" in results["task"]:
132
  key = "Pattern" if "pattern-id" in results["task"] else "Difficulty"
133
- row[key] = pattern_map[results["task"]["pattern-id"]] if "pattern-id" in results["task"] else results["task"]["difficulty-id"]
 
 
 
 
134
 
135
  results_metrics = results["metrics"]
136
 
@@ -146,7 +160,7 @@ def update_leaderboard_dataset_parallel(hivex_env, path):
146
 
147
  # ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
148
  ranked_dataframe = pd.DataFrame.from_records(data)
149
-
150
  new_history = ranked_dataframe
151
  file_path = path + "/" + hivex_env + ".csv"
152
  new_history.to_csv(file_path, index=False)
@@ -187,7 +201,7 @@ def get_data(rl_env, task_id, path) -> pd.DataFrame:
187
  filtered_data = filtered_data.drop(columns=["Task"])
188
 
189
  # Drop columns that have no data (all values are NaN)
190
- filtered_data = filtered_data.dropna(axis=1, how='all')
191
 
192
  # Drop columns where all values are 0.0
193
  filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)]
@@ -201,6 +215,7 @@ def get_data(rl_env, task_id, path) -> pd.DataFrame:
201
 
202
  return filtered_data
203
 
 
204
  def get_task(rl_env, task_id, path) -> str:
205
  """
206
  Get the task name from the leaderboard dataset based on the rl_env and task_id.
@@ -223,13 +238,24 @@ def get_task(rl_env, task_id, path) -> str:
223
  def convert_to_title_case(text: str) -> str:
224
  # Replace underscores with spaces
225
  text = text.replace("_", " ")
226
-
227
  # Convert each word to title case (capitalize the first letter)
228
  title_case_text = text.title()
229
-
230
  return title_case_text
231
 
232
 
 
 
 
 
 
 
 
 
 
 
 
233
  run_update_dataset()
234
 
235
  block = gr.Blocks(css=custom_css) # Attach the custom CSS here
@@ -247,34 +273,47 @@ with block:
247
  <h1 style="font-weight: bold;">HIVEX Leaderboard</h1>
248
  </div>
249
  """
250
- )
251
  with gr.Row(elem_id="header-row"):
252
- gr.HTML(f"<p style='text-align: center;'>Total models: {get_total_models()}</p>")
 
 
253
  with gr.Row(elem_id="header-row"):
254
- gr.HTML(f"<p style='text-align: center;'>Get started πŸš€ on our <a href='https://github.com/hivex-research/hivex'>GitHub repository</a>!</p>")
 
 
255
 
256
  path_ = download_leaderboard_dataset()
257
  # gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
258
  # ENVIRONMENT TABS
259
- with gr.Tabs() as tabs: # elem_classes="tab-buttons"
260
  for env_index in range(0, len(hivex_envs)):
261
  hivex_env = hivex_envs[env_index]
262
  with gr.Tab(f"{hivex_env['title']}") as env_tabs:
263
  # ADD CHECK BOX GROUP TO SELECT DIFFICULTY / PATTERN IDs
264
-
 
 
 
 
265
  # TASK TABS
266
  for task_id in range(0, hivex_env["task_count"]):
267
- task_title = convert_to_title_case(get_task(hivex_env["hivex_env"], task_id, path_))
 
 
268
  with gr.TabItem(f"Task {task_id}: {task_title}"):
269
  with gr.Row():
270
  data = get_data(hivex_env["hivex_env"], task_id, path_)
271
  row_count = len(data) # Number of rows in the data
272
-
273
  gr_dataframe = gr.components.Dataframe(
274
  value=data,
275
  headers=["User", "Model"],
276
  datatype=["markdown", "markdown"],
277
- row_count=(row_count, 'fixed') # Set to the exact number of rows in the data
 
 
 
278
  )
279
 
280
 
 
47
  """
48
 
49
  # Pattern: 0 Default, 1 Grid, 2 Chain, 3 Circle, 4 Square, 5 Cross, 6 Two_Rows, 7 Field, 8 Random
50
+ pattern_map = {
51
+ 0: "0: Default",
52
+ 1: "1: Grid",
53
+ 2: "2: Chain",
54
+ 3: "3: Circle",
55
+ 4: "4: Square",
56
+ 5: "5: Cross",
57
+ 6: "6: Two Rows",
58
+ 7: "7: Field",
59
+ 8: "8: Random",
60
+ }
61
 
62
  hivex_envs = [
63
  {
 
96
  def download_leaderboard_dataset():
97
  path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
98
  return path
99
+
100
 
101
  def get_total_models():
102
  total_models = 0
 
104
  model_ids = get_model_ids(hivex_env["hivex_env"])
105
  total_models += len(model_ids)
106
  return total_models
107
+
108
 
109
  def get_model_ids(hivex_env):
110
  api = HfApi()
 
140
  row["Task"] = results["task"]["name"]
141
  if "pattern-id" in results["task"] or "difficulty-id" in results["task"]:
142
  key = "Pattern" if "pattern-id" in results["task"] else "Difficulty"
143
+ row[key] = (
144
+ pattern_map[results["task"]["pattern-id"]]
145
+ if "pattern-id" in results["task"]
146
+ else results["task"]["difficulty-id"]
147
+ )
148
 
149
  results_metrics = results["metrics"]
150
 
 
160
 
161
  # ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
162
  ranked_dataframe = pd.DataFrame.from_records(data)
163
+
164
  new_history = ranked_dataframe
165
  file_path = path + "/" + hivex_env + ".csv"
166
  new_history.to_csv(file_path, index=False)
 
201
  filtered_data = filtered_data.drop(columns=["Task"])
202
 
203
  # Drop columns that have no data (all values are NaN)
204
+ filtered_data = filtered_data.dropna(axis=1, how="all")
205
 
206
  # Drop columns where all values are 0.0
207
  filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)]
 
215
 
216
  return filtered_data
217
 
218
+
219
  def get_task(rl_env, task_id, path) -> str:
220
  """
221
  Get the task name from the leaderboard dataset based on the rl_env and task_id.
 
238
  def convert_to_title_case(text: str) -> str:
239
  # Replace underscores with spaces
240
  text = text.replace("_", " ")
241
+
242
  # Convert each word to title case (capitalize the first letter)
243
  title_case_text = text.title()
244
+
245
  return title_case_text
246
 
247
 
248
+ def get_difficulty_pattern_ids_and_key(rl_env, path):
249
+ csv_path = path + "/" + rl_env + ".csv"
250
+ data = pd.read_csv(csv_path)
251
+
252
+ key = "Pattern" if "Pattern" in data.columns else "Difficulty"
253
+ # Get the unique values in the "Difficulty" column
254
+ difficulty_pattern_ids = data[key].unique()
255
+
256
+ key, difficulty_pattern_ids
257
+
258
+
259
  run_update_dataset()
260
 
261
  block = gr.Blocks(css=custom_css) # Attach the custom CSS here
 
273
  <h1 style="font-weight: bold;">HIVEX Leaderboard</h1>
274
  </div>
275
  """
276
+ )
277
  with gr.Row(elem_id="header-row"):
278
+ gr.HTML(
279
+ f"<p style='text-align: center;'>Total models: {get_total_models()}</p>"
280
+ )
281
  with gr.Row(elem_id="header-row"):
282
+ gr.HTML(
283
+ f"<p style='text-align: center;'>Get started πŸš€ on our <a href='https://github.com/hivex-research/hivex'>GitHub repository</a>!</p>"
284
+ )
285
 
286
  path_ = download_leaderboard_dataset()
287
  # gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
288
  # ENVIRONMENT TABS
289
+ with gr.Tabs() as tabs: # elem_classes="tab-buttons"
290
  for env_index in range(0, len(hivex_envs)):
291
  hivex_env = hivex_envs[env_index]
292
  with gr.Tab(f"{hivex_env['title']}") as env_tabs:
293
  # ADD CHECK BOX GROUP TO SELECT DIFFICULTY / PATTERN IDs
294
+ dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key(
295
+ hivex_env["hivex_env"], path_
296
+ )
297
+ gr.CheckboxGroup(difficulty_pattern_ids, label=dp_key)
298
+
299
  # TASK TABS
300
  for task_id in range(0, hivex_env["task_count"]):
301
+ task_title = convert_to_title_case(
302
+ get_task(hivex_env["hivex_env"], task_id, path_)
303
+ )
304
  with gr.TabItem(f"Task {task_id}: {task_title}"):
305
  with gr.Row():
306
  data = get_data(hivex_env["hivex_env"], task_id, path_)
307
  row_count = len(data) # Number of rows in the data
308
+
309
  gr_dataframe = gr.components.Dataframe(
310
  value=data,
311
  headers=["User", "Model"],
312
  datatype=["markdown", "markdown"],
313
+ row_count=(
314
+ row_count,
315
+ "fixed",
316
+ ), # Set to the exact number of rows in the data
317
  )
318
 
319