|
from functools import lru_cache, partial |
|
import os |
|
import json |
|
import re |
|
import tempfile |
|
from pathlib import Path |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import List, Dict |
|
from datatrove.io import get_datafolder, _get_true_fs |
|
from datatrove.utils.stats import MetricStatsDict |
|
import gradio as gr |
|
import tenacity |
|
|
|
from src.logic.graph_settings import Grouping |
|
|
|
def find_folders(base_folder: str, path: str) -> List[str]: |
|
base_folder_df = get_datafolder(base_folder) |
|
if not base_folder_df.exists(path): |
|
return [] |
|
|
|
from huggingface_hub import HfFileSystem |
|
extra_options = {} |
|
if isinstance(_get_true_fs(base_folder_df.fs), HfFileSystem): |
|
extra_options["expand_info"] = False |
|
|
|
return ( |
|
folder |
|
for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True, **extra_options).items() |
|
if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/")) |
|
) |
|
|
|
def fetch_datasets(base_folder: str, progress=gr.Progress()): |
|
datasets = sorted(progress.tqdm(find_folders(base_folder, ""))) |
|
if len(datasets) == 0: |
|
raise ValueError("No datasets found") |
|
return datasets, None |
|
|
|
def fetch_groups(base_folder: str, datasets: List[str], old_groups: str, type: str = "intersection"): |
|
if not datasets: |
|
return gr.update(choices=[], value=None) |
|
|
|
with ThreadPoolExecutor() as executor: |
|
GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets)) |
|
if len(GROUPS) == 0: |
|
return gr.update(choices=[], value=None) |
|
|
|
if type == "intersection": |
|
new_choices = set.intersection(*(set(g) for g in GROUPS)) |
|
else: |
|
new_choices = set.union(*(set(g) for g in GROUPS)) |
|
value = None |
|
if old_groups: |
|
value = list(set.intersection(new_choices, {old_groups})) |
|
value = value[0] if value else None |
|
|
|
if not value and len(new_choices) == 1: |
|
value = list(new_choices)[0] |
|
|
|
return gr.Dropdown(choices=sorted(list(new_choices)), value=value) |
|
|
|
def fetch_metrics(base_folder: str, datasets: List[str], group: str, old_metrics: str, type: str = "intersection"): |
|
if not group: |
|
return gr.update(choices=[], value=None) |
|
|
|
with ThreadPoolExecutor() as executor: |
|
metrics = list( |
|
executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets)) |
|
if len(metrics) == 0: |
|
return gr.update(choices=[], value=None) |
|
|
|
if type == "intersection": |
|
new_possibles_choices = set.intersection(*(set(s) for s in metrics)) |
|
else: |
|
new_possibles_choices = set.union(*(set(s) for s in metrics)) |
|
value = None |
|
if old_metrics: |
|
value = list(set.intersection(new_possibles_choices, {old_metrics})) |
|
value = value[0] if value else None |
|
|
|
if not value and len(new_possibles_choices) == 1: |
|
value = list(new_possibles_choices)[0] |
|
|
|
return gr.Dropdown(choices=sorted(list(new_possibles_choices)), value=value) |
|
|
|
def reverse_search(base_folder: str, possible_datasets: List[str], grouping: str, metric_name: str) -> str: |
|
with ThreadPoolExecutor() as executor: |
|
found_datasets = list(executor.map( |
|
lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None, |
|
possible_datasets)) |
|
found_datasets = [dataset for dataset in found_datasets if dataset is not None] |
|
return "\n".join(found_datasets) |
|
|
|
def reverse_search_add(datasets: List[str], reverse_search_results: str) -> List[str]: |
|
datasets = datasets or [] |
|
return list(set(datasets + reverse_search_results.strip().split("\n"))) |
|
|
|
def metric_exists(base_folder: str, path: str, metric_name: str, group_by: str) -> bool: |
|
base_folder = get_datafolder(base_folder) |
|
return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json") |
|
|
|
@tenacity.retry(stop=tenacity.stop_after_attempt(5)) |
|
def load_metrics(base_folder: str, path: str, metric_name: str, group_by: str) -> MetricStatsDict: |
|
base_folder = get_datafolder(base_folder) |
|
with base_folder.open(f"{path}/{group_by}/{metric_name}/metric.json") as f: |
|
json_metric = json.load(f) |
|
return MetricStatsDict.from_dict(json_metric) |
|
|
|
def load_data(dataset_path: str, base_folder: str, grouping: str, metric_name: str) -> MetricStatsDict: |
|
return load_metrics(base_folder, dataset_path, metric_name, grouping) |
|
|
|
|
|
def fetch_graph_data( |
|
base_folder: str, |
|
datasets: List[str], |
|
metric_name: str, |
|
grouping: Grouping, |
|
progress=gr.Progress(), |
|
): |
|
if len(datasets) <= 0 or not metric_name or not grouping: |
|
return None, None |
|
|
|
with ThreadPoolExecutor() as pool: |
|
data = list( |
|
progress.tqdm( |
|
pool.map( |
|
partial(load_data, base_folder=base_folder, metric_name=metric_name, grouping=grouping), |
|
datasets, |
|
), |
|
total=len(datasets), |
|
desc="Loading data...", |
|
) |
|
) |
|
|
|
data = {path: result for path, result in zip(datasets, data)} |
|
return data, None |
|
|
|
def update_datasets_with_regex(regex: str, selected_runs: List[str], all_runs: List[str]): |
|
if not regex: |
|
return [] |
|
new_dsts = {run for run in all_runs if re.search(regex, run)} |
|
if not new_dsts: |
|
return selected_runs |
|
dst_union = new_dsts.union(selected_runs or []) |
|
return sorted(list(dst_union)) |