Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import multiprocessing | |
import os | |
import pickle | |
import threading | |
import time | |
from collections import Counter, defaultdict | |
from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED | |
from datetime import datetime | |
from typing import Any, Dict, List, Tuple | |
from warnings import warn | |
import gc | |
import numpy as np | |
from huggingface_hub import HfApi | |
from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions | |
from bigcodebench.data.utils import CACHE_DIR | |
from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check | |
from bigcodebench.gen.util import trusted_check | |
from apscheduler.schedulers.background import BackgroundScheduler | |
REPO_ID = "bigcode/bigcodebench-evaluator" | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
API = HfApi(token=HF_TOKEN) | |
Result = Tuple[str, List[bool]] | |
def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit): | |
cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") | |
if os.path.exists(cache_file): | |
if check_gt_only: | |
with open(cache_file, "rb") as f: | |
return pickle.load(f) | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
tbegin = time.time() | |
with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
futures = [] | |
n_samples = 0 | |
expected_time = dict() | |
for problem in problems.values(): | |
args = ( | |
problem["complete_prompt"] + "\n" + problem["canonical_solution"], | |
problem["test"], | |
problem["task_id"], | |
max_as_limit, | |
max_data_limit, | |
max_stack_limit, | |
min_time_limit, | |
) | |
futures.append(executor.submit(trusted_check, *args)) | |
n_samples += 1 | |
for future in as_completed(futures): | |
result = future.result() | |
expected_time[result["task_id"]] = result["time"] | |
if any(expected_time.values()): | |
with open(cache_file, "wb") as f: | |
pickle.dump(expected_time, f) | |
return expected_time | |
def check_correctness( | |
completion_id: int, | |
problem: Dict[str, Any], | |
solution: str, | |
max_as_limit: float, | |
max_data_limit: float, | |
max_stack_limit: float, | |
identifier=None, | |
min_time_limit: float = 0.1, | |
gt_time_limit: float = 2.0, | |
) -> Dict[str, Result]: | |
ret = { | |
"completion_id": completion_id, | |
"task_id": problem["task_id"], | |
"_identifier": identifier, | |
"solution": solution, | |
} | |
ret["base"] = untrusted_check( | |
solution, | |
problem["test"], | |
problem["entry_point"], | |
max_as_limit, | |
max_data_limit, | |
max_stack_limit, | |
min_time_limit, | |
gt_time_limit, | |
) | |
return ret | |
def evaluate( | |
split: str, | |
subset: str, | |
samples: str, | |
pass_k: str="1,5,10", | |
parallel: int = -1, | |
min_time_limit: float = 1, | |
max_as_limit: int = 30 * 1024, | |
max_data_limit: int = 30 * 1024, | |
max_stack_limit: int = 10, | |
check_gt_only: bool = False, | |
no_gt: bool = False, | |
): | |
pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()] | |
if parallel < 1: | |
n_workers = max(1, multiprocessing.cpu_count() // 2) | |
else: | |
n_workers = parallel | |
if check_gt_only: | |
samples = "__dummy__.jsonl" | |
extra = subset + "_" if subset != "full" else "" | |
problems = get_bigcodebench(subset=subset) | |
dataset_hash = get_bigcodebench_hash(subset=subset) | |
if not no_gt: | |
expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit) | |
else: | |
expected_time = {task_id: None for task_id in problems} | |
gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems]) | |
failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems] | |
pass_at_k = dict() | |
results = { | |
"date": datetime.now().strftime("%Y-%m-%d %H:%M"), | |
"eval": {}, | |
} | |
if not check_gt_only: | |
with ProcessPoolExecutor(max_workers=n_workers) as executor: | |
futures = [] | |
completion_id = Counter() | |
n_samples = 0 | |
eval_results = defaultdict(list) # task_id -> | |
remainings = set() | |
for sample in load_solutions(samples): | |
task_id = sample["task_id"] | |
if task_id not in problems: | |
continue | |
solution = ( | |
sample["solution"] | |
if "solution" in sample | |
else problems[task_id]["complete_prompt"] + sample["completion"] | |
) | |
if "sanitized-calibrated" in samples: | |
solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution | |
remainings.add(sample["_identifier"]) | |
args = ( | |
completion_id[task_id], | |
problems[task_id], | |
solution, | |
max_as_limit, | |
max_data_limit, | |
max_stack_limit, | |
sample["_identifier"], | |
min_time_limit, | |
expected_time[task_id] if expected_time[task_id] else 20 | |
) | |
futures.append(executor.submit(check_correctness, *args)) | |
completion_id[task_id] += 1 | |
n_samples += 1 | |
assert n_samples == len(remainings), "Missing problems in unfinished" | |
assert len(completion_id) == len(problems), "Missing problems in samples" | |
for future in as_completed(futures): | |
result = future.result() | |
remainings.remove(result["_identifier"]) | |
eval_results[result["task_id"]].append(result) | |
del future, result | |
gc.collect() | |
# sort the results for each problem by completion_id | |
for task_id, task_results in eval_results.items(): | |
task_results.sort(key=lambda x: x["completion_id"]) | |
results["eval"][task_id] = [] | |
for res in task_results: | |
stat, details = res["base"] | |
results["eval"][task_id].append( | |
{ | |
"task_id": task_id, | |
"solution": res["solution"], | |
"status": stat, | |
"details": details, | |
} | |
) | |
# Calculate pass@k. | |
total = np.array([len(r) for k, r in results["eval"].items() if k in problems]) | |
base_correct = [] | |
for key, res in results["eval"].items(): | |
if key not in problems: | |
continue | |
bc = sum([r["status"] == PASS for r in res]) | |
base_correct.append(bc) | |
base_correct = np.array(base_correct) | |
pass_at_k.update({ | |
f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean() | |
for k in pass_k | |
if total.min() >= k | |
}) | |
del problems, futures | |
gc.collect() | |
pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0] | |
pass_at_k["split"] = split | |
pass_at_k["subset"] = subset | |
pass_at_k["calibrated"] = "sanitized-calibrated" in samples | |
pass_at_k["gt_pass_rate"] = gt_pass_rate | |
pass_at_k["failed_tasks"] = failed_tasks | |
return results, pass_at_k | |
# def run_gradio(): | |
interface = gr.Interface( | |
fn=evaluate, | |
inputs=[ | |
gr.Dropdown(["complete", "instruct"], label="BigCodeBench Split"), | |
gr.Dropdown(["full", "hard"], label="BigCodeBench Subset"), | |
gr.File(label="Samples Path (.jsonl)"), | |
gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"), | |
gr.Slider(-1, multiprocessing.cpu_count(), step=1, label="Parallel Workers", value=-1), | |
gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1), | |
gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024), | |
gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024), | |
gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10), | |
gr.Checkbox(label="Check GT Only"), | |
gr.Checkbox(label="No GT"), | |
], | |
outputs=[ | |
gr.JSON(label="Results"), | |
gr.JSON(label="Eval Results"), | |
], | |
# concurrency_limit=None | |
) | |
interface.queue(default_concurrency_limit=None) | |
# interface.launch(show_error=True) | |
def preload_gt(): | |
evaluate(split="complete", subset="full", samples="", check_gt_only=True) | |
evaluate(split="complete", subset="hard", samples="", check_gt_only=True) | |
def restart_space(): | |
logging.info(f"Restarting space with repo ID: {REPO_ID}") | |
try: | |
# Now restart the space | |
API.restart_space(repo_id=REPO_ID, token=HF_TOKEN) | |
preload_gt() | |
logging.info("Space restarted successfully.") | |
except Exception as e: | |
logging.error(f"Failed to restart space: {e}") | |
# if __name__ == "__main__": | |
preload_gt() | |
# run_gradio() | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(restart_space, "interval", hours=1) # Restart every 1h | |
logging.info("Scheduler initialized to restart space every 1 hour.") | |
scheduler.start() | |