Terry Zhuo
update
5d7ffc1
raw
history blame
9.38 kB
import gradio as gr
import json
import logging
import multiprocessing
import os
import pickle
import threading
import time
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED
from datetime import datetime
from typing import Any, Dict, List, Tuple
from warnings import warn
import gc
import numpy as np
from huggingface_hub import HfApi
from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions
from bigcodebench.data.utils import CACHE_DIR
from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check
from bigcodebench.gen.util import trusted_check
from apscheduler.schedulers.background import BackgroundScheduler
REPO_ID = "bigcode/bigcodebench-evaluator"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
API = HfApi(token=HF_TOKEN)
Result = Tuple[str, List[bool]]
def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit):
cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
if os.path.exists(cache_file):
with open(cache_file, "rb") as f:
return pickle.load(f)
os.makedirs(CACHE_DIR, exist_ok=True)
tbegin = time.time()
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
n_samples = 0
expected_time = dict()
for problem in problems.values():
args = (
problem["complete_prompt"] + "\n" + problem["canonical_solution"],
problem["test"],
problem["task_id"],
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
)
futures.append(executor.submit(trusted_check, *args))
n_samples += 1
for future in as_completed(futures):
result = future.result()
expected_time[result["task_id"]] = result["time"]
if any(expected_time.values()):
with open(cache_file, "wb") as f:
pickle.dump(expected_time, f)
return expected_time
def check_correctness(
completion_id: int,
problem: Dict[str, Any],
solution: str,
max_as_limit: float,
max_data_limit: float,
max_stack_limit: float,
identifier=None,
min_time_limit: float = 0.1,
gt_time_limit: float = 2.0,
) -> Dict[str, Result]:
ret = {
"completion_id": completion_id,
"task_id": problem["task_id"],
"_identifier": identifier,
"solution": solution,
}
ret["base"] = untrusted_check(
solution,
problem["test"],
problem["entry_point"],
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
gt_time_limit,
)
return ret
def evaluate(
split: str,
subset: str,
samples: str,
pass_k: str="1,5,10",
parallel: int = -1,
min_time_limit: float = 1,
max_as_limit: int = 30 * 1024,
max_data_limit: int = 30 * 1024,
max_stack_limit: int = 10,
check_gt_only: bool = False,
no_gt: bool = False,
):
pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()]
if parallel < 1:
n_workers = max(1, multiprocessing.cpu_count() // 2)
else:
n_workers = parallel
if check_gt_only:
samples = "__dummy__.jsonl"
extra = subset + "_" if subset != "full" else ""
problems = get_bigcodebench(subset=subset)
dataset_hash = get_bigcodebench_hash(subset=subset)
if not no_gt:
expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit)
else:
expected_time = {task_id: None for task_id in problems}
gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems])
failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems]
pass_at_k = dict()
results = {
"date": datetime.now().strftime("%Y-%m-%d %H:%M"),
"eval": {},
}
if not check_gt_only:
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
completion_id = Counter()
n_samples = 0
eval_results = defaultdict(list) # task_id ->
remainings = set()
for sample in load_solutions(samples):
task_id = sample["task_id"]
if task_id not in problems:
continue
solution = (
sample["solution"]
if "solution" in sample
else problems[task_id]["complete_prompt"] + sample["completion"]
)
if "sanitized_calibrated" in samples:
solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution
remainings.add(sample["_identifier"])
args = (
completion_id[task_id],
problems[task_id],
solution,
max_as_limit,
max_data_limit,
max_stack_limit,
sample["_identifier"],
min_time_limit,
expected_time[task_id] if expected_time[task_id] else 20
)
futures.append(executor.submit(check_correctness, *args))
completion_id[task_id] += 1
n_samples += 1
assert n_samples == len(remainings), "Missing problems in unfinished"
assert len(completion_id) == len(problems), "Missing problems in samples"
for future in as_completed(futures):
result = future.result()
remainings.remove(result["_identifier"])
eval_results[result["task_id"]].append(result)
del future, result
gc.collect()
# sort the results for each problem by completion_id
for task_id, task_results in eval_results.items():
task_results.sort(key=lambda x: x["completion_id"])
results["eval"][task_id] = []
for res in task_results:
stat, details = res["base"]
results["eval"][task_id].append(
{
"task_id": task_id,
"solution": res["solution"],
"status": stat,
"details": details,
}
)
# Calculate pass@k.
total = np.array([len(r) for k, r in results["eval"].items() if k in problems])
base_correct = []
for key, res in results["eval"].items():
if key not in problems:
continue
bc = sum([r["status"] == PASS for r in res])
base_correct.append(bc)
base_correct = np.array(base_correct)
pass_at_k.update({
f"pass@{k}": estimate_pass_at_k(total, base_correct, k).mean()
for k in pass_k
if total.min() >= k
})
del problems, futures
gc.collect()
pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0]
pass_at_k["split"] = split
pass_at_k["subset"] = subset
pass_at_k["calibrated"] = "sanitized_calibrated" in samples
pass_at_k["gt_pass_rate"] = gt_pass_rate
pass_at_k["failed_tasks"] = failed_tasks
return results, pass_at_k
# def run_gradio():
interface = gr.Interface(
fn=evaluate,
inputs=[
gr.Dropdown(["complete", "instruct"], label="BigCodeBench Split"),
gr.Dropdown(["full", "hard"], label="BigCodeBench Subset"),
gr.File(label="Samples Path (.jsonl)"),
gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"),
gr.Slider(-1, multiprocessing.cpu_count(), step=1, label="Parallel Workers", value=-1),
gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1),
gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024),
gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024),
gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10),
gr.Checkbox(label="Check GT Only"),
gr.Checkbox(label="No GT"),
],
outputs=[
gr.JSON(label="Results"),
gr.JSON(label="Eval Results"),
],
# concurrency_limit=None
)
interface.queue(default_concurrency_limit=None)
def preload_gt():
evaluate(split="complete", subset="full", samples="", check_gt_only=True)
evaluate(split="complete", subset="hard", samples="", check_gt_only=True)
def restart_space():
logging.info(f"Restarting space with repo ID: {REPO_ID}")
try:
# Now restart the space
API.restart_space(repo_id=REPO_ID, token=HF_TOKEN)
logging.info("Space restarted successfully.")
except Exception as e:
logging.error(f"Failed to restart space: {e}")
# if __name__ == "__main__":
preload_gt()
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", hours=5) # Restart every 5hs
scheduler.start()
interface.launch(show_error=True)