################################ # val: number(float)/string(str)/sql(dict) # col_unit: (agg_id, col_id, isDistinct(bool)) # val_unit: (unit_op, col_unit1, col_unit2) # table_unit: (table_type, col_unit/sql) # cond_unit: (not_op, op_id, val_unit, val1, val2) # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] # sql { # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} # 'where': condition # 'groupBy': [col_unit1, col_unit2, ...] # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) # 'having': condition # 'limit': None/limit value # 'intersect': None/sql # 'except': None/sql # 'union': None/sql # } ################################ import os import json import sqlite3 import argparse from .process_sql import get_schema, Schema, get_sql from .exec_eval import eval_exec_match # Flag to disable value evaluation LEVELS = ["easy", "medium", "hard", "duckdb", "ddl", "all"] TURNS = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] PARTIAL_TYPES = [ "select", "select(no AGG)", "where", "where(no OP)", "group(no Having)", "group", "order", "and/or", "IUEN", "keywords", ] DISABLE_VALUE = True # Flag to disable distinct in select evaluation DISABLE_DISTINCT = True CLAUSE_KEYWORDS = ( "select", "from", "where", "group", "order", "limit", "intersect", "union", "except", ) JOIN_KEYWORDS = ("join", "on", "as") WHERE_OPS = ( "not", "between", "=", ">", "<", ">=", "<=", "!=", "in", "like", "is", "exists", ) UNIT_OPS = ("none", "-", "+", "*", "/") AGG_OPS = ("none", "max", "min", "count", "sum", "avg") TABLE_TYPE = { "sql": "sql", "table_unit": "table_unit", } COND_OPS = ("and", "or") SQL_OPS = ("intersect", "union", "except") ORDER_OPS = ("desc", "asc") HARDNESS = { "component1": ("where", "group", "order", "limit", "join", "or", "like"), "component2": ("except", "union", "intersect"), } def condition_has_or(conds): return "or" in conds[1::2] def condition_has_like(conds): return WHERE_OPS.index("like") in [cond_unit[1] for cond_unit in conds[::2]] def condition_has_sql(conds): for cond_unit in conds[::2]: val1, val2 = cond_unit[3], cond_unit[4] if val1 is not None and type(val1) is dict: return True if val2 is not None and type(val2) is dict: return True return False def val_has_op(val_unit): return val_unit[0] != UNIT_OPS.index("none") def has_agg(unit): return unit[0] != AGG_OPS.index("none") def accuracy(count, total): if count == total: return 1 return 0 def recall(count, total): if count == total: return 1 return 0 def F1(acc, rec): if (acc + rec) == 0: return 0 return (2.0 * acc * rec) / (acc + rec) def get_scores(count, pred_total, label_total): if pred_total != label_total: return 0, 0, 0 elif count == pred_total: return 1, 1, 1 return 0, 0, 0 def eval_sel(pred, label): pred_sel = pred["select"][1] label_sel = label["select"][1] label_wo_agg = [unit[1] for unit in label_sel] pred_total = len(pred_sel) label_total = len(label_sel) cnt = 0 cnt_wo_agg = 0 for unit in pred_sel: if unit in label_sel: cnt += 1 label_sel.remove(unit) if unit[1] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[1]) return label_total, pred_total, cnt, cnt_wo_agg def eval_where(pred, label): pred_conds = [unit for unit in pred["where"][::2]] label_conds = [unit for unit in label["where"][::2]] label_wo_agg = [unit[2] for unit in label_conds] pred_total = len(pred_conds) label_total = len(label_conds) cnt = 0 cnt_wo_agg = 0 for unit in pred_conds: if unit in label_conds: cnt += 1 label_conds.remove(unit) if unit[2] in label_wo_agg: cnt_wo_agg += 1 label_wo_agg.remove(unit[2]) return label_total, pred_total, cnt, cnt_wo_agg def eval_group(pred, label): pred_cols = [unit[1] for unit in pred["groupBy"]] label_cols = [unit[1] for unit in label["groupBy"]] pred_total = len(pred_cols) label_total = len(label_cols) cnt = 0 pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] label_cols = [ label.split(".")[1] if "." in label else label for label in label_cols ] for col in pred_cols: if col in label_cols: cnt += 1 label_cols.remove(col) return label_total, pred_total, cnt def eval_having(pred, label): pred_total = label_total = cnt = 0 if len(pred["groupBy"]) > 0: pred_total = 1 if len(label["groupBy"]) > 0: label_total = 1 pred_cols = [unit[1] for unit in pred["groupBy"]] label_cols = [unit[1] for unit in label["groupBy"]] if ( pred_total == label_total == 1 and pred_cols == label_cols and pred["having"] == label["having"] ): cnt = 1 return label_total, pred_total, cnt def eval_order(pred, label): pred_total = label_total = cnt = 0 if len(pred["orderBy"]) > 0: pred_total = 1 if len(label["orderBy"]) > 0: label_total = 1 if ( len(label["orderBy"]) > 0 and pred["orderBy"] == label["orderBy"] and ( (pred["limit"] is None and label["limit"] is None) or (pred["limit"] is not None and label["limit"] is not None) ) ): cnt = 1 return label_total, pred_total, cnt def eval_and_or(pred, label): pred_ao = pred["where"][1::2] label_ao = label["where"][1::2] pred_ao = set(pred_ao) label_ao = set(label_ao) if pred_ao == label_ao: return 1, 1, 1 return len(pred_ao), len(label_ao), 0 def get_nestedSQL(sql): nested = [] for cond_unit in sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2]: if type(cond_unit[3]) is dict: nested.append(cond_unit[3]) if type(cond_unit[4]) is dict: nested.append(cond_unit[4]) if sql["intersect"] is not None: nested.append(sql["intersect"]) if sql["except"] is not None: nested.append(sql["except"]) if sql["union"] is not None: nested.append(sql["union"]) return nested def eval_nested(pred, label): label_total = 0 pred_total = 0 cnt = 0 if pred is not None: pred_total += 1 if label is not None: label_total += 1 if pred is not None and label is not None: partial_scores = Evaluator.eval_partial_match(pred, label) cnt += Evaluator.eval_exact_match(pred, label, partial_scores) return label_total, pred_total, cnt def eval_IUEN(pred, label): lt1, pt1, cnt1 = eval_nested(pred["intersect"], label["intersect"]) lt2, pt2, cnt2 = eval_nested(pred["except"], label["except"]) lt3, pt3, cnt3 = eval_nested(pred["union"], label["union"]) label_total = lt1 + lt2 + lt3 pred_total = pt1 + pt2 + pt3 cnt = cnt1 + cnt2 + cnt3 return label_total, pred_total, cnt def get_keywords(sql): res = set() if len(sql["where"]) > 0: res.add("where") if len(sql["groupBy"]) > 0: res.add("group") if len(sql["having"]) > 0: res.add("having") if len(sql["orderBy"]) > 0: res.add(sql["orderBy"][0]) res.add("order") if sql["limit"] is not None: res.add("limit") if sql["except"] is not None: res.add("except") if sql["union"] is not None: res.add("union") if sql["intersect"] is not None: res.add("intersect") # or keyword ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2] if len([token for token in ao if token == "or"]) > 0: res.add("or") cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2] # not keyword if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: res.add("not") # in keyword if ( len( [ cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index("in") ] ) > 0 ): res.add("in") # like keyword if ( len( [ cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index("like") ] ) > 0 ): res.add("like") return res def eval_keywords(pred, label): pred_keywords = get_keywords(pred) label_keywords = get_keywords(label) pred_total = len(pred_keywords) label_total = len(label_keywords) cnt = 0 for k in pred_keywords: if k in label_keywords: cnt += 1 return label_total, pred_total, cnt def count_agg(units): return len([unit for unit in units if has_agg(unit)]) def count_component1(sql): count = 0 if len(sql["where"]) > 0: count += 1 if len(sql["groupBy"]) > 0: count += 1 if len(sql["orderBy"]) > 0: count += 1 if sql["limit"] is not None: count += 1 if len(sql["from"]["table_units"]) > 0: # JOIN count += len(sql["from"]["table_units"]) - 1 ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2] count += len([token for token in ao if token == "or"]) cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2] count += len( [ cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index("like") ] ) return count def count_component2(sql): nested = get_nestedSQL(sql) return len(nested) def count_others(sql): count = 0 # number of aggregation agg_count = count_agg(sql["select"][1]) agg_count += count_agg(sql["where"][::2]) agg_count += count_agg(sql["groupBy"]) if len(sql["orderBy"]) > 0: agg_count += count_agg( [unit[1] for unit in sql["orderBy"][1] if unit[1]] + [unit[2] for unit in sql["orderBy"][1] if unit[2]] ) agg_count += count_agg(sql["having"]) if agg_count > 1: count += 1 # number of select columns if len(sql["select"][1]) > 1: count += 1 # number of where conditions if len(sql["where"]) > 1: count += 1 # number of group by clauses if len(sql["groupBy"]) > 1: count += 1 return count class Evaluator: """A simple evaluator""" def __init__( self, db_dir, kmaps, etype, plug_value, keep_distinct, progress_bar_for_each_datapoint ): self.db_dir = db_dir self.kmaps = kmaps self.etype = etype self.plug_value = plug_value self.keep_distinct = keep_distinct self.progress_bar_for_each_datapoint = progress_bar_for_each_datapoint self.db_paths = {} self.schemas = {} self.scores = {} for turn in TURNS: self.scores[turn] = {"count": 0, "exact": 0.0} self.scores[turn]["exec"] = 0 for level in LEVELS: self.scores[level] = {"count": 0, "partial": {}, "exact": 0.0} self.scores[level]["exec"] = 0 for type_ in PARTIAL_TYPES: self.scores[level]["partial"][type_] = { "acc": 0.0, "rec": 0.0, "f1": 0.0, "acc_count": 0, "rec_count": 0, } def eval_hardness(self, sql): count_comp1_ = count_component1(sql) count_comp2_ = count_component2(sql) count_others_ = count_others(sql) if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: return "easy" elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or ( count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0 ): return "medium" elif ( (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1) ): return "hard" else: return "extra" @classmethod def eval_exact_match(cls, pred, label, partial_scores): for key, score in partial_scores.items(): if score["f1"] != 1: return 0 if len(label["from"]["table_units"]) > 0: label_tables = sorted(label["from"]["table_units"]) pred_tables = sorted(pred["from"]["table_units"]) return label_tables == pred_tables return 1 @classmethod def eval_partial_match(cls, pred, label): res = {} label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res["select"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) res["select(no AGG)"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res["where"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) res["where(no OP)"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } label_total, pred_total, cnt = eval_group(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res["group(no Having)"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } label_total, pred_total, cnt = eval_having(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res["group"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } label_total, pred_total, cnt = eval_order(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res["order"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } label_total, pred_total, cnt = eval_and_or(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res["and/or"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } label_total, pred_total, cnt = eval_IUEN(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res["IUEN"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } label_total, pred_total, cnt = eval_keywords(pred, label) acc, rec, f1 = get_scores(cnt, pred_total, label_total) res["keywords"] = { "acc": acc, "rec": rec, "f1": f1, "label_total": label_total, "pred_total": pred_total, } return res def evaluate_one(self, db_name, gold, predicted, setup_sql, validate_sql, turn_scores, idx, category): if db_name not in self.db_paths: db_path = os.path.join(self.db_dir, db_name, db_name + ".duckdb") self.db_paths[db_name] = db_path self.schemas[db_name] = Schema(get_schema(db_path)) if idx > 3: idx = "> 4" else: idx += 1 turn_id = "turn " + str(idx) hardness = category self.scores[turn_id]["count"] += 1 self.scores[hardness]["count"] += 1 self.scores["all"]["count"] += 1 if self.etype in ['all', 'match']: schema = self.schemas[db_name] g_sql = get_sql(schema, gold) self.scores[hardness]["count"] += 1 try: p_sql = get_sql(schema, predicted) except: # If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql p_sql = { "except": None, "from": {"conds": [], "table_units": []}, "groupBy": [], "having": [], "intersect": None, "limit": None, "orderBy": [], "select": [False, []], "union": None, "where": [], } if self.etype in ["all", "exec"]: exec_score = eval_exec_match( db=self.db_paths[db_name], p_str=predicted, g_str=gold, setup_sql=setup_sql, validate_sql=validate_sql, plug_value=self.plug_value, keep_distinct=self.keep_distinct, progress_bar_for_each_datapoint=self.progress_bar_for_each_datapoint, ) if exec_score: self.scores[hardness]["exec"] += 1 self.scores[turn_id]["exec"] += 1 self.scores["all"]["exec"] += 1 turn_scores["exec"].append(1) else: turn_scores["exec"].append(0) if self.etype in ["all", "match"]: # rebuild sql for value evaluation kmap = self.kmaps[db_name] g_valid_col_units = build_valid_col_units( g_sql["from"]["table_units"], schema ) g_sql = rebuild_sql_val(g_sql) g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) p_valid_col_units = build_valid_col_units( p_sql["from"]["table_units"], schema ) p_sql = rebuild_sql_val(p_sql) p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) partial_scores = self.eval_partial_match(p_sql, g_sql) exact_score = self.eval_exact_match(p_sql, g_sql, partial_scores) if exact_score == 0: turn_scores["exact"].append(0) print("{} pred: {}".format(hardness, predicted)) print("{} gold: {}".format(hardness, gold)) print("") else: turn_scores["exact"].append(1) self.scores[turn_id]["exact"] += exact_score self.scores[hardness]["exact"] += exact_score self.scores["all"]["exact"] += exact_score for type_ in PARTIAL_TYPES: if partial_scores[type_]["pred_total"] > 0: self.scores[hardness]["partial"][type_]["acc"] += partial_scores[ type_ ]["acc"] self.scores[hardness]["partial"][type_]["acc_count"] += 1 if partial_scores[type_]["label_total"] > 0: self.scores[hardness]["partial"][type_]["rec"] += partial_scores[ type_ ]["rec"] self.scores[hardness]["partial"][type_]["rec_count"] += 1 self.scores[hardness]["partial"][type_]["f1"] += partial_scores[type_][ "f1" ] if partial_scores[type_]["pred_total"] > 0: self.scores["all"]["partial"][type_]["acc"] += partial_scores[type_][ "acc" ] self.scores["all"]["partial"][type_]["acc_count"] += 1 if partial_scores[type_]["label_total"] > 0: self.scores["all"]["partial"][type_]["rec"] += partial_scores[type_][ "rec" ] self.scores["all"]["partial"][type_]["rec_count"] += 1 self.scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] result = { "predictSQL": predicted, "goldSQL": gold, } if self.etype in ['all', 'match']: result.update({ "hardness": hardness, "exact": exact_score, "partial": partial_scores, }) if self.etype in ['all', 'exec']: result['exec'] = exec_score return result def finalize(self): scores = self.scores for turn in TURNS: if scores[turn]["count"] == 0: continue if self.etype in ["all", "exec"]: scores[turn]["exec"] /= scores[turn]["count"] if self.etype in ["all", "match"]: scores[turn]["exact"] /= scores[turn]["count"] for level in LEVELS: if scores[level]["count"] == 0: continue if self.etype in ["all", "exec"]: scores[level]["exec"] /= scores[level]["count"] if self.etype in ["all", "match"]: scores[level]["exact"] /= scores[level]["count"] for type_ in PARTIAL_TYPES: if scores[level]["partial"][type_]["acc_count"] == 0: scores[level]["partial"][type_]["acc"] = 0 else: scores[level]["partial"][type_]["acc"] = ( scores[level]["partial"][type_]["acc"] / scores[level]["partial"][type_]["acc_count"] * 1.0 ) if scores[level]["partial"][type_]["rec_count"] == 0: scores[level]["partial"][type_]["rec"] = 0 else: scores[level]["partial"][type_]["rec"] = ( scores[level]["partial"][type_]["rec"] / scores[level]["partial"][type_]["rec_count"] * 1.0 ) if ( scores[level]["partial"][type_]["acc"] == 0 and scores[level]["partial"][type_]["rec"] == 0 ): scores[level]["partial"][type_]["f1"] = 1 else: scores[level]["partial"][type_]["f1"] = ( 2.0 * scores[level]["partial"][type_]["acc"] * scores[level]["partial"][type_]["rec"] / ( scores[level]["partial"][type_]["rec"] + scores[level]["partial"][type_]["acc"] ) ) def isValidSQL(sql, db): conn = sqlite3.connect(db) cursor = conn.cursor() try: cursor.execute(sql) except: return False return True def print_formated_s(row_name, l, element_format): template = "{:20} " + " ".join([element_format] * len(l)) print(template.format(row_name, *l)) def print_scores(scores, etype, include_turn_acc=True): turns = TURNS levels = ["easy", "medium", "hard", "duckdb", "ddl", "all"] if include_turn_acc: levels.append("joint_all") partial_types = PARTIAL_TYPES print_formated_s("", levels, "{:20}") counts = [scores[level]["count"] for level in levels] print_formated_s("count", counts, "{:<20d}") if etype in ["all", "exec"]: print("===================== EXECUTION ACCURACY =====================") exec_scores = [scores[level]["exec"] for level in levels] print_formated_s("execution", exec_scores, "{:<20.3f}") if etype in ["all", "match"]: print("\n====================== EXACT MATCHING ACCURACY =====================") exact_scores = [scores[level]["exact"] for level in levels] print_formated_s("exact match", exact_scores, "{:<20.3f}") print("\n---------------------PARTIAL MATCHING ACCURACY----------------------") for type_ in partial_types: this_scores = [scores[level]["partial"][type_]["acc"] for level in levels] print_formated_s(type_, this_scores, "{:<20.3f}") print("---------------------- PARTIAL MATCHING RECALL ----------------------") for type_ in partial_types: this_scores = [scores[level]["partial"][type_]["rec"] for level in levels] print_formated_s(type_, this_scores, "{:<20.3f}") print("---------------------- PARTIAL MATCHING F1 --------------------------") for type_ in partial_types: this_scores = [scores[level]["partial"][type_]["f1"] for level in levels] print_formated_s(type_, this_scores, "{:<20.3f}") if include_turn_acc: print() print() print_formated_s("", turns, "{:20}") counts = [scores[turn]["count"] for turn in turns] print_formated_s("count", counts, "{:<20d}") if etype in ["all", "exec"]: print( "===================== TURN EXECUTION ACCURACY =====================" ) exec_scores = [scores[turn]["exec"] for turn in turns] print_formated_s("execution", exec_scores, "{:<20.3f}") if etype in ["all", "match"]: print( "\n====================== TURN EXACT MATCHING ACCURACY =====================" ) exact_scores = [scores[turn]["exact"] for turn in turns] print_formated_s("exact match", exact_scores, "{:<20.3f}") def evaluate( gold, predict, db_dir, etype, kmaps, plug_value, keep_distinct, progress_bar_for_each_datapoint, ): with open(gold) as f: glist = [] gseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: glist.append(gseq_one) gseq_one = [] else: lstrip = l.strip().split("\t") gseq_one.append(lstrip) # include the last session # this was previously ignored in the SParC evaluation script # which might lead to slight differences in scores if len(gseq_one) != 0: glist.append(gseq_one) # spider formatting indicates that there is only one "single turn" # do not report "turn accuracy" for SPIDER include_turn_acc = len(glist) > 1 with open(predict) as f: plist = [] pseq_one = [] for l in f.readlines(): if len(l.strip()) == 0: plist.append(pseq_one) pseq_one = [] else: pseq_one.append(l.strip().split("\t")) if len(pseq_one) != 0: plist.append(pseq_one) assert len(plist) == len(glist), "number of sessions must equal" evaluator = Evaluator(db_dir, kmaps, etype, plug_value, keep_distinct, progress_bar_for_each_datapoint) results = [] for i, (p, g) in enumerate(zip(plist, glist)): if (i + 1) % 10 == 0: print("Evaluating %dth prediction" % (i + 1)) evaluator.scores["joint_all"]["count"] += 1 turn_scores = {"exec": [], "exact": []} for idx, pg in enumerate(zip(p, g)): p, g = pg p_str = p[0] p_str = p_str.replace("value", "1") g_str, db_name = g results.append(evaluator.evaluate_one(db_name, g_str, p_str, "", "", turn_scores, idx, "")) if all(v == 1 for v in turn_scores["exec"]): evaluator.scores["joint_all"]["exec"] += 1 if all(v == 1 for v in turn_scores["exact"]): evaluator.scores["joint_all"]["exact"] += 1 evaluator.finalize() print_scores(evaluator.scores, etype, include_turn_acc=include_turn_acc) return { "per_item": results, "total_scores": evaluator.scores } # Rebuild SQL functions for value evaluation def rebuild_cond_unit_val(cond_unit): if cond_unit is None or not DISABLE_VALUE: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit if type(val1) is not dict: val1 = None else: val1 = rebuild_sql_val(val1) if type(val2) is not dict: val2 = None else: val2 = rebuild_sql_val(val2) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_val(condition): if condition is None or not DISABLE_VALUE: return condition res = [] for idx, it in enumerate(condition): if idx % 2 == 0: res.append(rebuild_cond_unit_val(it)) else: res.append(it) return res def rebuild_sql_val(sql): if sql is None or not DISABLE_VALUE: return sql sql["from"]["conds"] = rebuild_condition_val(sql["from"]["conds"]) sql["having"] = rebuild_condition_val(sql["having"]) sql["where"] = rebuild_condition_val(sql["where"]) sql["intersect"] = rebuild_sql_val(sql["intersect"]) sql["except"] = rebuild_sql_val(sql["except"]) sql["union"] = rebuild_sql_val(sql["union"]) return sql # Rebuild SQL functions for foreign key evaluation def build_valid_col_units(table_units, schema): col_ids = [ table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE["table_unit"] ] prefixs = [col_id[:-2] for col_id in col_ids] valid_col_units = [] for value in schema.idMap.values(): if "." in value and value[: value.index(".")] in prefixs: valid_col_units.append(value) return valid_col_units def rebuild_col_unit_col(valid_col_units, col_unit, kmap): if col_unit is None: return col_unit agg_id, col_id, distinct = col_unit if col_id in kmap and col_id in valid_col_units: col_id = kmap[col_id] if DISABLE_DISTINCT: distinct = None return agg_id, col_id, distinct def rebuild_val_unit_col(valid_col_units, val_unit, kmap): if val_unit is None: return val_unit unit_op, col_unit1, col_unit2 = val_unit col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) return unit_op, col_unit1, col_unit2 def rebuild_table_unit_col(valid_col_units, table_unit, kmap): if table_unit is None: return table_unit table_type, col_unit_or_sql = table_unit if isinstance(col_unit_or_sql, tuple): col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) return table_type, col_unit_or_sql def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): if cond_unit is None: return cond_unit not_op, op_id, val_unit, val1, val2 = cond_unit val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) return not_op, op_id, val_unit, val1, val2 def rebuild_condition_col(valid_col_units, condition, kmap): for idx in range(len(condition)): if idx % 2 == 0: condition[idx] = rebuild_cond_unit_col( valid_col_units, condition[idx], kmap ) return condition def rebuild_select_col(valid_col_units, sel, kmap): if sel is None: return sel distinct, _list = sel new_list = [] for it in _list: agg_id, val_unit = it new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) if DISABLE_DISTINCT: distinct = None return distinct, new_list def rebuild_from_col(valid_col_units, from_, kmap): if from_ is None: return from_ from_["table_units"] = [ rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_["table_units"] ] from_["conds"] = rebuild_condition_col(valid_col_units, from_["conds"], kmap) return from_ def rebuild_group_by_col(valid_col_units, group_by, kmap): if group_by is None: return group_by return [ rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by ] def rebuild_order_by_col(valid_col_units, order_by, kmap): if order_by is None or len(order_by) == 0: return order_by direction, val_units = order_by new_val_units = [ rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units ] return direction, new_val_units def rebuild_sql_col(valid_col_units, sql, kmap): if sql is None: return sql sql["select"] = rebuild_select_col(valid_col_units, sql["select"], kmap) sql["from"] = rebuild_from_col(valid_col_units, sql["from"], kmap) sql["where"] = rebuild_condition_col(valid_col_units, sql["where"], kmap) sql["groupBy"] = rebuild_group_by_col(valid_col_units, sql["groupBy"], kmap) sql["orderBy"] = rebuild_order_by_col(valid_col_units, sql["orderBy"], kmap) sql["having"] = rebuild_condition_col(valid_col_units, sql["having"], kmap) sql["intersect"] = rebuild_sql_col(valid_col_units, sql["intersect"], kmap) sql["except"] = rebuild_sql_col(valid_col_units, sql["except"], kmap) sql["union"] = rebuild_sql_col(valid_col_units, sql["union"], kmap) return sql def build_foreign_key_map(entry): cols_orig = entry["column_names_original"] tables_orig = entry["table_names_original"] # rebuild cols corresponding to idmap in Schema cols = [] for col_orig in cols_orig: if col_orig[0] >= 0: t = tables_orig[col_orig[0]] c = col_orig[1] cols.append("__" + t.lower() + "." + c.lower() + "__") else: cols.append("__all__") def keyset_in_list(k1, k2, k_list): for k_set in k_list: if k1 in k_set or k2 in k_set: return k_set new_k_set = set() k_list.append(new_k_set) return new_k_set foreign_key_list = [] foreign_keys = entry["foreign_keys"] for fkey in foreign_keys: key1, key2 = fkey key_set = keyset_in_list(key1, key2, foreign_key_list) key_set.add(key1) key_set.add(key2) foreign_key_map = {} for key_set in foreign_key_list: sorted_list = sorted(list(key_set)) midx = sorted_list[0] for idx in sorted_list: foreign_key_map[cols[idx]] = cols[midx] return foreign_key_map def build_foreign_key_map_from_json(table): with open(table) as f: data = json.load(f) tables = {} for entry in data: tables[entry["db_id"]] = build_foreign_key_map(entry) return tables if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--gold", dest="gold", type=str, help="the path to the gold queries" ) parser.add_argument( "--pred", dest="pred", type=str, help="the path to the predicted queries" ) parser.add_argument( "--db", dest="db", type=str, help="the directory that contains all the databases and test suites", ) parser.add_argument( "--table", dest="table", type=str, help="the tables.json schema file" ) parser.add_argument( "--etype", dest="etype", type=str, default="exec", help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", choices=("all", "exec", "match"), ) parser.add_argument( "--plug_value", default=False, action="store_true", help="whether to plug in the gold value into the predicted query; suitable if your model does not predict values.", ) parser.add_argument( "--keep_distinct", default=False, action="store_true", help="whether to keep distinct keyword during evaluation. default is false.", ) parser.add_argument( "--progress_bar_for_each_datapoint", default=False, action="store_true", help="whether to print progress bar of running test inputs for each datapoint", ) args = parser.parse_args() # only evaluting exact match needs this argument kmaps = None if args.etype in ["all", "match"]: assert ( args.table is not None ), "table argument must be non-None if exact set match is evaluated" kmaps = build_foreign_key_map_from_json(args.table) evaluate( args.gold, args.pred, args.db, args.etype, kmaps, args.plug_value, args.keep_distinct, args.progress_bar_for_each_datapoint, )