philippds's picture
Update app.py
693b619 verified
raw
history blame
10.5 kB
import os
import json
import requests
import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
from huggingface_hub.repocard import metadata_load
from apscheduler.schedulers.background import BackgroundScheduler
from tqdm.contrib.concurrent import thread_map
from utils import make_clickable_model, make_clickable_user
from typing import List # Add this import statement
DATASET_REPO_URL = (
"https://huggingface.co/datasets/hivex-research/hivex-leaderboard-data"
)
DATASET_REPO_ID = "hivex-research/hivex-leaderboard-data"
HF_TOKEN = os.environ.get("HF_TOKEN")
block = gr.Blocks()
api = HfApi(token=HF_TOKEN)
# .tab-buttons button {
# font-size: 20px;
# }
custom_css = """
/* Full width space */
.gradio-container {
max-width: 95%!important;
}
.gr-dataframe table {
width: auto;
}
.gr-dataframe td, .gr-dataframe th {
white-space: nowrap;
text-overflow: ellipsis;
overflow: hidden;
width: 1%;
}
"""
# Pattern: 0 Default, 1 Grid, 2 Chain, 3 Circle, 4 Square, 5 Cross, 6 Two_Rows, 7 Field, 8 Random
pattern_map = {
0: "0: Default",
1: "1: Grid",
2: "2: Chain",
3: "3: Circle",
4: "4: Square",
5: "5: Cross",
6: "6: Two Rows",
7: "7: Field",
8: "8: Random",
}
hivex_envs = [
{
"title": "Wind Farm Control",
"hivex_env": "hivex-wind-farm-control",
"task_count": 2,
},
{
"title": "Wildfire Resource Management",
"hivex_env": "hivex-wildfire-resource-management",
"task_count": 3,
},
{
"title": "Drone-Based Reforestation",
"hivex_env": "hivex-drone-based-reforestation",
"task_count": 7,
},
{
"title": "Ocean Plastic Collection",
"hivex_env": "hivex-ocean-plastic-collection",
"task_count": 4,
},
{
"title": "Aerial Wildfire Suppression",
"hivex_env": "hivex-aerial-wildfire-suppression",
"task_count": 9,
},
]
def restart():
print("RESTART")
api.restart_space(repo_id="hivex-research/hivex-leaderboard")
def download_leaderboard_dataset():
path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset")
return path
def get_total_models():
total_models = 0
for hivex_env in hivex_envs:
model_ids = get_model_ids(hivex_env["hivex_env"])
total_models += len(model_ids)
return total_models
def get_model_ids(hivex_env):
api = HfApi()
models = api.list_models(filter=hivex_env)
model_ids = [x.modelId for x in models]
return model_ids
def get_metadata(model_id):
try:
readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180)
return metadata_load(readme_path)
except requests.exceptions.HTTPError:
# 404 README.md not found
return None
def update_leaderboard_dataset_parallel(hivex_env, path):
# Get model ids associated with hivex_env
model_ids = get_model_ids(hivex_env)
def process_model(model_id):
meta = get_metadata(model_id)
# LOADED_MODEL_METADATA[model_id] = meta if meta is not None else ''
if meta is None:
return None
user_id = model_id.split("/")[0]
row = {}
row["User"] = user_id
row["Model"] = model_id
results = meta["model-index"][0]["results"][0]
row["Task-ID"] = results["task"]["task-id"]
row["Task"] = results["task"]["name"]
if "pattern-id" in results["task"] or "difficulty-id" in results["task"]:
key = "Pattern" if "pattern-id" in results["task"] else "Difficulty"
row[key] = (
pattern_map[results["task"]["pattern-id"]]
if "pattern-id" in results["task"]
else results["task"]["difficulty-id"]
)
results_metrics = results["metrics"]
for result in results_metrics:
row[result["name"]] = float(result["value"].split("+/-")[0].strip())
return row
data = list(thread_map(process_model, model_ids, desc="Processing models"))
# Filter out None results (models with no metadata)
data = [row for row in data if row is not None]
# ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data))
ranked_dataframe = pd.DataFrame.from_records(data)
new_history = ranked_dataframe
file_path = path + "/" + hivex_env + ".csv"
new_history.to_csv(file_path, index=False)
return ranked_dataframe
def run_update_dataset():
path_ = download_leaderboard_dataset()
for i in range(0, len(hivex_envs)):
hivex_env = hivex_envs[i]
update_leaderboard_dataset_parallel(hivex_env["hivex_env"], path_)
api.upload_folder(
folder_path=path_,
repo_id="hivex-research/hivex-leaderboard-data",
repo_type="dataset",
commit_message="Update dataset",
)
def get_data(rl_env, task_id, path) -> pd.DataFrame:
"""
Get data from rl_env, filter by the given task_id, and drop the Task-ID column.
Also drops any columns that have no data (all values are NaN) or all values are 0.0.
:return: filtered data as a pandas DataFrame without the Task-ID column
"""
csv_path = path + "/" + rl_env + ".csv"
data = pd.read_csv(csv_path)
# Filter the data to only include rows where the "Task-ID" column matches the given task_id
filtered_data = data[data["Task-ID"] == task_id]
# Drop the "Task-ID" column
filtered_data = filtered_data.drop(columns=["Task-ID"])
# Drop the "Task" column
filtered_data = filtered_data.drop(columns=["Task"])
# Drop columns that have no data (all values are NaN)
filtered_data = filtered_data.dropna(axis=1, how="all")
# Drop columns where all values are 0.0
filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)]
# Convert User and Model columns to clickable links
for index, row in filtered_data.iterrows():
user_id = row["User"]
filtered_data.loc[index, "User"] = make_clickable_user(user_id)
model_id = row["Model"]
filtered_data.loc[index, "Model"] = make_clickable_model(model_id)
return filtered_data
def get_task(rl_env, task_id, path) -> str:
"""
Get the task name from the leaderboard dataset based on the rl_env and task_id.
:return: The task name as a string
"""
csv_path = path + "/" + rl_env + ".csv"
data = pd.read_csv(csv_path)
# Filter the data to find the row with the matching task_id
task_row = data[data["Task-ID"] == task_id]
# Check if the task exists and return the task name
if not task_row.empty:
task_name = task_row.iloc[0]["Task"]
return task_name
else:
return "Task not found"
def convert_to_title_case(text: str) -> str:
# Replace underscores with spaces
text = text.replace("_", " ")
# Convert each word to title case (capitalize the first letter)
title_case_text = text.title()
return title_case_text
def get_difficulty_pattern_ids_and_key(rl_env, path):
csv_path = path + "/" + rl_env + ".csv"
data = pd.read_csv(csv_path)
if "Pattern" in data.columns:
key = "Pattern"
difficulty_pattern_ids = data[key].unique()
elif "Difficulty" in data.columns:
key = "Difficulty"
difficulty_pattern_ids = data[key].unique()
else:
# Handle the case where neither 'Pattern' nor 'Difficulty' columns exist
key = None
difficulty_pattern_ids = []
return key, difficulty_pattern_ids
run_update_dataset()
block = gr.Blocks(css=custom_css) # Attach the custom CSS here
with block:
with gr.Row(elem_id="header-row"):
# TITLE IMAGE
gr.HTML(
"""
<div style="width: 50%; margin: 0 auto; text-align: center;">
<img
src="https://huggingface.co/spaces/hivex-research/hivex-leaderboard/resolve/main/hivex_logo.png"
alt="hivex logo"
style="width: 100px; display: inline-block; border-radius:20px;"
/>
<h1 style="font-weight: bold;">HIVEX Leaderboard</h1>
</div>
"""
)
with gr.Row(elem_id="header-row"):
gr.HTML(
f"<p style='text-align: center;'>Total models: {get_total_models()}</p>"
)
with gr.Row(elem_id="header-row"):
gr.HTML(
f"<p style='text-align: center;'>Get started πŸš€ on our <a href='https://github.com/hivex-research/hivex'>GitHub repository</a>!</p>"
)
path_ = download_leaderboard_dataset()
# gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")
# ENVIRONMENT TABS
with gr.Tabs() as tabs: # elem_classes="tab-buttons"
for env_index in range(0, len(hivex_envs)):
hivex_env = hivex_envs[env_index]
with gr.Tab(f"{hivex_env['title']}") as env_tabs:
# Call the function to get the actual values
dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key(
hivex_env["hivex_env"], path_
)
if dp_key is not None:
gr.CheckboxGroup([str(dp_id) for dp_id in difficulty_pattern_ids], label=dp_key)
# TASK TABS
for task_id in range(0, hivex_env["task_count"]):
task_title = convert_to_title_case(
get_task(hivex_env["hivex_env"], task_id, path_)
)
with gr.TabItem(f"Task {task_id}: {task_title}"):
with gr.Row():
data = get_data(hivex_env["hivex_env"], task_id, path_)
row_count = len(data) # Number of rows in the data
gr_dataframe = gr.components.Dataframe(
value=data,
headers=["User", "Model"],
datatype=["markdown", "markdown"],
row_count=(
row_count,
"fixed",
), # Set to the exact number of rows in the data
)
scheduler = BackgroundScheduler()
scheduler.add_job(restart, "interval", seconds=86400)
scheduler.start()
block.launch()