import os import json import requests import gradio as gr import pandas as pd from huggingface_hub import HfApi, hf_hub_download, snapshot_download from huggingface_hub.repocard import metadata_load from apscheduler.schedulers.background import BackgroundScheduler from tqdm.contrib.concurrent import thread_map from utils import make_clickable_model, make_clickable_user from typing import List # Add this import statement DATASET_REPO_URL = ( "https://huggingface.co/datasets/hivex-research/hivex-leaderboard-data" ) DATASET_REPO_ID = "hivex-research/hivex-leaderboard-data" HF_TOKEN = os.environ.get("HF_TOKEN") block = gr.Blocks() api = HfApi(token=HF_TOKEN) # .tab-buttons button { # font-size: 20px; # } custom_css = """ /* Full width space */ .gradio-container { max-width: 95%!important; } .gr-dataframe table { width: auto; } .gr-dataframe td, .gr-dataframe th { white-space: nowrap; text-overflow: ellipsis; overflow: hidden; width: 1%; } """ # Pattern: 0 Default, 1 Grid, 2 Chain, 3 Circle, 4 Square, 5 Cross, 6 Two_Rows, 7 Field, 8 Random pattern_map = { 0: "0: Default", 1: "1: Grid", 2: "2: Chain", 3: "3: Circle", 4: "4: Square", 5: "5: Cross", 6: "6: Two Rows", 7: "7: Field", 8: "8: Random", } hivex_envs = [ { "title": "Wind Farm Control", "hivex_env": "hivex-wind-farm-control", "task_count": 2, }, { "title": "Wildfire Resource Management", "hivex_env": "hivex-wildfire-resource-management", "task_count": 3, }, { "title": "Drone-Based Reforestation", "hivex_env": "hivex-drone-based-reforestation", "task_count": 7, }, { "title": "Ocean Plastic Collection", "hivex_env": "hivex-ocean-plastic-collection", "task_count": 4, }, { "title": "Aerial Wildfire Suppression", "hivex_env": "hivex-aerial-wildfire-suppression", "task_count": 9, }, ] def restart(): print("RESTART") api.restart_space(repo_id="hivex-research/hivex-leaderboard") def download_leaderboard_dataset(): path = snapshot_download(repo_id=DATASET_REPO_ID, repo_type="dataset") return path def get_total_models(): total_models = 0 for hivex_env in hivex_envs: model_ids = get_model_ids(hivex_env["hivex_env"]) total_models += len(model_ids) return total_models def get_model_ids(hivex_env): api = HfApi() models = api.list_models(filter=hivex_env) model_ids = [x.modelId for x in models] return model_ids def get_metadata(model_id): try: readme_path = hf_hub_download(model_id, filename="README.md", etag_timeout=180) return metadata_load(readme_path) except requests.exceptions.HTTPError: # 404 README.md not found return None def update_leaderboard_dataset_parallel(hivex_env, path): # Get model ids associated with hivex_env model_ids = get_model_ids(hivex_env) def process_model(model_id): meta = get_metadata(model_id) # LOADED_MODEL_METADATA[model_id] = meta if meta is not None else '' if meta is None: return None user_id = model_id.split("/")[0] row = {} row["User"] = user_id row["Model"] = model_id results = meta["model-index"][0]["results"][0] row["Task-ID"] = results["task"]["task-id"] row["Task"] = results["task"]["name"] if "pattern-id" in results["task"] or "difficulty-id" in results["task"]: key = "Pattern" if "pattern-id" in results["task"] else "Difficulty" row[key] = ( pattern_map[results["task"]["pattern-id"]] if "pattern-id" in results["task"] else results["task"]["difficulty-id"] ) results_metrics = results["metrics"] for result in results_metrics: row[result["name"]] = float(result["value"].split("+/-")[0].strip()) return row data = list(thread_map(process_model, model_ids, desc="Processing models")) # Filter out None results (models with no metadata) data = [row for row in data if row is not None] # ranked_dataframe = rank_dataframe(pd.DataFrame.from_records(data)) ranked_dataframe = pd.DataFrame.from_records(data) new_history = ranked_dataframe file_path = path + "/" + hivex_env + ".csv" new_history.to_csv(file_path, index=False) return ranked_dataframe def run_update_dataset(): path_ = download_leaderboard_dataset() for i in range(0, len(hivex_envs)): hivex_env = hivex_envs[i] update_leaderboard_dataset_parallel(hivex_env["hivex_env"], path_) api.upload_folder( folder_path=path_, repo_id="hivex-research/hivex-leaderboard-data", repo_type="dataset", commit_message="Update dataset", ) def get_data(rl_env, task_id, path) -> pd.DataFrame: """ Get data from rl_env, filter by the given task_id, and drop the Task-ID column. Also drops any columns that have no data (all values are NaN) or all values are 0.0. :return: filtered data as a pandas DataFrame without the Task-ID column """ csv_path = path + "/" + rl_env + ".csv" data = pd.read_csv(csv_path) # Filter the data to only include rows where the "Task-ID" column matches the given task_id filtered_data = data[data["Task-ID"] == task_id] # Drop the "Task-ID" column filtered_data = filtered_data.drop(columns=["Task-ID"]) # Drop the "Task" column filtered_data = filtered_data.drop(columns=["Task"]) # Drop columns that have no data (all values are NaN) filtered_data = filtered_data.dropna(axis=1, how="all") # Drop columns where all values are 0.0 filtered_data = filtered_data.loc[:, (filtered_data != 0.0).any(axis=0)] # Convert User and Model columns to clickable links for index, row in filtered_data.iterrows(): user_id = row["User"] filtered_data.loc[index, "User"] = make_clickable_user(user_id) model_id = row["Model"] filtered_data.loc[index, "Model"] = make_clickable_model(model_id) return filtered_data def get_task(rl_env, task_id, path) -> str: """ Get the task name from the leaderboard dataset based on the rl_env and task_id. :return: The task name as a string """ csv_path = path + "/" + rl_env + ".csv" data = pd.read_csv(csv_path) # Filter the data to find the row with the matching task_id task_row = data[data["Task-ID"] == task_id] # Check if the task exists and return the task name if not task_row.empty: task_name = task_row.iloc[0]["Task"] return task_name else: return "Task not found" def convert_to_title_case(text: str) -> str: # Replace underscores with spaces text = text.replace("_", " ") # Convert each word to title case (capitalize the first letter) title_case_text = text.title() return title_case_text def get_difficulty_pattern_ids_and_key(rl_env, path): csv_path = path + "/" + rl_env + ".csv" data = pd.read_csv(csv_path) if "Pattern" in data.columns: key = "Pattern" difficulty_pattern_ids = data[key].unique() elif "Difficulty" in data.columns: key = "Difficulty" difficulty_pattern_ids = data[key].unique() else: # Handle the case where neither 'Pattern' nor 'Difficulty' columns exist key = None difficulty_pattern_ids = [] return key, difficulty_pattern_ids def filter_checkbox_data(rl_env, task_id, selected_values, path): """ Filters the data based on the selected difficulty/pattern values. """ data = get_data(rl_env, task_id, path) # If there are selected values, filter the DataFrame if selected_values: filter_column = "Pattern" if "Pattern" in data.columns else "Difficulty" data = data[data[filter_column].isin(selected_values)] return data run_update_dataset() block = gr.Blocks(css=custom_css) # Attach the custom CSS here with block: with gr.Row(elem_id="header-row"): # TITLE IMAGE gr.HTML( """
hivex logo

