File size: 4,179 Bytes
2a26d3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json
import argparse
from text2sql.src.evaluation import evaluation_main
# from text2sql.src.evaluation_ves import evaluation_ves_main
from text2sql.src.gpt_request import generate_main
from text2sql.src.gpt_request_encoder import generate_main_encoder

def main(args):
    if args.eval_data_name == "bird" and args.mode == "dev":
        args.db_root_path = "table_related_benchmarks/evalset/bird_data/dev_databases"
        args.eval_data_path = "table_related_benchmarks/evalset/bird_data/dev.json"
        args.ground_truth_path = "table_related_benchmarks/evalset/bird_data/dev.sql"
        # args.mode = "dev"
        # args.use_knowledge = "True"
        
    if args.eval_data_name == "spider" and args.mode == "test":
        args.db_root_path = "table_related_benchmarks/evalset/spider_data/test_database"
        args.eval_data_path = "table_related_benchmarks/evalset/spider_data/test.json"
        args.ground_truth_path = "table_related_benchmarks/evalset/spider_data/test_gold.sql"
        # args.mode = "test"
        # args.use_knowledge = "False"

    if args.eval_data_name == "spider" and args.mode == "dev":
        args.db_root_path = "table_related_benchmarks/evalset/spider_data/dev_database"
        args.eval_data_path = "table_related_benchmarks/evalset/spider_data/dev.json"
        args.ground_truth_path = "table_related_benchmarks/evalset/spider_data/dev_gold.sql"
        # args.mode = "dev"
        # # args.use_knowledge = "False"
        # #  

    if args.is_use_knowledge:
        args.use_knowledge = "True"
    else:
        args.use_knowledge = "False"
    eval_datas = json.load(open(args.eval_data_path, 'r'))
    # eval_datas = eval_datas[:3]

    # '''for debug'''
    # datas = []
    # simple_num = 0
    # moderate_num = 0
    # challenging_num = 0
    # filter_num = 5
    # for i,data in enumerate(eval_datas):
    #     if simple_num== moderate_num==challenging_num== filter_num:
    #         break
    #     if eval(f"{data['difficulty']}_num") >= filter_num:
    #         continue
        
    #     if data['difficulty'] == 'simple':
    #         datas.append(data)
    #         simple_num += 1

    #     if data['difficulty'] == 'moderate':
    #         datas.append(data)
    #         moderate_num += 1

    #     if data['difficulty'] == 'challenging':
    #         datas.append(data)
    #         challenging_num += 1
    # import copy
    # eval_datas = copy.deepcopy(datas)
    # '''for debug'''
    if args.use_encoder:
        predicted_sql_path = generate_main_encoder(eval_datas, args)
        # predicted_sql_path = "table_related_benchmarks/text2sql/output/predict_dev.json"
    else:
        # predicted_sql_path = "table_related_benchmarks/text2sql/output/predict_dev.json"
        predicted_sql_path = generate_main(eval_datas, args)
    
    evaluation_main(args, eval_datas, predicted_sql_path)
 

if __name__ == '__main__':

    args_parser = argparse.ArgumentParser()
    args_parser.add_argument('--eval_type', type=str, choices=["ex"], default="ex")
    args_parser.add_argument('--eval_data_name', type=str, choices=["bird", "spider"], default="bird")
    args_parser.add_argument('--mode', type=str, choices=["dev", "test"], default="dev")
    args_parser.add_argument('--is_use_knowledge', default=False, action="store_true")
    args_parser.add_argument('--data_output_path', type=str, default="table_related_benchmarks/text2sql/output")
    args_parser.add_argument('--chain_of_thought', type=str, default="False")
    args_parser.add_argument('--model_path', type=str) # , required=True
    args_parser.add_argument('--gpus_num', type=int, default=1)
    args_parser.add_argument('--num_cpus', type=int, default=4)
    args_parser.add_argument('--meta_time_out', type=float, default=30.0)
    args_parser.add_argument('--use_encoder', default=False, action="store_true")
    args_parser.add_argument('--use_gpt_api', default=False, action="store_true")
    
    args = args_parser.parse_args()
    main(args)


    # CUDA_VISIBLE_DEVICES=6 python table_related_benchmarks/run_text2sql_eval.py --model_path /data4/sft_output/qwen2.5-7b-ins-0929/checkpoint-3200 --eval_data_name spider