terryyz's picture
Update app.py
aca9a0c verified
raw
history blame
12.7 kB
import gradio as gr
import json
import multiprocessing
import os
import pickle
import threading
import time
from collections import Counter, defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED
from datetime import datetime
from typing import Any, Dict, List, Tuple
from warnings import warn
import numpy as np
from termcolor import cprint
from tqdm import tqdm
from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions
from bigcodebench.data.utils import CACHE_DIR
from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check
from bigcodebench.gen.util import trusted_check
Result = Tuple[str, List[bool]]
def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit):
cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl")
if os.path.exists(cache_file):
if check_gt_only:
os.remove(cache_file)
else:
print(f"Load from ground-truth from {cache_file}")
with open(cache_file, "rb") as f:
return pickle.load(f)
os.makedirs(CACHE_DIR, exist_ok=True)
print("\nAsserting the groundtruth...")
tbegin = time.time()
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
n_samples = 0
expected_time = dict()
for problem in problems.values():
args = (
problem["complete_prompt"] + "\n" + problem["canonical_solution"],
problem["test"],
problem["task_id"],
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
)
futures.append(executor.submit(trusted_check, *args))
n_samples += 1
for future in tqdm(as_completed(futures), total=n_samples):
result = future.result()
expected_time[result["task_id"]] = result["time"]
print(f"Expected outputs computed in {time.time() - tbegin:.2f}s")
if any(expected_time.values()):
with open(cache_file, "wb") as f:
pickle.dump(expected_time, f)
return expected_time
def check_correctness(
completion_id: int,
problem: Dict[str, Any],
solution: str,
max_as_limit: float,
max_data_limit: float,
max_stack_limit: float,
identifier=None,
min_time_limit: float = 0.1,
gt_time_limit: float = 2.0,
) -> Dict[str, Result]:
ret = {
"completion_id": completion_id,
"task_id": problem["task_id"],
"_identifier": identifier,
"solution": solution,
}
ret["base"] = untrusted_check(
solution,
problem["test"],
problem["entry_point"],
max_as_limit,
max_data_limit,
max_stack_limit,
min_time_limit,
gt_time_limit,
)
return ret
def evaluate(
split: str,
subset: str,
samples: str,
pass_k: str="1,5,10",
parallel: int = None,
min_time_limit: float = 1,
max_as_limit: int = 30 * 1024,
max_data_limit: int = 30 * 1024,
max_stack_limit: int = 10,
check_gt_only: bool = False,
no_gt: bool = False,
):
pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()]
if parallel is None:
n_workers = max(1, multiprocessing.cpu_count() // 2)
else:
n_workers = parallel
if check_gt_only:
samples = "__dummy__.jsonl"
extra = subset + "_" if subset != "full" else ""
if os.path.isdir(samples):
result_path = os.path.join(samples, f"{extra}eval_results.json")
else:
assert samples.endswith(".jsonl")
result_path = samples.replace(".jsonl", f"_{extra}eval_results.json")
problems = get_bigcodebench(subset=subset)
dataset_hash = get_bigcodebench_hash(subset=subset)
if not no_gt:
expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit)
else:
expected_time = {task_id: None for task_id in problems}
gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems])
failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems]
if os.path.isfile(result_path):
with open(result_path, "r") as f:
results = json.load(f)
results = compatible_eval_result(results)
else:
if check_gt_only:
if gt_pass_rate > 0.99:
cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green")
else:
cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red")
if len(failed_tasks) > 0:
cprint(f"Failed tasks: {failed_tasks}", "red")
return {"gt_pass_rate":float(gt_pass_rate), "failed_tasks": failed_tasks}
results = {
"date": datetime.now().strftime("%Y-%m-%d %H:%M"),
"eval": {},
}
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = []
completion_id = Counter()
n_samples = 0
eval_results = defaultdict(list) # task_id ->
remainings = set()
print("Reading samples...")
for sample in tqdm(load_solutions(samples)):
task_id = sample["task_id"]
if task_id not in problems:
warn(
f"Task {task_id} is found in the samples but not found in the dataset"
)
continue
solution = (
sample["solution"]
if "solution" in sample
else problems[task_id]["complete_prompt"] + sample["completion"]
)
if "sanitized-calibrated" in samples:
solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution
remainings.add(sample["_identifier"])
args = (
completion_id[task_id],
problems[task_id],
solution,
max_as_limit,
max_data_limit,
max_stack_limit,
sample["_identifier"],
min_time_limit,
expected_time[task_id] if expected_time[task_id] else 20
)
futures.append(executor.submit(check_correctness, *args))
completion_id[task_id] += 1
n_samples += 1
assert n_samples == len(remainings), "Missing problems in unfinished"
assert len(completion_id) == len(problems), "Missing problems in samples"
def stucking_checker():
not_done = futures
while len(not_done) > 0:
done, not_done = wait(not_done, timeout=240, return_when=FIRST_COMPLETED)
if len(done) == 0:
warn("No samples have finished testing in the last 240s")
warn(f"{len(remainings)} samples to be tested: {remainings}")
threading.Thread(target=stucking_checker).start()
for future in tqdm(as_completed(futures), total=n_samples):
result = future.result()
remainings.remove(result["_identifier"])
eval_results[result["task_id"]].append(result)
# sort the results for each problem by completion_id
for task_id, task_results in eval_results.items():
task_results.sort(key=lambda x: x["completion_id"])
results["eval"][task_id] = []
for res in task_results:
stat, details = res["base"]
results["eval"][task_id].append(
{
"task_id": task_id,
"solution": res["solution"],
"status": stat,
"details": details,
}
)
# Calculate pass@k.
total = np.array([len(r) for k, r in results["eval"].items() if k in problems])
base_correct = []
for key, res in results["eval"].items():
if key not in problems:
continue
bc = sum([r["status"] == PASS for r in res])
base_correct.append(bc)
base_correct = np.array(base_correct)
pass_at_k = {
f"pass@{k}": float(estimate_pass_at_k(total, base_correct, k).mean())
for k in pass_k
if total.min() >= k
}
pass_at_k["gt_pass_rate"] = float(gt_pass_rate)
pass_at_k["failed_tasks"] = failed_tasks
return pass_at_k
# mode = "-calibrated" if "sanitized-calibrated" in samples else ""
# extra = subset.capitalize()
# split = split.capitalize()
# cprint(f"BigCodeBench-{split}{mode} ({extra})", "green")
# if no_gt:
# cprint(f"Groundtruth is not checked", "yellow")
# else:
# if gt_pass_rate > 0.99:
# cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green")
# else:
# cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red")
# if len(failed_tasks) > 0:
# cprint(f"Failed tasks: {failed_tasks}", "red")
# for k, v in pass_at_k.items():
# cprint(f"{k}:\t{v:.3f}", "green")
# # save results
# if os.path.isfile(result_path):
# decision = ""
# while decision.lower() not in ["y", "n"]:
# print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...")
# decision = input()
# if decision.lower() == "y":
# # mv the file to a backup
# new_path = result_path + ".bak"
# while os.path.isfile(new_path):
# new_path += ".bak"
# os.rename(result_path, new_path)
# print(f"Backup {result_path} to {new_path}")
# if not os.path.isfile(result_path):
# with open(result_path, "w") as f:
# json.dump(results, f, indent=2)
# if save_pass_rate:
# pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json")
# pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0]
# pass_at_k["calibrated"] = "sanitized-calibrated" in samples
# pass_at_k["subset"] = subset
# def save_pass_at_k():
# with open(pass_at_k_path, "w") as f:
# json.dump(pass_at_k, f, indent=2)
# if os.path.isfile(pass_at_k_path):
# saved_pass_at_k = json.load(open(pass_at_k_path, "r"))
# # compare saved_pass_at_k with pass_at_k
# for k in saved_pass_at_k.keys():
# if pass_at_k[k] != saved_pass_at_k[k]:
# cprint(f"Warning: {k} is different from the saved one", "yellow")
# # ask user whether to save the pass@k
# decision = ""
# while decision.lower() not in ["y", "n"]:
# print(f"Save pass@k to {pass_at_k_path}? [Y/N]")
# decision = input()
# if decision.lower() == "y":
# save_pass_at_k()
# else:
# save_pass_at_k()
def run_gradio():
interface = gr.Interface(
fn=evaluate,
inputs=[
gr.Dropdown(["complete", "instruct"], label="Split"),
gr.Dropdown(["full", "hard"], label="Subset"),
gr.File(label="Samples Path (.jsonl)"),
gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"),
gr.Slider(1, multiprocessing.cpu_count(), step=1, label="Parallel Workers"),
gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1),
gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024),
gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024),
gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10),
gr.Checkbox(label="Check GT Only"),
gr.Checkbox(label="No GT"),
],
outputs="text",
# concurrency_limit=None
)
interface.queue(default_concurrency_limit=None)
interface.launch(show_error=True)
if __name__ == "__main__":
run_gradio()
# evaluate("complete", "hard", "meta-llama--Llama-3.2-3B-Instruct--bigcodebench-instruct--vllm-0-1.jsonl")