HIVEX Leaderboard

""" ) with gr.Row(elem_id="header-row"): gr.HTML( f"

Total models: {get_total_models()}

" ) with gr.Row(elem_id="header-row"): gr.HTML( f"

Get started 🚀 on our GitHub repository!

" ) path_ = download_leaderboard_dataset() def filter_data(rl_env, task_id, selected_values, path): """ Filters the data based on the selected difficulty/pattern values. """ data = get_data(rl_env, task_id, path) # If there are selected values, filter the DataFrame if selected_values: filter_column = "Pattern" if "Pattern" in data.columns else "Difficulty" data = data[data[filter_column].isin(selected_values)] return data def update_filtered_data(selected_values, rl_env, task_id, path): filtered_data = filter_data(rl_env, task_id, selected_values, path) return filtered_data # ENVIRONMENT TABS with gr.Tabs() as tabs: for env_index in range(0, len(hivex_envs)): hivex_env = hivex_envs[env_index] with gr.Tab(f"{hivex_env['title']}") as env_tabs: dp_key, difficulty_pattern_ids = get_difficulty_pattern_ids_and_key( hivex_env["hivex_env"], path_ ) if dp_key is not None and difficulty_pattern_ids: selected_checkboxes = gr.CheckboxGroup( [str(dp_id) for dp_id in difficulty_pattern_ids], label=dp_key ) for task_id in range(0, hivex_env["task_count"]): task_title = convert_to_title_case( get_task(hivex_env["hivex_env"], task_id, path_) ) with gr.TabItem(f"Task {task_id}: {task_title}"): # Display initial data data = get_data(hivex_env["hivex_env"], task_id, path_) row_count = len(data) gr_dataframe = gr.DataFrame( value=data, headers=["User", "Model"], datatype=["markdown", "markdown"], row_count=(row_count, "fixed"), ) # Add a callback to update the DataFrame when checkboxes are changed selected_checkboxes.change( fn=update_filtered_data, inputs=[selected_checkboxes, gr.Textbox(hivex_env["hivex_env"]), gr.Number(task_id), gr.Textbox(path_)], outputs=gr_dataframe, ) else: gr.HTML("

No difficulty or pattern data available for this environment.

") scheduler = BackgroundScheduler() scheduler.add_job(restart, "interval", seconds=86400) scheduler.start() block.launch()