qianxiao1111's picture
upgrade: add benchmarks eval
2a26d3b
raw
history blame
4.18 kB
import json
import argparse
from text2sql.src.evaluation import evaluation_main
# from text2sql.src.evaluation_ves import evaluation_ves_main
from text2sql.src.gpt_request import generate_main
from text2sql.src.gpt_request_encoder import generate_main_encoder
def main(args):
if args.eval_data_name == "bird" and args.mode == "dev":
args.db_root_path = "table_related_benchmarks/evalset/bird_data/dev_databases"
args.eval_data_path = "table_related_benchmarks/evalset/bird_data/dev.json"
args.ground_truth_path = "table_related_benchmarks/evalset/bird_data/dev.sql"
# args.mode = "dev"
# args.use_knowledge = "True"
if args.eval_data_name == "spider" and args.mode == "test":
args.db_root_path = "table_related_benchmarks/evalset/spider_data/test_database"
args.eval_data_path = "table_related_benchmarks/evalset/spider_data/test.json"
args.ground_truth_path = "table_related_benchmarks/evalset/spider_data/test_gold.sql"
# args.mode = "test"
# args.use_knowledge = "False"
if args.eval_data_name == "spider" and args.mode == "dev":
args.db_root_path = "table_related_benchmarks/evalset/spider_data/dev_database"
args.eval_data_path = "table_related_benchmarks/evalset/spider_data/dev.json"
args.ground_truth_path = "table_related_benchmarks/evalset/spider_data/dev_gold.sql"
# args.mode = "dev"
# # args.use_knowledge = "False"
# #
if args.is_use_knowledge:
args.use_knowledge = "True"
else:
args.use_knowledge = "False"
eval_datas = json.load(open(args.eval_data_path, 'r'))
# eval_datas = eval_datas[:3]
# '''for debug'''
# datas = []
# simple_num = 0
# moderate_num = 0
# challenging_num = 0
# filter_num = 5
# for i,data in enumerate(eval_datas):
# if simple_num== moderate_num==challenging_num== filter_num:
# break
# if eval(f"{data['difficulty']}_num") >= filter_num:
# continue
# if data['difficulty'] == 'simple':
# datas.append(data)
# simple_num += 1
# if data['difficulty'] == 'moderate':
# datas.append(data)
# moderate_num += 1
# if data['difficulty'] == 'challenging':
# datas.append(data)
# challenging_num += 1
# import copy
# eval_datas = copy.deepcopy(datas)
# '''for debug'''
if args.use_encoder:
predicted_sql_path = generate_main_encoder(eval_datas, args)
# predicted_sql_path = "table_related_benchmarks/text2sql/output/predict_dev.json"
else:
# predicted_sql_path = "table_related_benchmarks/text2sql/output/predict_dev.json"
predicted_sql_path = generate_main(eval_datas, args)
evaluation_main(args, eval_datas, predicted_sql_path)
if __name__ == '__main__':
args_parser = argparse.ArgumentParser()
args_parser.add_argument('--eval_type', type=str, choices=["ex"], default="ex")
args_parser.add_argument('--eval_data_name', type=str, choices=["bird", "spider"], default="bird")
args_parser.add_argument('--mode', type=str, choices=["dev", "test"], default="dev")
args_parser.add_argument('--is_use_knowledge', default=False, action="store_true")
args_parser.add_argument('--data_output_path', type=str, default="table_related_benchmarks/text2sql/output")
args_parser.add_argument('--chain_of_thought', type=str, default="False")
args_parser.add_argument('--model_path', type=str) # , required=True
args_parser.add_argument('--gpus_num', type=int, default=1)
args_parser.add_argument('--num_cpus', type=int, default=4)
args_parser.add_argument('--meta_time_out', type=float, default=30.0)
args_parser.add_argument('--use_encoder', default=False, action="store_true")
args_parser.add_argument('--use_gpt_api', default=False, action="store_true")
args = args_parser.parse_args()
main(args)
# CUDA_VISIBLE_DEVICES=6 python table_related_benchmarks/run_text2sql_eval.py --model_path /data4/sft_output/qwen2.5-7b-ins-0929/checkpoint-3200 --eval_data_name spider