Spaces:
Running
Running
import argparse | |
from typing import List, Dict, Any, Tuple | |
import pickle as pkl | |
import tqdm | |
from .exec_eval import exec_on_db, result_eq | |
import os | |
from collections import defaultdict | |
import time | |
from multiprocessing import cpu_count, Pool, Manager | |
from itertools import repeat | |
NUM_PROCESSES = cpu_count() // 3 | |
if NUM_PROCESSES == 0: | |
NUM_PROCESSES = 1 | |
MULTIPLICATIVE_OVERHEAD = 3 | |
ADDITIVE_OVERHEAD = 30 | |
GOLD_TIMEOUT = 100 | |
cache_path = "cache.pkl" | |
m = Manager() | |
cache = m.dict() | |
def load_predictions(f_path: str) -> List[str]: | |
preds = [] | |
with open(f_path, "r") as in_file: | |
for l in in_file: | |
preds.append(l.strip()) | |
return preds | |
def acc(l, idxes=None): | |
if idxes is None: | |
idxes = [_ for _ in range(len(l))] | |
c = 0 | |
for idx in idxes: | |
if l[idx]: | |
c += 1 | |
return float(c) / len(idxes) | |
# the input is a tuple of gold_dict, model prediction and whether to use cache | |
# and teh output is whether the model prediction passes the entire test suite | |
def judge(args: Tuple[Dict[str, Any], str, bool]) -> bool: | |
gold_dict, pred, use_cache = args | |
testsuite_paths = gold_dict["testsuite"] | |
gold_query = gold_dict["query"] | |
order_matters = "order by" in gold_query.lower() | |
db_path = gold_dict["db_path"] | |
# if already computed sometime before | |
# and cache allowed, directly return the result | |
k = (db_path, gold_query, pred) | |
if use_cache and k in cache: | |
return cache[k] | |
pass_all_testcase = True | |
for testcase_path in testsuite_paths: | |
start = time.time() | |
flg, gold_result = exec_on_db(testcase_path, gold_query, timeout=GOLD_TIMEOUT) | |
duration = time.time() - start | |
timeout = ADDITIVE_OVERHEAD + MULTIPLICATIVE_OVERHEAD * duration | |
if flg != "result": | |
print("Warning: executing gold query results in an exception") | |
continue | |
flg, pred_result = exec_on_db(testcase_path, pred, timeout=int(timeout)) | |
if flg != "result": | |
pass_all_testcase = False | |
break | |
if not result_eq(gold_result, pred_result, order_matters): | |
pass_all_testcase = False | |
break | |
# save the results in the cache | |
if use_cache: | |
cache[k] = pass_all_testcase | |
return pass_all_testcase | |
# cache is a dictionary | |
# the key is a ternary tuple (empty_database_path, SQL1, SQL2) | |
# the value is whether SQL1 and SQL2 are equivalent, judged by the test suites | |
def load_cache() -> Dict[Tuple[str, str, str], bool]: | |
if os.path.exists(cache_path): | |
d = m.dict(pkl.load(open(cache_path, "rb"))) | |
for k, v in d.items(): | |
cache[k] = v | |
return cache | |
# dump the cache | |
def save_cache(): | |
pkl.dump(dict(cache), open(cache_path, "wb")) | |
def main( | |
preds: List[str], | |
gold_file: str = "classical_test.pkl", | |
verbose: bool = True, | |
num_processes: int = NUM_PROCESSES, | |
subset: str = "full", | |
use_cache: bool = True, | |
) -> List[bool]: | |
gold_dicts = pkl.load(open(gold_file, "rb")) | |
if subset != "full": | |
gold_dicts = [ | |
d | |
for d in gold_dicts | |
if d["db_path"] == "database/{db_id}/{db_id}.sqlite".format(db_id=subset) | |
] | |
assert len(gold_dicts) == len( | |
preds | |
), "number of gold and prediction should be equal" | |
group_name2idxes = defaultdict(list) | |
for idx, gold_dict in enumerate(gold_dicts): | |
group_name2idxes[gold_dict["db_id"]].append(idx) | |
with Pool(num_processes) as pool: | |
result = list( | |
tqdm.tqdm( | |
pool.imap(judge, zip(gold_dicts, preds, repeat(use_cache, len(preds)))), | |
total=len(gold_dicts), | |
) | |
) | |
if verbose: | |
print("overall accuracy: ", acc(result)) | |
for group, idxes in group_name2idxes.items(): | |
print("accuracy for ", group, acc(result, idxes)) | |
return result | |
if __name__ == "__main__": | |
start = time.time() | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--gold", | |
dest="gold", | |
type=str, | |
default="classical_test.pkl", | |
help="the path to the predicted queries", | |
) | |
parser.add_argument( | |
"--pred", dest="pred", type=str, help="the path to the predicted queries" | |
) | |
parser.add_argument( | |
"--out_file", type=str, required=True, help="the output file path" | |
) | |
parser.add_argument( | |
"--num_processes", default=NUM_PROCESSES, help="number of processes to use" | |
) | |
parser.add_argument( | |
"--subset", | |
default="full", | |
choices=( | |
"atis", | |
"advising", | |
"academic", | |
"imdb", | |
"restaurants", | |
"geography", | |
"scholar", | |
"yelp", | |
"full", | |
), | |
help="which subset to evaluate on.", | |
) | |
parser.add_argument( | |
"--disable_cache", | |
default=False, | |
action="store_true", | |
help="whether to directly apply previously computed result and cache the current results. " | |
"use this flag to disable caching.", | |
) | |
args = parser.parse_args() | |
preds = load_predictions(args.pred) | |
assert not os.path.exists(args.out_file), ( | |
"output file path %s already exists" % args.out_file | |
) | |
use_cache = not args.disable_cache | |
if use_cache: | |
load_cache() | |
result = main( | |
preds=preds, | |
gold_file=args.gold, | |
verbose=True, | |
num_processes=args.num_processes, | |
subset=args.subset, | |
use_cache=use_cache, | |
) | |
pkl.dump(result, open(args.out_file, "wb")) | |
print("total time used: ", time.time() - start) | |
if use_cache: | |
save_cache() | |