import argparse from typing import List, Dict, Any, Tuple import pickle as pkl import tqdm from .exec_eval import exec_on_db, result_eq import os from collections import defaultdict import time from multiprocessing import cpu_count, Pool, Manager from itertools import repeat NUM_PROCESSES = cpu_count() // 3 if NUM_PROCESSES == 0: NUM_PROCESSES = 1 MULTIPLICATIVE_OVERHEAD = 3 ADDITIVE_OVERHEAD = 30 GOLD_TIMEOUT = 100 cache_path = "cache.pkl" m = Manager() cache = m.dict() def load_predictions(f_path: str) -> List[str]: preds = [] with open(f_path, "r") as in_file: for l in in_file: preds.append(l.strip()) return preds def acc(l, idxes=None): if idxes is None: idxes = [_ for _ in range(len(l))] c = 0 for idx in idxes: if l[idx]: c += 1 return float(c) / len(idxes) # the input is a tuple of gold_dict, model prediction and whether to use cache # and teh output is whether the model prediction passes the entire test suite def judge(args: Tuple[Dict[str, Any], str, bool]) -> bool: gold_dict, pred, use_cache = args testsuite_paths = gold_dict["testsuite"] gold_query = gold_dict["query"] order_matters = "order by" in gold_query.lower() db_path = gold_dict["db_path"] # if already computed sometime before # and cache allowed, directly return the result k = (db_path, gold_query, pred) if use_cache and k in cache: return cache[k] pass_all_testcase = True for testcase_path in testsuite_paths: start = time.time() flg, gold_result = exec_on_db(testcase_path, gold_query, timeout=GOLD_TIMEOUT) duration = time.time() - start timeout = ADDITIVE_OVERHEAD + MULTIPLICATIVE_OVERHEAD * duration if flg != "result": print("Warning: executing gold query results in an exception") continue flg, pred_result = exec_on_db(testcase_path, pred, timeout=int(timeout)) if flg != "result": pass_all_testcase = False break if not result_eq(gold_result, pred_result, order_matters): pass_all_testcase = False break # save the results in the cache if use_cache: cache[k] = pass_all_testcase return pass_all_testcase # cache is a dictionary # the key is a ternary tuple (empty_database_path, SQL1, SQL2) # the value is whether SQL1 and SQL2 are equivalent, judged by the test suites def load_cache() -> Dict[Tuple[str, str, str], bool]: if os.path.exists(cache_path): d = m.dict(pkl.load(open(cache_path, "rb"))) for k, v in d.items(): cache[k] = v return cache # dump the cache def save_cache(): pkl.dump(dict(cache), open(cache_path, "wb")) def main( preds: List[str], gold_file: str = "classical_test.pkl", verbose: bool = True, num_processes: int = NUM_PROCESSES, subset: str = "full", use_cache: bool = True, ) -> List[bool]: gold_dicts = pkl.load(open(gold_file, "rb")) if subset != "full": gold_dicts = [ d for d in gold_dicts if d["db_path"] == "database/{db_id}/{db_id}.sqlite".format(db_id=subset) ] assert len(gold_dicts) == len( preds ), "number of gold and prediction should be equal" group_name2idxes = defaultdict(list) for idx, gold_dict in enumerate(gold_dicts): group_name2idxes[gold_dict["db_id"]].append(idx) with Pool(num_processes) as pool: result = list( tqdm.tqdm( pool.imap(judge, zip(gold_dicts, preds, repeat(use_cache, len(preds)))), total=len(gold_dicts), ) ) if verbose: print("overall accuracy: ", acc(result)) for group, idxes in group_name2idxes.items(): print("accuracy for ", group, acc(result, idxes)) return result if __name__ == "__main__": start = time.time() parser = argparse.ArgumentParser() parser.add_argument( "--gold", dest="gold", type=str, default="classical_test.pkl", help="the path to the predicted queries", ) parser.add_argument( "--pred", dest="pred", type=str, help="the path to the predicted queries" ) parser.add_argument( "--out_file", type=str, required=True, help="the output file path" ) parser.add_argument( "--num_processes", default=NUM_PROCESSES, help="number of processes to use" ) parser.add_argument( "--subset", default="full", choices=( "atis", "advising", "academic", "imdb", "restaurants", "geography", "scholar", "yelp", "full", ), help="which subset to evaluate on.", ) parser.add_argument( "--disable_cache", default=False, action="store_true", help="whether to directly apply previously computed result and cache the current results. " "use this flag to disable caching.", ) args = parser.parse_args() preds = load_predictions(args.pred) assert not os.path.exists(args.out_file), ( "output file path %s already exists" % args.out_file ) use_cache = not args.disable_cache if use_cache: load_cache() result = main( preds=preds, gold_file=args.gold, verbose=True, num_processes=args.num_processes, subset=args.subset, use_cache=use_cache, ) pkl.dump(result, open(args.out_file, "wb")) print("total time used: ", time.time() - start) if use_cache: save_cache()