Spaces:
Running
Running
################################ | |
# val: number(float)/string(str)/sql(dict) | |
# col_unit: (agg_id, col_id, isDistinct(bool)) | |
# val_unit: (unit_op, col_unit1, col_unit2) | |
# table_unit: (table_type, col_unit/sql) | |
# cond_unit: (not_op, op_id, val_unit, val1, val2) | |
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...] | |
# sql { | |
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) | |
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} | |
# 'where': condition | |
# 'groupBy': [col_unit1, col_unit2, ...] | |
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) | |
# 'having': condition | |
# 'limit': None/limit value | |
# 'intersect': None/sql | |
# 'except': None/sql | |
# 'union': None/sql | |
# } | |
################################ | |
import os | |
import json | |
import sqlite3 | |
import argparse | |
from .process_sql import get_schema, Schema, get_sql | |
from .exec_eval import eval_exec_match | |
# Flag to disable value evaluation | |
LEVELS = ["easy", "medium", "hard", "duckdb", "ddl", "all"] | |
TURNS = ["turn 1", "turn 2", "turn 3", "turn 4", "turn > 4"] | |
PARTIAL_TYPES = [ | |
"select", | |
"select(no AGG)", | |
"where", | |
"where(no OP)", | |
"group(no Having)", | |
"group", | |
"order", | |
"and/or", | |
"IUEN", | |
"keywords", | |
] | |
DISABLE_VALUE = True | |
# Flag to disable distinct in select evaluation | |
DISABLE_DISTINCT = True | |
CLAUSE_KEYWORDS = ( | |
"select", | |
"from", | |
"where", | |
"group", | |
"order", | |
"limit", | |
"intersect", | |
"union", | |
"except", | |
) | |
JOIN_KEYWORDS = ("join", "on", "as") | |
WHERE_OPS = ( | |
"not", | |
"between", | |
"=", | |
">", | |
"<", | |
">=", | |
"<=", | |
"!=", | |
"in", | |
"like", | |
"is", | |
"exists", | |
) | |
UNIT_OPS = ("none", "-", "+", "*", "/") | |
AGG_OPS = ("none", "max", "min", "count", "sum", "avg") | |
TABLE_TYPE = { | |
"sql": "sql", | |
"table_unit": "table_unit", | |
} | |
COND_OPS = ("and", "or") | |
SQL_OPS = ("intersect", "union", "except") | |
ORDER_OPS = ("desc", "asc") | |
HARDNESS = { | |
"component1": ("where", "group", "order", "limit", "join", "or", "like"), | |
"component2": ("except", "union", "intersect"), | |
} | |
def condition_has_or(conds): | |
return "or" in conds[1::2] | |
def condition_has_like(conds): | |
return WHERE_OPS.index("like") in [cond_unit[1] for cond_unit in conds[::2]] | |
def condition_has_sql(conds): | |
for cond_unit in conds[::2]: | |
val1, val2 = cond_unit[3], cond_unit[4] | |
if val1 is not None and type(val1) is dict: | |
return True | |
if val2 is not None and type(val2) is dict: | |
return True | |
return False | |
def val_has_op(val_unit): | |
return val_unit[0] != UNIT_OPS.index("none") | |
def has_agg(unit): | |
return unit[0] != AGG_OPS.index("none") | |
def accuracy(count, total): | |
if count == total: | |
return 1 | |
return 0 | |
def recall(count, total): | |
if count == total: | |
return 1 | |
return 0 | |
def F1(acc, rec): | |
if (acc + rec) == 0: | |
return 0 | |
return (2.0 * acc * rec) / (acc + rec) | |
def get_scores(count, pred_total, label_total): | |
if pred_total != label_total: | |
return 0, 0, 0 | |
elif count == pred_total: | |
return 1, 1, 1 | |
return 0, 0, 0 | |
def eval_sel(pred, label): | |
pred_sel = pred["select"][1] | |
label_sel = label["select"][1] | |
label_wo_agg = [unit[1] for unit in label_sel] | |
pred_total = len(pred_sel) | |
label_total = len(label_sel) | |
cnt = 0 | |
cnt_wo_agg = 0 | |
for unit in pred_sel: | |
if unit in label_sel: | |
cnt += 1 | |
label_sel.remove(unit) | |
if unit[1] in label_wo_agg: | |
cnt_wo_agg += 1 | |
label_wo_agg.remove(unit[1]) | |
return label_total, pred_total, cnt, cnt_wo_agg | |
def eval_where(pred, label): | |
pred_conds = [unit for unit in pred["where"][::2]] | |
label_conds = [unit for unit in label["where"][::2]] | |
label_wo_agg = [unit[2] for unit in label_conds] | |
pred_total = len(pred_conds) | |
label_total = len(label_conds) | |
cnt = 0 | |
cnt_wo_agg = 0 | |
for unit in pred_conds: | |
if unit in label_conds: | |
cnt += 1 | |
label_conds.remove(unit) | |
if unit[2] in label_wo_agg: | |
cnt_wo_agg += 1 | |
label_wo_agg.remove(unit[2]) | |
return label_total, pred_total, cnt, cnt_wo_agg | |
def eval_group(pred, label): | |
pred_cols = [unit[1] for unit in pred["groupBy"]] | |
label_cols = [unit[1] for unit in label["groupBy"]] | |
pred_total = len(pred_cols) | |
label_total = len(label_cols) | |
cnt = 0 | |
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] | |
label_cols = [ | |
label.split(".")[1] if "." in label else label for label in label_cols | |
] | |
for col in pred_cols: | |
if col in label_cols: | |
cnt += 1 | |
label_cols.remove(col) | |
return label_total, pred_total, cnt | |
def eval_having(pred, label): | |
pred_total = label_total = cnt = 0 | |
if len(pred["groupBy"]) > 0: | |
pred_total = 1 | |
if len(label["groupBy"]) > 0: | |
label_total = 1 | |
pred_cols = [unit[1] for unit in pred["groupBy"]] | |
label_cols = [unit[1] for unit in label["groupBy"]] | |
if ( | |
pred_total == label_total == 1 | |
and pred_cols == label_cols | |
and pred["having"] == label["having"] | |
): | |
cnt = 1 | |
return label_total, pred_total, cnt | |
def eval_order(pred, label): | |
pred_total = label_total = cnt = 0 | |
if len(pred["orderBy"]) > 0: | |
pred_total = 1 | |
if len(label["orderBy"]) > 0: | |
label_total = 1 | |
if ( | |
len(label["orderBy"]) > 0 | |
and pred["orderBy"] == label["orderBy"] | |
and ( | |
(pred["limit"] is None and label["limit"] is None) | |
or (pred["limit"] is not None and label["limit"] is not None) | |
) | |
): | |
cnt = 1 | |
return label_total, pred_total, cnt | |
def eval_and_or(pred, label): | |
pred_ao = pred["where"][1::2] | |
label_ao = label["where"][1::2] | |
pred_ao = set(pred_ao) | |
label_ao = set(label_ao) | |
if pred_ao == label_ao: | |
return 1, 1, 1 | |
return len(pred_ao), len(label_ao), 0 | |
def get_nestedSQL(sql): | |
nested = [] | |
for cond_unit in sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2]: | |
if type(cond_unit[3]) is dict: | |
nested.append(cond_unit[3]) | |
if type(cond_unit[4]) is dict: | |
nested.append(cond_unit[4]) | |
if sql["intersect"] is not None: | |
nested.append(sql["intersect"]) | |
if sql["except"] is not None: | |
nested.append(sql["except"]) | |
if sql["union"] is not None: | |
nested.append(sql["union"]) | |
return nested | |
def eval_nested(pred, label): | |
label_total = 0 | |
pred_total = 0 | |
cnt = 0 | |
if pred is not None: | |
pred_total += 1 | |
if label is not None: | |
label_total += 1 | |
if pred is not None and label is not None: | |
partial_scores = Evaluator.eval_partial_match(pred, label) | |
cnt += Evaluator.eval_exact_match(pred, label, partial_scores) | |
return label_total, pred_total, cnt | |
def eval_IUEN(pred, label): | |
lt1, pt1, cnt1 = eval_nested(pred["intersect"], label["intersect"]) | |
lt2, pt2, cnt2 = eval_nested(pred["except"], label["except"]) | |
lt3, pt3, cnt3 = eval_nested(pred["union"], label["union"]) | |
label_total = lt1 + lt2 + lt3 | |
pred_total = pt1 + pt2 + pt3 | |
cnt = cnt1 + cnt2 + cnt3 | |
return label_total, pred_total, cnt | |
def get_keywords(sql): | |
res = set() | |
if len(sql["where"]) > 0: | |
res.add("where") | |
if len(sql["groupBy"]) > 0: | |
res.add("group") | |
if len(sql["having"]) > 0: | |
res.add("having") | |
if len(sql["orderBy"]) > 0: | |
res.add(sql["orderBy"][0]) | |
res.add("order") | |
if sql["limit"] is not None: | |
res.add("limit") | |
if sql["except"] is not None: | |
res.add("except") | |
if sql["union"] is not None: | |
res.add("union") | |
if sql["intersect"] is not None: | |
res.add("intersect") | |
# or keyword | |
ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2] | |
if len([token for token in ao if token == "or"]) > 0: | |
res.add("or") | |
cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2] | |
# not keyword | |
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: | |
res.add("not") | |
# in keyword | |
if ( | |
len( | |
[ | |
cond_unit | |
for cond_unit in cond_units | |
if cond_unit[1] == WHERE_OPS.index("in") | |
] | |
) | |
> 0 | |
): | |
res.add("in") | |
# like keyword | |
if ( | |
len( | |
[ | |
cond_unit | |
for cond_unit in cond_units | |
if cond_unit[1] == WHERE_OPS.index("like") | |
] | |
) | |
> 0 | |
): | |
res.add("like") | |
return res | |
def eval_keywords(pred, label): | |
pred_keywords = get_keywords(pred) | |
label_keywords = get_keywords(label) | |
pred_total = len(pred_keywords) | |
label_total = len(label_keywords) | |
cnt = 0 | |
for k in pred_keywords: | |
if k in label_keywords: | |
cnt += 1 | |
return label_total, pred_total, cnt | |
def count_agg(units): | |
return len([unit for unit in units if has_agg(unit)]) | |
def count_component1(sql): | |
count = 0 | |
if len(sql["where"]) > 0: | |
count += 1 | |
if len(sql["groupBy"]) > 0: | |
count += 1 | |
if len(sql["orderBy"]) > 0: | |
count += 1 | |
if sql["limit"] is not None: | |
count += 1 | |
if len(sql["from"]["table_units"]) > 0: # JOIN | |
count += len(sql["from"]["table_units"]) - 1 | |
ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2] | |
count += len([token for token in ao if token == "or"]) | |
cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2] | |
count += len( | |
[ | |
cond_unit | |
for cond_unit in cond_units | |
if cond_unit[1] == WHERE_OPS.index("like") | |
] | |
) | |
return count | |
def count_component2(sql): | |
nested = get_nestedSQL(sql) | |
return len(nested) | |
def count_others(sql): | |
count = 0 | |
# number of aggregation | |
agg_count = count_agg(sql["select"][1]) | |
agg_count += count_agg(sql["where"][::2]) | |
agg_count += count_agg(sql["groupBy"]) | |
if len(sql["orderBy"]) > 0: | |
agg_count += count_agg( | |
[unit[1] for unit in sql["orderBy"][1] if unit[1]] | |
+ [unit[2] for unit in sql["orderBy"][1] if unit[2]] | |
) | |
agg_count += count_agg(sql["having"]) | |
if agg_count > 1: | |
count += 1 | |
# number of select columns | |
if len(sql["select"][1]) > 1: | |
count += 1 | |
# number of where conditions | |
if len(sql["where"]) > 1: | |
count += 1 | |
# number of group by clauses | |
if len(sql["groupBy"]) > 1: | |
count += 1 | |
return count | |
class Evaluator: | |
"""A simple evaluator""" | |
def __init__( | |
self, | |
db_dir, | |
kmaps, | |
etype, | |
plug_value, | |
keep_distinct, | |
progress_bar_for_each_datapoint | |
): | |
self.db_dir = db_dir | |
self.kmaps = kmaps | |
self.etype = etype | |
self.plug_value = plug_value | |
self.keep_distinct = keep_distinct | |
self.progress_bar_for_each_datapoint = progress_bar_for_each_datapoint | |
self.db_paths = {} | |
self.schemas = {} | |
self.scores = {} | |
for turn in TURNS: | |
self.scores[turn] = {"count": 0, "exact": 0.0} | |
self.scores[turn]["exec"] = 0 | |
for level in LEVELS: | |
self.scores[level] = {"count": 0, "partial": {}, "exact": 0.0} | |
self.scores[level]["exec"] = 0 | |
for type_ in PARTIAL_TYPES: | |
self.scores[level]["partial"][type_] = { | |
"acc": 0.0, | |
"rec": 0.0, | |
"f1": 0.0, | |
"acc_count": 0, | |
"rec_count": 0, | |
} | |
def eval_hardness(self, sql): | |
count_comp1_ = count_component1(sql) | |
count_comp2_ = count_component2(sql) | |
count_others_ = count_others(sql) | |
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: | |
return "easy" | |
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or ( | |
count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0 | |
): | |
return "medium" | |
elif ( | |
(count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) | |
or (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) | |
or (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1) | |
): | |
return "hard" | |
else: | |
return "extra" | |
def eval_exact_match(cls, pred, label, partial_scores): | |
for key, score in partial_scores.items(): | |
if score["f1"] != 1: | |
return 0 | |
if len(label["from"]["table_units"]) > 0: | |
label_tables = sorted(label["from"]["table_units"]) | |
pred_tables = sorted(pred["from"]["table_units"]) | |
return label_tables == pred_tables | |
return 1 | |
def eval_partial_match(cls, pred, label): | |
res = {} | |
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) | |
acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
res["select"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) | |
res["select(no AGG)"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) | |
acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
res["where"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) | |
res["where(no OP)"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
label_total, pred_total, cnt = eval_group(pred, label) | |
acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
res["group(no Having)"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
label_total, pred_total, cnt = eval_having(pred, label) | |
acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
res["group"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
label_total, pred_total, cnt = eval_order(pred, label) | |
acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
res["order"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
label_total, pred_total, cnt = eval_and_or(pred, label) | |
acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
res["and/or"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
label_total, pred_total, cnt = eval_IUEN(pred, label) | |
acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
res["IUEN"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
label_total, pred_total, cnt = eval_keywords(pred, label) | |
acc, rec, f1 = get_scores(cnt, pred_total, label_total) | |
res["keywords"] = { | |
"acc": acc, | |
"rec": rec, | |
"f1": f1, | |
"label_total": label_total, | |
"pred_total": pred_total, | |
} | |
return res | |
def evaluate_one(self, db_name, gold, predicted, setup_sql, | |
validate_sql, turn_scores, idx, category): | |
if db_name not in self.db_paths: | |
db_path = os.path.join(self.db_dir, db_name, db_name + ".duckdb") | |
self.db_paths[db_name] = db_path | |
self.schemas[db_name] = Schema(get_schema(db_path)) | |
if idx > 3: | |
idx = "> 4" | |
else: | |
idx += 1 | |
turn_id = "turn " + str(idx) | |
hardness = category | |
self.scores[turn_id]["count"] += 1 | |
self.scores[hardness]["count"] += 1 | |
self.scores["all"]["count"] += 1 | |
if self.etype in ['all', 'match']: | |
schema = self.schemas[db_name] | |
g_sql = get_sql(schema, gold) | |
self.scores[hardness]["count"] += 1 | |
try: | |
p_sql = get_sql(schema, predicted) | |
except: | |
# If p_sql is not valid, then we will use an empty sql to evaluate with the correct sql | |
p_sql = { | |
"except": None, | |
"from": {"conds": [], "table_units": []}, | |
"groupBy": [], | |
"having": [], | |
"intersect": None, | |
"limit": None, | |
"orderBy": [], | |
"select": [False, []], | |
"union": None, | |
"where": [], | |
} | |
if self.etype in ["all", "exec"]: | |
exec_score = eval_exec_match( | |
db=self.db_paths[db_name], | |
p_str=predicted, | |
g_str=gold, | |
setup_sql=setup_sql, | |
validate_sql=validate_sql, | |
plug_value=self.plug_value, | |
keep_distinct=self.keep_distinct, | |
progress_bar_for_each_datapoint=self.progress_bar_for_each_datapoint, | |
) | |
if exec_score: | |
self.scores[hardness]["exec"] += 1 | |
self.scores[turn_id]["exec"] += 1 | |
self.scores["all"]["exec"] += 1 | |
turn_scores["exec"].append(1) | |
else: | |
turn_scores["exec"].append(0) | |
if self.etype in ["all", "match"]: | |
# rebuild sql for value evaluation | |
kmap = self.kmaps[db_name] | |
g_valid_col_units = build_valid_col_units( | |
g_sql["from"]["table_units"], schema | |
) | |
g_sql = rebuild_sql_val(g_sql) | |
g_sql = rebuild_sql_col(g_valid_col_units, g_sql, kmap) | |
p_valid_col_units = build_valid_col_units( | |
p_sql["from"]["table_units"], schema | |
) | |
p_sql = rebuild_sql_val(p_sql) | |
p_sql = rebuild_sql_col(p_valid_col_units, p_sql, kmap) | |
partial_scores = self.eval_partial_match(p_sql, g_sql) | |
exact_score = self.eval_exact_match(p_sql, g_sql, partial_scores) | |
if exact_score == 0: | |
turn_scores["exact"].append(0) | |
print("{} pred: {}".format(hardness, predicted)) | |
print("{} gold: {}".format(hardness, gold)) | |
print("") | |
else: | |
turn_scores["exact"].append(1) | |
self.scores[turn_id]["exact"] += exact_score | |
self.scores[hardness]["exact"] += exact_score | |
self.scores["all"]["exact"] += exact_score | |
for type_ in PARTIAL_TYPES: | |
if partial_scores[type_]["pred_total"] > 0: | |
self.scores[hardness]["partial"][type_]["acc"] += partial_scores[ | |
type_ | |
]["acc"] | |
self.scores[hardness]["partial"][type_]["acc_count"] += 1 | |
if partial_scores[type_]["label_total"] > 0: | |
self.scores[hardness]["partial"][type_]["rec"] += partial_scores[ | |
type_ | |
]["rec"] | |
self.scores[hardness]["partial"][type_]["rec_count"] += 1 | |
self.scores[hardness]["partial"][type_]["f1"] += partial_scores[type_][ | |
"f1" | |
] | |
if partial_scores[type_]["pred_total"] > 0: | |
self.scores["all"]["partial"][type_]["acc"] += partial_scores[type_][ | |
"acc" | |
] | |
self.scores["all"]["partial"][type_]["acc_count"] += 1 | |
if partial_scores[type_]["label_total"] > 0: | |
self.scores["all"]["partial"][type_]["rec"] += partial_scores[type_][ | |
"rec" | |
] | |
self.scores["all"]["partial"][type_]["rec_count"] += 1 | |
self.scores["all"]["partial"][type_]["f1"] += partial_scores[type_]["f1"] | |
result = { | |
"predictSQL": predicted, | |
"goldSQL": gold, | |
} | |
if self.etype in ['all', 'match']: | |
result.update({ | |
"hardness": hardness, | |
"exact": exact_score, | |
"partial": partial_scores, | |
}) | |
if self.etype in ['all', 'exec']: | |
result['exec'] = exec_score | |
return result | |
def finalize(self): | |
scores = self.scores | |
for turn in TURNS: | |
if scores[turn]["count"] == 0: | |
continue | |
if self.etype in ["all", "exec"]: | |
scores[turn]["exec"] /= scores[turn]["count"] | |
if self.etype in ["all", "match"]: | |
scores[turn]["exact"] /= scores[turn]["count"] | |
for level in LEVELS: | |
if scores[level]["count"] == 0: | |
continue | |
if self.etype in ["all", "exec"]: | |
scores[level]["exec"] /= scores[level]["count"] | |
if self.etype in ["all", "match"]: | |
scores[level]["exact"] /= scores[level]["count"] | |
for type_ in PARTIAL_TYPES: | |
if scores[level]["partial"][type_]["acc_count"] == 0: | |
scores[level]["partial"][type_]["acc"] = 0 | |
else: | |
scores[level]["partial"][type_]["acc"] = ( | |
scores[level]["partial"][type_]["acc"] | |
/ scores[level]["partial"][type_]["acc_count"] | |
* 1.0 | |
) | |
if scores[level]["partial"][type_]["rec_count"] == 0: | |
scores[level]["partial"][type_]["rec"] = 0 | |
else: | |
scores[level]["partial"][type_]["rec"] = ( | |
scores[level]["partial"][type_]["rec"] | |
/ scores[level]["partial"][type_]["rec_count"] | |
* 1.0 | |
) | |
if ( | |
scores[level]["partial"][type_]["acc"] == 0 | |
and scores[level]["partial"][type_]["rec"] == 0 | |
): | |
scores[level]["partial"][type_]["f1"] = 1 | |
else: | |
scores[level]["partial"][type_]["f1"] = ( | |
2.0 | |
* scores[level]["partial"][type_]["acc"] | |
* scores[level]["partial"][type_]["rec"] | |
/ ( | |
scores[level]["partial"][type_]["rec"] | |
+ scores[level]["partial"][type_]["acc"] | |
) | |
) | |
def isValidSQL(sql, db): | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
try: | |
cursor.execute(sql) | |
except: | |
return False | |
return True | |
def print_formated_s(row_name, l, element_format): | |
template = "{:20} " + " ".join([element_format] * len(l)) | |
print(template.format(row_name, *l)) | |
def print_scores(scores, etype, include_turn_acc=True): | |
turns = TURNS | |
levels = ["easy", "medium", "hard", "duckdb", "ddl", "all"] | |
if include_turn_acc: | |
levels.append("joint_all") | |
partial_types = PARTIAL_TYPES | |
print_formated_s("", levels, "{:20}") | |
counts = [scores[level]["count"] for level in levels] | |
print_formated_s("count", counts, "{:<20d}") | |
if etype in ["all", "exec"]: | |
print("===================== EXECUTION ACCURACY =====================") | |
exec_scores = [scores[level]["exec"] for level in levels] | |
print_formated_s("execution", exec_scores, "{:<20.3f}") | |
if etype in ["all", "match"]: | |
print("\n====================== EXACT MATCHING ACCURACY =====================") | |
exact_scores = [scores[level]["exact"] for level in levels] | |
print_formated_s("exact match", exact_scores, "{:<20.3f}") | |
print("\n---------------------PARTIAL MATCHING ACCURACY----------------------") | |
for type_ in partial_types: | |
this_scores = [scores[level]["partial"][type_]["acc"] for level in levels] | |
print_formated_s(type_, this_scores, "{:<20.3f}") | |
print("---------------------- PARTIAL MATCHING RECALL ----------------------") | |
for type_ in partial_types: | |
this_scores = [scores[level]["partial"][type_]["rec"] for level in levels] | |
print_formated_s(type_, this_scores, "{:<20.3f}") | |
print("---------------------- PARTIAL MATCHING F1 --------------------------") | |
for type_ in partial_types: | |
this_scores = [scores[level]["partial"][type_]["f1"] for level in levels] | |
print_formated_s(type_, this_scores, "{:<20.3f}") | |
if include_turn_acc: | |
print() | |
print() | |
print_formated_s("", turns, "{:20}") | |
counts = [scores[turn]["count"] for turn in turns] | |
print_formated_s("count", counts, "{:<20d}") | |
if etype in ["all", "exec"]: | |
print( | |
"===================== TURN EXECUTION ACCURACY =====================" | |
) | |
exec_scores = [scores[turn]["exec"] for turn in turns] | |
print_formated_s("execution", exec_scores, "{:<20.3f}") | |
if etype in ["all", "match"]: | |
print( | |
"\n====================== TURN EXACT MATCHING ACCURACY =====================" | |
) | |
exact_scores = [scores[turn]["exact"] for turn in turns] | |
print_formated_s("exact match", exact_scores, "{:<20.3f}") | |
def evaluate( | |
gold, | |
predict, | |
db_dir, | |
etype, | |
kmaps, | |
plug_value, | |
keep_distinct, | |
progress_bar_for_each_datapoint, | |
): | |
with open(gold) as f: | |
glist = [] | |
gseq_one = [] | |
for l in f.readlines(): | |
if len(l.strip()) == 0: | |
glist.append(gseq_one) | |
gseq_one = [] | |
else: | |
lstrip = l.strip().split("\t") | |
gseq_one.append(lstrip) | |
# include the last session | |
# this was previously ignored in the SParC evaluation script | |
# which might lead to slight differences in scores | |
if len(gseq_one) != 0: | |
glist.append(gseq_one) | |
# spider formatting indicates that there is only one "single turn" | |
# do not report "turn accuracy" for SPIDER | |
include_turn_acc = len(glist) > 1 | |
with open(predict) as f: | |
plist = [] | |
pseq_one = [] | |
for l in f.readlines(): | |
if len(l.strip()) == 0: | |
plist.append(pseq_one) | |
pseq_one = [] | |
else: | |
pseq_one.append(l.strip().split("\t")) | |
if len(pseq_one) != 0: | |
plist.append(pseq_one) | |
assert len(plist) == len(glist), "number of sessions must equal" | |
evaluator = Evaluator(db_dir, kmaps, etype, plug_value, keep_distinct, progress_bar_for_each_datapoint) | |
results = [] | |
for i, (p, g) in enumerate(zip(plist, glist)): | |
if (i + 1) % 10 == 0: | |
print("Evaluating %dth prediction" % (i + 1)) | |
evaluator.scores["joint_all"]["count"] += 1 | |
turn_scores = {"exec": [], "exact": []} | |
for idx, pg in enumerate(zip(p, g)): | |
p, g = pg | |
p_str = p[0] | |
p_str = p_str.replace("value", "1") | |
g_str, db_name = g | |
results.append(evaluator.evaluate_one(db_name, g_str, p_str, "", "", turn_scores, idx, "")) | |
if all(v == 1 for v in turn_scores["exec"]): | |
evaluator.scores["joint_all"]["exec"] += 1 | |
if all(v == 1 for v in turn_scores["exact"]): | |
evaluator.scores["joint_all"]["exact"] += 1 | |
evaluator.finalize() | |
print_scores(evaluator.scores, etype, include_turn_acc=include_turn_acc) | |
return { | |
"per_item": results, | |
"total_scores": evaluator.scores | |
} | |
# Rebuild SQL functions for value evaluation | |
def rebuild_cond_unit_val(cond_unit): | |
if cond_unit is None or not DISABLE_VALUE: | |
return cond_unit | |
not_op, op_id, val_unit, val1, val2 = cond_unit | |
if type(val1) is not dict: | |
val1 = None | |
else: | |
val1 = rebuild_sql_val(val1) | |
if type(val2) is not dict: | |
val2 = None | |
else: | |
val2 = rebuild_sql_val(val2) | |
return not_op, op_id, val_unit, val1, val2 | |
def rebuild_condition_val(condition): | |
if condition is None or not DISABLE_VALUE: | |
return condition | |
res = [] | |
for idx, it in enumerate(condition): | |
if idx % 2 == 0: | |
res.append(rebuild_cond_unit_val(it)) | |
else: | |
res.append(it) | |
return res | |
def rebuild_sql_val(sql): | |
if sql is None or not DISABLE_VALUE: | |
return sql | |
sql["from"]["conds"] = rebuild_condition_val(sql["from"]["conds"]) | |
sql["having"] = rebuild_condition_val(sql["having"]) | |
sql["where"] = rebuild_condition_val(sql["where"]) | |
sql["intersect"] = rebuild_sql_val(sql["intersect"]) | |
sql["except"] = rebuild_sql_val(sql["except"]) | |
sql["union"] = rebuild_sql_val(sql["union"]) | |
return sql | |
# Rebuild SQL functions for foreign key evaluation | |
def build_valid_col_units(table_units, schema): | |
col_ids = [ | |
table_unit[1] | |
for table_unit in table_units | |
if table_unit[0] == TABLE_TYPE["table_unit"] | |
] | |
prefixs = [col_id[:-2] for col_id in col_ids] | |
valid_col_units = [] | |
for value in schema.idMap.values(): | |
if "." in value and value[: value.index(".")] in prefixs: | |
valid_col_units.append(value) | |
return valid_col_units | |
def rebuild_col_unit_col(valid_col_units, col_unit, kmap): | |
if col_unit is None: | |
return col_unit | |
agg_id, col_id, distinct = col_unit | |
if col_id in kmap and col_id in valid_col_units: | |
col_id = kmap[col_id] | |
if DISABLE_DISTINCT: | |
distinct = None | |
return agg_id, col_id, distinct | |
def rebuild_val_unit_col(valid_col_units, val_unit, kmap): | |
if val_unit is None: | |
return val_unit | |
unit_op, col_unit1, col_unit2 = val_unit | |
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) | |
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) | |
return unit_op, col_unit1, col_unit2 | |
def rebuild_table_unit_col(valid_col_units, table_unit, kmap): | |
if table_unit is None: | |
return table_unit | |
table_type, col_unit_or_sql = table_unit | |
if isinstance(col_unit_or_sql, tuple): | |
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) | |
return table_type, col_unit_or_sql | |
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): | |
if cond_unit is None: | |
return cond_unit | |
not_op, op_id, val_unit, val1, val2 = cond_unit | |
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) | |
return not_op, op_id, val_unit, val1, val2 | |
def rebuild_condition_col(valid_col_units, condition, kmap): | |
for idx in range(len(condition)): | |
if idx % 2 == 0: | |
condition[idx] = rebuild_cond_unit_col( | |
valid_col_units, condition[idx], kmap | |
) | |
return condition | |
def rebuild_select_col(valid_col_units, sel, kmap): | |
if sel is None: | |
return sel | |
distinct, _list = sel | |
new_list = [] | |
for it in _list: | |
agg_id, val_unit = it | |
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) | |
if DISABLE_DISTINCT: | |
distinct = None | |
return distinct, new_list | |
def rebuild_from_col(valid_col_units, from_, kmap): | |
if from_ is None: | |
return from_ | |
from_["table_units"] = [ | |
rebuild_table_unit_col(valid_col_units, table_unit, kmap) | |
for table_unit in from_["table_units"] | |
] | |
from_["conds"] = rebuild_condition_col(valid_col_units, from_["conds"], kmap) | |
return from_ | |
def rebuild_group_by_col(valid_col_units, group_by, kmap): | |
if group_by is None: | |
return group_by | |
return [ | |
rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by | |
] | |
def rebuild_order_by_col(valid_col_units, order_by, kmap): | |
if order_by is None or len(order_by) == 0: | |
return order_by | |
direction, val_units = order_by | |
new_val_units = [ | |
rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units | |
] | |
return direction, new_val_units | |
def rebuild_sql_col(valid_col_units, sql, kmap): | |
if sql is None: | |
return sql | |
sql["select"] = rebuild_select_col(valid_col_units, sql["select"], kmap) | |
sql["from"] = rebuild_from_col(valid_col_units, sql["from"], kmap) | |
sql["where"] = rebuild_condition_col(valid_col_units, sql["where"], kmap) | |
sql["groupBy"] = rebuild_group_by_col(valid_col_units, sql["groupBy"], kmap) | |
sql["orderBy"] = rebuild_order_by_col(valid_col_units, sql["orderBy"], kmap) | |
sql["having"] = rebuild_condition_col(valid_col_units, sql["having"], kmap) | |
sql["intersect"] = rebuild_sql_col(valid_col_units, sql["intersect"], kmap) | |
sql["except"] = rebuild_sql_col(valid_col_units, sql["except"], kmap) | |
sql["union"] = rebuild_sql_col(valid_col_units, sql["union"], kmap) | |
return sql | |
def build_foreign_key_map(entry): | |
cols_orig = entry["column_names_original"] | |
tables_orig = entry["table_names_original"] | |
# rebuild cols corresponding to idmap in Schema | |
cols = [] | |
for col_orig in cols_orig: | |
if col_orig[0] >= 0: | |
t = tables_orig[col_orig[0]] | |
c = col_orig[1] | |
cols.append("__" + t.lower() + "." + c.lower() + "__") | |
else: | |
cols.append("__all__") | |
def keyset_in_list(k1, k2, k_list): | |
for k_set in k_list: | |
if k1 in k_set or k2 in k_set: | |
return k_set | |
new_k_set = set() | |
k_list.append(new_k_set) | |
return new_k_set | |
foreign_key_list = [] | |
foreign_keys = entry["foreign_keys"] | |
for fkey in foreign_keys: | |
key1, key2 = fkey | |
key_set = keyset_in_list(key1, key2, foreign_key_list) | |
key_set.add(key1) | |
key_set.add(key2) | |
foreign_key_map = {} | |
for key_set in foreign_key_list: | |
sorted_list = sorted(list(key_set)) | |
midx = sorted_list[0] | |
for idx in sorted_list: | |
foreign_key_map[cols[idx]] = cols[midx] | |
return foreign_key_map | |
def build_foreign_key_map_from_json(table): | |
with open(table) as f: | |
data = json.load(f) | |
tables = {} | |
for entry in data: | |
tables[entry["db_id"]] = build_foreign_key_map(entry) | |
return tables | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--gold", dest="gold", type=str, help="the path to the gold queries" | |
) | |
parser.add_argument( | |
"--pred", dest="pred", type=str, help="the path to the predicted queries" | |
) | |
parser.add_argument( | |
"--db", | |
dest="db", | |
type=str, | |
help="the directory that contains all the databases and test suites", | |
) | |
parser.add_argument( | |
"--table", dest="table", type=str, help="the tables.json schema file" | |
) | |
parser.add_argument( | |
"--etype", | |
dest="etype", | |
type=str, | |
default="exec", | |
help="evaluation type, exec for test suite accuracy, match for the original exact set match accuracy", | |
choices=("all", "exec", "match"), | |
) | |
parser.add_argument( | |
"--plug_value", | |
default=False, | |
action="store_true", | |
help="whether to plug in the gold value into the predicted query; suitable if your model does not predict values.", | |
) | |
parser.add_argument( | |
"--keep_distinct", | |
default=False, | |
action="store_true", | |
help="whether to keep distinct keyword during evaluation. default is false.", | |
) | |
parser.add_argument( | |
"--progress_bar_for_each_datapoint", | |
default=False, | |
action="store_true", | |
help="whether to print progress bar of running test inputs for each datapoint", | |
) | |
args = parser.parse_args() | |
# only evaluting exact match needs this argument | |
kmaps = None | |
if args.etype in ["all", "match"]: | |
assert ( | |
args.table is not None | |
), "table argument must be non-None if exact set match is evaluated" | |
kmaps = build_foreign_key_map_from_json(args.table) | |
evaluate( | |
args.gold, | |
args.pred, | |
args.db, | |
args.etype, | |
kmaps, | |
args.plug_value, | |
args.keep_distinct, | |
args.progress_bar_for_each_datapoint, | |
) | |