import gradio as gr import json import multiprocessing import os import pickle import threading import time from collections import Counter, defaultdict from concurrent.futures import ProcessPoolExecutor, as_completed, wait, FIRST_COMPLETED from datetime import datetime from typing import Any, Dict, List, Tuple from warnings import warn import numpy as np from termcolor import cprint from tqdm import tqdm from bigcodebench.data import get_bigcodebench, get_bigcodebench_hash, load_solutions from bigcodebench.data.utils import CACHE_DIR from bigcodebench.eval import PASS, compatible_eval_result, estimate_pass_at_k, untrusted_check from bigcodebench.gen.util import trusted_check Result = Tuple[str, List[bool]] def get_groundtruth(n_workers, problems, hashcode, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit): cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") if os.path.exists(cache_file): if check_gt_only: os.remove(cache_file) else: print(f"Load from ground-truth from {cache_file}") with open(cache_file, "rb") as f: return pickle.load(f) os.makedirs(CACHE_DIR, exist_ok=True) print("\nAsserting the groundtruth...") tbegin = time.time() with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = [] n_samples = 0 expected_time = dict() for problem in problems.values(): args = ( problem["complete_prompt"] + "\n" + problem["canonical_solution"], problem["test"], problem["task_id"], max_as_limit, max_data_limit, max_stack_limit, min_time_limit, ) futures.append(executor.submit(trusted_check, *args)) n_samples += 1 for future in tqdm(as_completed(futures), total=n_samples): result = future.result() expected_time[result["task_id"]] = result["time"] print(f"Expected outputs computed in {time.time() - tbegin:.2f}s") if any(expected_time.values()): with open(cache_file, "wb") as f: pickle.dump(expected_time, f) return expected_time def check_correctness( completion_id: int, problem: Dict[str, Any], solution: str, max_as_limit: float, max_data_limit: float, max_stack_limit: float, identifier=None, min_time_limit: float = 0.1, gt_time_limit: float = 2.0, ) -> Dict[str, Result]: ret = { "completion_id": completion_id, "task_id": problem["task_id"], "_identifier": identifier, "solution": solution, } ret["base"] = untrusted_check( solution, problem["test"], problem["entry_point"], max_as_limit, max_data_limit, max_stack_limit, min_time_limit, gt_time_limit, ) return ret def evaluate( split: str, subset: str, samples: str, pass_k: str="1,5,10", parallel: int = None, min_time_limit: float = 1, max_as_limit: int = 30 * 1024, max_data_limit: int = 30 * 1024, max_stack_limit: int = 10, check_gt_only: bool = False, no_gt: bool = False, ): pass_k = [int(k.strip()) for k in pass_k.split(',') if k.strip().isdigit()] if parallel is None: n_workers = max(1, multiprocessing.cpu_count() // 2) else: n_workers = parallel if check_gt_only: samples = "__dummy__.jsonl" extra = subset + "_" if subset != "full" else "" if os.path.isdir(samples): result_path = os.path.join(samples, f"{extra}eval_results.json") else: assert samples.endswith(".jsonl") result_path = samples.replace(".jsonl", f"_{extra}eval_results.json") problems = get_bigcodebench(subset=subset) dataset_hash = get_bigcodebench_hash(subset=subset) if not no_gt: expected_time = get_groundtruth(n_workers, problems, dataset_hash, check_gt_only, max_as_limit, max_data_limit, max_stack_limit, min_time_limit) else: expected_time = {task_id: None for task_id in problems} gt_pass_rate = np.mean([1 if v is not None else 0 for k, v in expected_time.items() if k in problems]) failed_tasks = [k for k, v in expected_time.items() if v is None and k in problems] if os.path.isfile(result_path): with open(result_path, "r") as f: results = json.load(f) results = compatible_eval_result(results) else: if check_gt_only: if gt_pass_rate > 0.99: cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green") else: cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red") if len(failed_tasks) > 0: cprint(f"Failed tasks: {failed_tasks}", "red") return {"gt_pass_rate":float(gt_pass_rate), "failed_tasks": failed_tasks} results = { "date": datetime.now().strftime("%Y-%m-%d %H:%M"), "eval": {}, } with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = [] completion_id = Counter() n_samples = 0 eval_results = defaultdict(list) # task_id -> remainings = set() print("Reading samples...") for sample in tqdm(load_solutions(samples)): task_id = sample["task_id"] if task_id not in problems: warn( f"Task {task_id} is found in the samples but not found in the dataset" ) continue solution = ( sample["solution"] if "solution" in sample else problems[task_id]["complete_prompt"] + sample["completion"] ) if "sanitized-calibrated" in samples: solution = problems[task_id]["code_prompt"] + "\n pass\n" + solution remainings.add(sample["_identifier"]) args = ( completion_id[task_id], problems[task_id], solution, max_as_limit, max_data_limit, max_stack_limit, sample["_identifier"], min_time_limit, expected_time[task_id] if expected_time[task_id] else 20 ) futures.append(executor.submit(check_correctness, *args)) completion_id[task_id] += 1 n_samples += 1 assert n_samples == len(remainings), "Missing problems in unfinished" assert len(completion_id) == len(problems), "Missing problems in samples" def stucking_checker(): not_done = futures while len(not_done) > 0: done, not_done = wait(not_done, timeout=240, return_when=FIRST_COMPLETED) if len(done) == 0: warn("No samples have finished testing in the last 240s") warn(f"{len(remainings)} samples to be tested: {remainings}") threading.Thread(target=stucking_checker).start() for future in tqdm(as_completed(futures), total=n_samples): result = future.result() remainings.remove(result["_identifier"]) eval_results[result["task_id"]].append(result) # sort the results for each problem by completion_id for task_id, task_results in eval_results.items(): task_results.sort(key=lambda x: x["completion_id"]) results["eval"][task_id] = [] for res in task_results: stat, details = res["base"] results["eval"][task_id].append( { "task_id": task_id, "solution": res["solution"], "status": stat, "details": details, } ) # Calculate pass@k. total = np.array([len(r) for k, r in results["eval"].items() if k in problems]) base_correct = [] for key, res in results["eval"].items(): if key not in problems: continue bc = sum([r["status"] == PASS for r in res]) base_correct.append(bc) base_correct = np.array(base_correct) pass_at_k = { f"pass@{k}": float(estimate_pass_at_k(total, base_correct, k).mean()) for k in pass_k if total.min() >= k } pass_at_k["gt_pass_rate"] = float(gt_pass_rate) pass_at_k["failed_tasks"] = failed_tasks return pass_at_k # mode = "-calibrated" if "sanitized-calibrated" in samples else "" # extra = subset.capitalize() # split = split.capitalize() # cprint(f"BigCodeBench-{split}{mode} ({extra})", "green") # if no_gt: # cprint(f"Groundtruth is not checked", "yellow") # else: # if gt_pass_rate > 0.99: # cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}", "green") # else: # cprint(f"Groundtruth pass rate: {gt_pass_rate:.3f}\nPlease be cautious!", "red") # if len(failed_tasks) > 0: # cprint(f"Failed tasks: {failed_tasks}", "red") # for k, v in pass_at_k.items(): # cprint(f"{k}:\t{v:.3f}", "green") # # save results # if os.path.isfile(result_path): # decision = "" # while decision.lower() not in ["y", "n"]: # print(f"{result_path} already exists. Press [Y/N] to overwrite or exit...") # decision = input() # if decision.lower() == "y": # # mv the file to a backup # new_path = result_path + ".bak" # while os.path.isfile(new_path): # new_path += ".bak" # os.rename(result_path, new_path) # print(f"Backup {result_path} to {new_path}") # if not os.path.isfile(result_path): # with open(result_path, "w") as f: # json.dump(results, f, indent=2) # if save_pass_rate: # pass_at_k_path = result_path.replace("_eval_results.json", "_pass_at_k.json") # pass_at_k["model"] = os.path.basename(samples).split("--bigcodebench-")[0] # pass_at_k["calibrated"] = "sanitized-calibrated" in samples # pass_at_k["subset"] = subset # def save_pass_at_k(): # with open(pass_at_k_path, "w") as f: # json.dump(pass_at_k, f, indent=2) # if os.path.isfile(pass_at_k_path): # saved_pass_at_k = json.load(open(pass_at_k_path, "r")) # # compare saved_pass_at_k with pass_at_k # for k in saved_pass_at_k.keys(): # if pass_at_k[k] != saved_pass_at_k[k]: # cprint(f"Warning: {k} is different from the saved one", "yellow") # # ask user whether to save the pass@k # decision = "" # while decision.lower() not in ["y", "n"]: # print(f"Save pass@k to {pass_at_k_path}? [Y/N]") # decision = input() # if decision.lower() == "y": # save_pass_at_k() # else: # save_pass_at_k() def run_gradio(): interface = gr.Interface( fn=evaluate, inputs=[ gr.Dropdown(["complete", "instruct"], label="Split"), gr.Dropdown(["full", "hard"], label="Subset"), gr.File(label="Samples Path (.jsonl)"), gr.Textbox(label="Pass k Values (comma-separated)", value="1,5,10"), gr.Slider(1, multiprocessing.cpu_count(), step=1, label="Parallel Workers"), gr.Slider(0.1, 10, step=0.1, label="Min Time Limit", value=1), gr.Slider(1, 100 * 1024, step=1024, label="Max AS Limit", value=30 * 1024), gr.Slider(1, 100 * 1024, step=1024, label="Max Data Limit", value=30 * 1024), gr.Slider(1, 100, step=1, label="Max Stack Limit", value=10), gr.Checkbox(label="Check GT Only"), gr.Checkbox(label="No GT"), ], outputs="text", # concurrency_limit=None ) interface.queue(default_concurrency_limit=None) interface.launch(show_error=True) if __name__ == "__main__": run_gradio() # evaluate("complete", "hard", "meta-llama--Llama-3.2-3B-Instruct--bigcodebench-instruct--vllm-0-1.jsonl")