|
import os |
|
import json |
|
import requests |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from huggingface_hub import HfApi, hf_hub_download, snapshot_download |
|
from huggingface_hub.repocard import metadata_load |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
|
|
from tqdm.contrib.concurrent import thread_map |
|
|
|
from utils import make_clickable_model, make_clickable_user |
|
|
|
from typing import List |
|
|
|
DATASET_REPO_URL = ( |
|
"https://huggingface.co/datasets/hivex-research/hivex-leaderboard-data" |
|
) |
|
DATASET_REPO_ID = "hivex-research/hivex-leaderboard-data" |
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
block = gr.Blocks() |
|
api = HfApi(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
/* Full width space */ |
|
.gradio-container { |
|
max-width: 95%!important; |
|
} |
|
|
|
.gr-dataframe table { |
|
width: auto; |
|
} |
|
|
|
.gr-dataframe td, .gr-dataframe th { |
|
white-space: nowrap; |
|
text-overflow: ellipsis; |
|
overflow: hidden; |
|
width: 1%; |
|
} |
|
""" |
|
|
|
|
|
pattern_map = { |
|
0: "0: Default", |
|
1: "1: Grid", |
|
2: "2: Chain", |
|
3: "3: Circle", |
|
4: "4: Square", |
|
5: "5: Cross", |
|
6: "6: Two Rows", |
|
7: "7: Field", |
|
8: "8: Random", |
|
} |
|
|
|
hivex_envs = [ |
|
{ |
|
"title": "Wind Farm Control", |
|
"hivex_env": "hivex-wind-farm-control", |
|
"task_count": 2, |
|
}, |
|
{ |
|
"title": "Wildfire Resource Management", |
|
"hivex_env": "hivex-wildfire-resource-management", |
|
"task_count": 3, |
|
}, |
|
{ |
|
"title": "Drone-Based Reforestation", |
|
"hivex_env": "hivex-drone-based-reforestation", |
|
"task_count": 7, |
|
}, |
|
{ |
|
"title": "Ocean Plastic Collection", |
|
"hivex_env": "hivex-ocean-plastic-collection", |
|
"task_count": 4, |
|
}, |
|
{ |
|
"title": "Aerial Wildfire Suppression", |
|
"hivex_env": "hivex-aerial-wildfire-suppression", |
|
"task_count": 9, |
|
}, |
|
] |
|
|
|
|
|
def restart(): |
|
print("RESTART") |
|
api.restart_space(repo_id="hivex-research/hivex-leaderboard") |
|
|
|
|
|
def download_leaderboard_dataset(): |
|
path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset") |
|
return path |
|
|
|
|
|
def get_total_models(): |
|
total_models = 0 |
|
for hivex_env in hivex_envs: |
|
model_ids = get_model_ids(hivex_env["hivex_env"]) |
|
total_models += len(model_ids) |
|
return total_models |
|
|
|
|
|
def get_model_ids(hivex_env): |
|
api = HfApi() |
|
models = api.list_models(filter=hivex_env) |
|
model_ids = [x.modelId for x in models] |
|
return model_ids |
|
|
|
|
|
def get_metadata(model_id): |
|
try: |
|
readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180) |
|
return metadata_load(readme_path) |
|
except requests.exceptions.HTTPError: |
|
|
|
return None |
|
|
|
|
|
def update_leaderboard_dataset_parallel(hivex_env, path): |
|
|
|
model_ids = get_model_ids(hivex_env) |
|
|
|
def process_model(model_id): |
|
meta = get_metadata(model_id) |
|
|
|
if meta is None: |
|
return None |
|
user_id = model_id.split("/")[0] |
|
row = {} |
|
row["User"] = user_id |
|
row["Model"] = model_id |
|
results = meta["model-index"][0]["results"][0] |
|
row["Task-ID"] = results["task"]["task-id"] |
|
row["Task"] = results["task"]["name"] |
|
if "pattern-id" in results["task"] or "difficulty-id" in results["task"]: |
|
key = "Pattern" if "pattern-id" in results["task"] else "Difficulty" |
|
row[key] = ( |
|
pattern_map[results["task"]["pattern-id"]] |
|
if "pattern-id" in results["task"] |
|
else results["task"]["difficulty-id"] |
|
) |
|
|
|
results_metrics = results["metrics"] |
|
|
|
for result in results_metrics: |
|
row[result["name"]] = float(result["value"].split("+/-")[0].strip()) |
|
|
|
return row |
|
|
|
data = list(thread_map(process_model, model_ids, desc="Processing models")) |
|
|
|
|
|
data = [row for row in data if row is not None] |
|
|
|
|
|
ranked_dataframe = pd.DataFrame.from_records(data) |
|
|
|
new_history = ranked_dataframe |
|
file_path = path + "/" + hivex_env + ".csv" |
|
new_history.to_csv(file_path, index=False) |
|
|
|
return ranked_dataframe |
|
|
|
|
|
def run_update_dataset(): |
|
path_ = download_leaderboard_dataset() |
|
for i in range(0, len(hivex_envs)): |
|
hivex_env = hivex_envs[i] |
|
update_leaderboard_dataset_parallel(hivex_env["hivex_env"], path_) |
|
|
|
api.upload_folder( |
|
folder_path=path_, |
|
repo_id="hivex-research/hivex-leaderboard-data", |
|
repo_type="dataset", |
|
commit_message="Update dataset", |
|
) |
|
|
|
|
|
def get_data(rl_env, task_id, path) -> pd.DataFrame: |
|
""" |
|
Get data from rl_env, filter by the given task_id, and drop the Task-ID column. |
|
Also drops any columns that have no data (all values are NaN) or all values are 0.0. |
|
:return: filtered data as a pandas DataFrame without the Task-ID column |
|
""" |
|
csv_path = path + "/" + rl_env + ".csv" |
|
data = pd.read_csv(csv_path) |
|
|
|
|
|
filtered_data = data[data["Task-ID"] == task_id] |
|
|
|
|
|
filtered_data = filtered_data.drop(columns=["Task-ID"]) |
|
|
|
|
|
filtered_data = filtered_data.drop(columns=["Task"]) |
|
|
|
|
|
filtered_data = filtered_data.dropna(axis=1, how="all") |
|
|
|
|
|
filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)] |
|
|
|
|
|
for index, row in filtered_data.iterrows(): |
|
user_id = row["User"] |
|
filtered_data.loc[index, "User"] = make_clickable_user(user_id) |
|
model_id = row["Model"] |
|
filtered_data.loc[index, "Model"] = make_clickable_model(model_id) |
|
|
|
return filtered_data |
|
|
|
|
|
def get_task(rl_env, task_id, path) -> str: |
|
""" |
|
Get the task name from the leaderboard dataset based on the rl_env and task_id. |
|
:return: The task name as a string |
|
""" |
|
csv_path = path + "/" + rl_env + ".csv" |
|
data = pd.read_csv(csv_path) |
|
|
|
|
|
task_row = data[data["Task-ID"] == task_id] |
|
|
|
|
|
if not task_row.empty: |
|
task_name = task_row.iloc[0]["Task"] |
|
return task_name |
|
else: |
|
return "Task not found" |
|
|
|
|
|
def convert_to_title_case(text: str) -> str: |
|
|
|
text = text.replace("_", " ") |
|
|
|
|
|
title_case_text = text.title() |
|
|
|
return title_case_text |
|
|
|
|
|
def get_difficulty_pattern_ids_and_key(rl_env, path): |
|
csv_path = path + "/" + rl_env + ".csv" |
|
data = pd.read_csv(csv_path) |
|
|
|
if "Pattern" in data.columns: |
|
key = "Pattern" |
|
difficulty_pattern_ids = data[key].unique() |
|
elif "Difficulty" in data.columns: |
|
key = "Difficulty" |
|
difficulty_pattern_ids = data[key].unique() |
|
else: |
|
|
|
key = None |
|
difficulty_pattern_ids = [] |
|
|
|
return key, difficulty_pattern_ids |
|
|
|
|
|
|
|
run_update_dataset() |
|
|
|
block = gr.Blocks(css=custom_css) |
|
with block: |
|
with gr.Row(elem_id="header-row"): |
|
|
|
gr.HTML( |
|
""" |
|
<div style="width: 50%; margin: 0 auto; text-align: center;"> |
|
<img |
|
src="https://huggingface.co/spaces/hivex-research/hivex-leaderboard/resolve/main/hivex_logo.png" |
|
alt="hivex logo" |
|
style="width: 100px; display: inline-block; border-radius:20px;" |
|
/> |
|
<h1 style="font-weight: bold;">HIVEX Leaderboard</h1> |
|
</div> |
|
""" |
|
) |
|
with gr.Row(elem_id="header-row"): |
|
gr.HTML( |
|
f"<p style='text-align: center;'>Total models: {get_total_models()}</p>" |
|
) |
|
with gr.Row(elem_id="header-row"): |
|
gr.HTML( |
|
f"<p style='text-align: center;'>Get started π on our <a href='https://github.com/hivex-research/hivex'>GitHub repository</a>!</p>" |
|
) |
|
|
|
path_ = download_leaderboard_dataset() |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
for env_index in range(0, len(hivex_envs)): |
|
hivex_env = hivex_envs[env_index] |
|
with gr.Tab(f"{hivex_env['title']}") as env_tabs: |
|
|
|
|
|
dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key( |
|
hivex_env["hivex_env"], path_ |
|
) |
|
if dp_key is not None: |
|
gr.CheckboxGroup([str(dp_id) for dp_id in difficulty_pattern_ids], label=dp_key) |
|
|
|
|
|
for task_id in range(0, hivex_env["task_count"]): |
|
task_title = convert_to_title_case( |
|
get_task(hivex_env["hivex_env"], task_id, path_) |
|
) |
|
with gr.TabItem(f"Task {task_id}: {task_title}"): |
|
with gr.Row(): |
|
data = get_data(hivex_env["hivex_env"], task_id, path_) |
|
row_count = len(data) |
|
|
|
gr_dataframe = gr.components.Dataframe( |
|
value=data, |
|
headers=["User", "Model"], |
|
datatype=["markdown", "markdown"], |
|
row_count=( |
|
row_count, |
|
"fixed", |
|
), |
|
) |
|
|
|
|
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(restart, "interval", seconds=86400) |
|
scheduler.start() |
|
|
|
block.launch() |
|
|