from functools import lru_cache, partial import os import json import re import tempfile from pathlib import Path from concurrent.futures import ThreadPoolExecutor from typing import List, Dict from datatrove.io import get_datafolder, _get_true_fs from datatrove.utils.stats import MetricStatsDict import gradio as gr import tenacity from src.logic.graph_settings import Grouping def find_folders(base_folder: str, path: str) -> List[str]: base_folder_df = get_datafolder(base_folder) if not base_folder_df.exists(path): return [] from huggingface_hub import HfFileSystem extra_options = {} if isinstance(_get_true_fs(base_folder_df.fs), HfFileSystem): extra_options["expand_info"] = False # speed up return ( folder for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True, **extra_options).items() if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/")) ) def fetch_datasets(base_folder: str, progress=gr.Progress()): datasets = sorted(progress.tqdm(find_folders(base_folder, ""))) if len(datasets) == 0: raise ValueError("No datasets found") return datasets, None def fetch_groups(base_folder: str, datasets: List[str], old_groups: str, type: str = "intersection"): if not datasets: return gr.update(choices=[], value=None) with ThreadPoolExecutor() as executor: GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets)) if len(GROUPS) == 0: return gr.update(choices=[], value=None) if type == "intersection": new_choices = set.intersection(*(set(g) for g in GROUPS)) else: new_choices = set.union(*(set(g) for g in GROUPS)) value = None if old_groups: value = list(set.intersection(new_choices, {old_groups})) value = value[0] if value else None if not value and len(new_choices) == 1: value = list(new_choices)[0] return gr.Dropdown(choices=sorted(list(new_choices)), value=value) def fetch_metrics(base_folder: str, datasets: List[str], group: str, old_metrics: str, type: str = "intersection"): if not group: return gr.update(choices=[], value=None) with ThreadPoolExecutor() as executor: metrics = list( executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets)) if len(metrics) == 0: return gr.update(choices=[], value=None) if type == "intersection": new_possibles_choices = set.intersection(*(set(s) for s in metrics)) else: new_possibles_choices = set.union(*(set(s) for s in metrics)) value = None if old_metrics: value = list(set.intersection(new_possibles_choices, {old_metrics})) value = value[0] if value else None if not value and len(new_possibles_choices) == 1: value = list(new_possibles_choices)[0] return gr.Dropdown(choices=sorted(list(new_possibles_choices)), value=value) def reverse_search(base_folder: str, possible_datasets: List[str], grouping: str, metric_name: str) -> str: with ThreadPoolExecutor() as executor: found_datasets = list(executor.map( lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None, possible_datasets)) found_datasets = [dataset for dataset in found_datasets if dataset is not None] return "\n".join(found_datasets) def reverse_search_add(datasets: List[str], reverse_search_results: str) -> List[str]: datasets = datasets or [] return list(set(datasets + reverse_search_results.strip().split("\n"))) def metric_exists(base_folder: str, path: str, metric_name: str, group_by: str) -> bool: base_folder = get_datafolder(base_folder) return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json") @tenacity.retry(stop=tenacity.stop_after_attempt(5)) def load_metrics(base_folder: str, path: str, metric_name: str, group_by: str) -> MetricStatsDict: base_folder = get_datafolder(base_folder) with base_folder.open(f"{path}/{group_by}/{metric_name}/metric.json") as f: json_metric = json.load(f) return MetricStatsDict.from_dict(json_metric) def load_data(dataset_path: str, base_folder: str, grouping: str, metric_name: str) -> MetricStatsDict: return load_metrics(base_folder, dataset_path, metric_name, grouping) def fetch_graph_data( base_folder: str, datasets: List[str], metric_name: str, grouping: Grouping, progress=gr.Progress(), ): if len(datasets) <= 0 or not metric_name or not grouping: return None, None with ThreadPoolExecutor() as pool: data = list( progress.tqdm( pool.map( partial(load_data, base_folder=base_folder, metric_name=metric_name, grouping=grouping), datasets, ), total=len(datasets), desc="Loading data...", ) ) data = {path: result for path, result in zip(datasets, data)} return data, None def update_datasets_with_regex(regex: str, selected_runs: List[str], all_runs: List[str]): if not regex: return [] new_dsts = {run for run in all_runs if re.search(regex, run)} if not new_dsts: return selected_runs dst_union = new_dsts.union(selected_runs or []) return sorted(list(dst_union))