|
import gradio as gr |
|
import json |
|
import logging |
|
import multiprocessing |
|
import os |
|
import pickle |
|
import threading |
|
import time |
|
from collections import Counter, defaultdict |
|
from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED |
|
from datetime import datetime |
|
from typing import Any, Dict, List, Tuple |
|
from warnings import warn |
|
import gc |
|
|
|
import numpy as np |
|
from huggingface_hub import HfApi |
|
from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions |
|
from bigcodebench.data.utils import CACHE_DIR |
|
from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check |
|
from bigcodebench.gen.util import trusted_check |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
|
|
REPO_ID = "bigcode/bigcodebench-evaluator" |
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
API = HfApi(token=HF_TOKEN) |
|
Result = Tuple[str, List[bool]] |
|
|
|
|
|
def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit): |
|
cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") |
|
if os.path.exists(cache_file): |
|
with open(cache_file, "rb") as f: |
|
return pickle.load(f) |
|
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
tbegin = time.time() |
|
|
|
with ProcessPoolExecutor(max_workers=n_workers) as executor: |
|
futures = [] |
|
n_samples = 0 |
|
expected_time = dict() |
|
|
|
for problem in problems.values(): |
|
args = ( |
|
problem["complete_prompt"] + "\n" + problem["canonical_solution"], |
|
problem["test"], |
|
problem["task_id"], |
|
max_as_limit, |
|
max_data_limit, |
|
max_stack_limit, |
|
min_time_limit, |
|
) |
|
|
|
futures.append(executor.submit(trusted_check, *args)) |
|
n_samples += 1 |
|
|
|
for future in as_completed(futures): |
|
result = future.result() |
|
expected_time[result["task_id"]] = result["time"] |
|
|
|
if any(expected_time.values()): |
|
with open(cache_file, "wb") as f: |
|
pickle.dump(expected_time, f) |
|
|
|
return expected_time |
|
|
|
|
|
def check_correctness( |
|
completion_id: int, |
|
problem: Dict[str, Any], |
|
solution: str, |
|
max_as_limit: float, |
|
max_data_limit: float, |
|
max_stack_limit: float, |
|
identifier=None, |
|
min_time_limit: float = 0.1, |
|
gt_time_limit: float = 2.0, |
|
) -> Dict[str, Result]: |
|
ret = { |
|
"completion_id": completion_id, |
|
"task_id": problem["task_id"], |
|
"_identifier": identifier, |
|
"solution": solution, |
|
} |
|
ret["base"] = untrusted_check( |
|
solution, |
|
problem["test"], |
|
problem["entry_point"], |
|
max_as_limit, |
|
max_data_limit, |
|
max_stack_limit, |
|
min_time_limit, |
|
gt_time_limit, |
|
) |
|
return ret |
|
|
|
|
|
def evaluate( |
|
split: str, |
|
subset: str, |
|
samples: str, |
|
pass_k: str="1,5,10", |
|
parallel: int = -1, |
|
min_time_limit: float = 1, |
|
max_as_limit: int = 30 * 1024, |
|
max_data_limit: int = 30 * 1024, |
|
max_stack_limit: int = 10, |
|
check_gt_only: bool = False, |
|
no_gt: bool = False, |
|
): |
|
pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()] |
|
if parallel < 1: |
|
n_workers = max(1, multiprocessing.cpu_count() // 2) |
|
else: |
|
n_workers = parallel |
|
|
|
if check_gt_only: |
|
samples = "__dummy__.jsonl" |
|
|
|
extra = subset + "_" if subset != "full" else "" |
|
|
|
problems = get_bigcodebench(subset=subset) |
|
dataset_hash = get_bigcodebench_hash(subset=subset) |
|
|
|
if not no_gt: |
|
expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit) |
|
else: |
|
expected_time = {task_id: None for task_id in problems} |
|
|
|
gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems]) |
|
failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems] |
|
|
|
pass_at_k = dict() |
|
results = { |
|
"date": datetime.now().strftime("%Y-%m-%d %H:%M"), |
|
"eval": {}, |
|
} |
|
|
|
if not check_gt_only: |
|
|
|
with ProcessPoolExecutor(max_workers=n_workers) as executor: |
|
futures = [] |
|
completion_id = Counter() |
|
n_samples = 0 |
|
eval_results = defaultdict(list) |
|
remainings = set() |
|
|
|
for sample in load_solutions(samples): |
|
task_id = sample["task_id"] |
|
|
|
if task_id not in problems: |
|
continue |
|
solution = ( |
|
sample["solution"] |
|
if "solution" in sample |
|
else problems[task_id]["complete_prompt"] + sample["completion"] |
|
) |
|
if "sanitized-calibrated" in samples: |
|
solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution |
|
remainings.add(sample["_identifier"]) |
|
args = ( |
|
completion_id[task_id], |
|
problems[task_id], |
|
solution, |
|
max_as_limit, |
|
max_data_limit, |
|
max_stack_limit, |
|
sample["_identifier"], |
|
min_time_limit, |
|
expected_time[task_id] if expected_time[task_id] else 20 |
|
) |
|
futures.append(executor.submit(check_correctness, *args)) |
|
completion_id[task_id] += 1 |
|
n_samples += 1 |
|
|
|
assert n_samples == len(remainings), "Missing problems in unfinished" |
|
assert len(completion_id) == len(problems), "Missing problems in samples" |
|
|
|
for future in as_completed(futures): |
|
result = future.result() |
|
remainings.remove(result["_identifier"]) |
|
eval_results[result["task_id"]].append(result) |
|
del future, result |
|
gc.collect() |
|
|
|
|
|
for task_id, task_results in eval_results.items(): |
|
task_results.sort(key=lambda x: x["completion_id"]) |
|
results["eval"][task_id] = [] |
|
for res in task_results: |
|
stat, details = res["base"] |
|
results["eval"][task_id].append( |
|
{ |
|
"task_id": task_id, |
|
"solution": res["solution"], |
|
"status": stat, |
|
"details": details, |
|
} |
|
) |
|
|
|
|
|
total = np.array([len(r) for k, r in results["eval"].items() if k in problems]) |
|
base_correct = [] |
|
|
|
for key, res in results["eval"].items(): |
|
if key not in problems: |
|
continue |
|
bc = sum([r["status"] == PASS for r in res]) |
|
base_correct.append(bc) |
|
|
|
base_correct = np.array(base_correct) |
|
|
|
pass_at_k.update({ |
|
f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean() |
|
for k in pass_k |
|
if total.min() >= k |
|
}) |
|
|
|
del problems, futures |
|
gc.collect() |
|
|
|
pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0] |
|
pass_at_k["split"] = split |
|
pass_at_k["subset"] = subset |
|
pass_at_k["calibrated"] = "sanitized-calibrated" in samples |
|
pass_at_k["gt_pass_rate"] = gt_pass_rate |
|
pass_at_k["failed_tasks"] = failed_tasks |
|
|
|
return results, pass_at_k |
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=evaluate, |
|
inputs=[ |
|
gr.Dropdown(["complete", "instruct"], label="BigCodeBench Split"), |
|
gr.Dropdown(["full", "hard"], label="BigCodeBench Subset"), |
|
gr.File(label="Samples Path (.jsonl)"), |
|
gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"), |
|
gr.Slider(-1, multiprocessing.cpu_count(), step=1, label="Parallel Workers", value=-1), |
|
gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1), |
|
gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024), |
|
gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024), |
|
gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10), |
|
gr.Checkbox(label="Check GT Only"), |
|
gr.Checkbox(label="No GT"), |
|
], |
|
outputs=[ |
|
gr.JSON(label="Results"), |
|
gr.JSON(label="Eval Results"), |
|
], |
|
|
|
) |
|
interface.queue(default_concurrency_limit=None) |
|
|
|
|
|
def preload_gt(): |
|
evaluate(split="complete", subset="full", samples="", check_gt_only=True) |
|
evaluate(split="complete", subset="hard", samples="", check_gt_only=True) |
|
|
|
|
|
def restart_space(): |
|
logging.info(f"Restarting space with repo ID: {REPO_ID}") |
|
try: |
|
|
|
API.restart_space(repo_id=REPO_ID, token=HF_TOKEN) |
|
preload_gt() |
|
logging.info("Space restarted successfully.") |
|
except Exception as e: |
|
logging.error(f"Failed to restart space: {e}") |
|
|
|
|
|
|
|
preload_gt() |
|
|
|
|
|
|
|
|
|
|
|
|
|
interface.launch(show_error=True) |
|
|