qianxiao1111's picture
upgrade: add benchmarks eval
2a26d3b
raw
history blame
13.1 kB
# 1. git clone -b v0.5.5-tablegpt-merged https://github.com/zTaoplus/vllm.git
# install tablegpt vllm
## apply diff file (recommended in case of use only)
# 1. pip install vllm==0.5.5
# 2. cd vllm
# 3. git diff 09c7792610ada9f88bbf87d32b472dd44bf23cc2 HEAD -- vllm | patch -p1 -d "$(pip show vllm | grep Location | awk '{print $2}')"
## build from source (dev recommended)
## Note: Building from source may take 10-30 minutes and requires access to GitHub or other repositories. Make sure to configure an HTTP/HTTPS proxy.
## cd vllm && pip install -e . [-v]. The -v flag is optional and can be used to display verbose logs.
# see https://github.com/zTaoplus/TableGPT-hf to view the model-related configs.
import os
import json
from tqdm import tqdm
from vllm import LLM
from vllm.sampling_params import SamplingParams
from transformers import AutoTokenizer
from text2sql.src.gpt_request import (
generate_schema_prompt,
generate_comment_prompt,
cot_wizard,
decouple_question_schema,
generate_sql_file,
parser_sql
) # text2sql.src.
import warnings
import random
# 忽略所有警告
warnings.filterwarnings("ignore")
# import pandas as pd
# from typing import Literal, List, Optional
# from io import StringIO
import sqlite3
# DEFAULT_SYS_MSG = "You are a helpful assistant."
# ENCODER_TYPE = "contrastive"
def get_table_info(db_path, enum_num=None):
# extract create ddls
'''
:param root_place:
:param db_name:
:return:
'''
full_schema_prompt_list = []
conn = sqlite3.connect(db_path)
# Create a cursor object
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
all_tables_info = []
# 表截断
# if len(tables) > 16:
# tables = random.sample(tables, 16)
for table in tables:
if table == 'sqlite_sequence':
continue
# 前几行枚举值
cur_table = "`{}`".format(table[0])
cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, enum_num))
row_ls = cursor.fetchall()
cursor.execute(f"PRAGMA table_info({cur_table});")
column_name_tp_ls = cursor.fetchall()
all_columns_info = []
for column_name_tp in column_name_tp_ls:
pos_id = column_name_tp[0] # 字段位置
col_name = column_name_tp[1] # 字段名
col_type = column_name_tp[2] # 字段类型
# 字段枚举值
contains_nan = False
enum_values = []
for row in row_ls:
value = row[pos_id]
if value is None:
contains_nan = True
enum_values.append(str(value))
if len(enum_values) == 0:
enum_values = ["None"]
single_columns_info = {
"name": col_name,
"dtype": col_type,
"values": enum_values,
"contains_nan": contains_nan,
"is_unique": False
}
all_columns_info.append(single_columns_info)
# 列截断
# if len(all_columns_info) > 32:
# all_columns_info = random.sample(all_columns_info, 32)
single_table_info = {"columns": all_columns_info}
all_tables_info.append(single_table_info)
return all_tables_info
def generate_combined_prompts_one_encoder(db_path, question, knowledge=None):
schema_prompt = generate_schema_prompt(db_path, num_rows=None) # This is the entry to collect values
comment_prompt = generate_comment_prompt(question, knowledge)
# encoder_prompt = get_encoder_prompt(table_info)
combined_prompts = schema_prompt + '\n\n' + comment_prompt + cot_wizard() + '\nSELECT '
return combined_prompts
def get_encoder_prompt(table_info):
encode_prompt = "".join(
f"table_{i} as follow:\n<TABLE_CONTENT>\n"
for i in range(len(table_info)))
return encode_prompt
def get_messages_one(db_path, question, knowledge=None):
table_info = get_table_info(db_path, enum_num=3) # 采用几行枚举值
prompt = generate_combined_prompts_one_encoder(db_path, question, knowledge=knowledge)
messages = [
{
"role": "system",
"content": "You are a helpul assistant."
}
]
content = []
for i in range(len(table_info)):
if i == len(table_info) -1:
# 最后一个
content.extend([
{
"type": "text",
"text": f"table_{i} as follow: \n",
},
{"type": "table", "table": table_info[i]},
{
"type": "text",
"text": prompt,
}
]
)
else:
content.extend([
{
"type": "text",
"text": f"table_{i} as follow: \n",
},
{"type": "table", "table": table_info[i]},
]
)
messages.append(
{
"role": "user",
"content": content
}
)
# print("*"*100)
# print(json.dumps(messages, ensure_ascii=False))
# print("*"*100)
# exit()
return messages
def calculate_table_num():
import os
db_dir = "/home/jyuan/LLM/evaluation/table_related_benchmarks/evalset/bird_data/dev_databases"
table_nums = []
cols_nums = []
for file in os.listdir(db_dir):
if "." in file:
continue
db_path = os.path.join(db_dir, file, f"{file}.sqlite")
# print(db_path)
conn = sqlite3.connect(db_path)
# Create a cursor object
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
table_num = len(tables)
table_nums.append(table_num)
for table in tables:
if table == 'sqlite_sequence':
continue
# cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(table[0]))
# create_prompt = cursor.fetchone()[0]
# schemas[table[0]] = create_prompt
# 前几行枚举值
cur_table = "`{}`".format(table[0])
cursor.execute(f"PRAGMA table_info({cur_table});")
column_name_tp_ls = cursor.fetchall()
if len(column_name_tp_ls) == 115:
print(db_path)
cols_nums.append(len(column_name_tp_ls))
cols_nums = sorted(cols_nums, reverse=True)
print(f"max table: {max(table_nums)}, max columns: {max(cols_nums)}")
print(cols_nums[:10])
def llm_generate_result_encoder(model_name_or_path, gpus_num, messages_ls):
# 批量推理
print("model", model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
print(
"load tokenizer {} from {} over.".format(
tokenizer.__class__, model_name_or_path
)
)
model = LLM(
model=model_name_or_path,
max_model_len=12000,
max_num_seqs=1,
dtype="bfloat16",
limit_mm_per_prompt={"table": 16},
gpu_memory_utilization=0.9,
tensor_parallel_size=gpus_num,
)
p = SamplingParams(temperature=0, max_tokens=1024)
outputs = model.chat(messages=messages_ls, sampling_params=p)
generated_res = []
for i, output in enumerate(tqdm(outputs)):
text = output.outputs[0].text
sql = parser_sql(text)
generated_res.append(sql)
return generated_res
def col_nums_max(message):
content = message[1]["content"]
table_nums = 0
col_nums_ls = []
for dic in content:
if dic["type"] == "table":
table_nums += 1
col_num = len(dic["table"]["columns"])
col_nums_ls.append(col_num)
return int(max(col_nums_ls) + 1), table_nums
def llm_generate_result_encoder_one(model_name_or_path, gpus_num, messages_ls):
# 单条推理
print("model", model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
print(
"load tokenizer {} from {} over.".format(
tokenizer.__class__, model_name_or_path
)
)
model = LLM(
model=model_name_or_path,
max_model_len=12000,
max_num_seqs=1,
dtype="bfloat16",
limit_mm_per_prompt={"table": 20},
gpu_memory_utilization=0.9
)
p = SamplingParams(temperature=0, max_tokens=1024)
error_ls = []
generated_res = []
for i, messages in enumerate(messages_ls):
# if i != 89:
# continue
try:
# max_col, table_nums = col_nums_max(messages)
# # print("="*100, max_col, table_nums)
# model.llm_engine.model_config.hf_config.encoder_config.max_cols = max_col
# model.llm_engine.model_config.multimodal_config.limit_per_prompt = {"table": table_nums}
outputs = model.chat(messages=messages, sampling_params=p)
text = outputs[0].outputs[0].text
sql = parser_sql(text)
except Exception as e:
error_ls.append(i)
sql = ""
generated_res.append(sql)
if len(error_ls) != 0:
json.dump({"error_id": error_ls}, open("table_related_benchmarks/text2sql/output/error_ls.json", 'w'), indent=4)
return generated_res
def collect_response_from_gpt_encoder(model_path, gpus_num, db_path_list, question_list, knowledge_list=None):
'''
:param db_path: str
:param question_list: []
:return: dict of responses collected from llm
'''
responses_dict = {}
response_list = []
messages_ls = []
for i in tqdm(range(len(question_list)), desc="get prompt"):
# print('--------------------- processing {}th question ---------------------'.format(i))
# print('the question is: {}'.format(question))
question = question_list[i]
db_path = db_path_list[i]
if knowledge_list:
messages = get_messages_one(db_path, question, knowledge=knowledge_list[i])
else:
messages = get_messages_one(db_path, question)
messages_ls.append(messages)
outputs_sql = llm_generate_result_encoder(model_path, gpus_num, messages_ls)
for i in tqdm(range(len(outputs_sql)), desc="postprocess result"):
question = question_list[i]
sql = outputs_sql[i]
db_id = db_path_list[i].split('/')[-1].split('.sqlite')[0]
sql = sql + '\t----- bird -----\t' + db_id # to avoid unpredicted \t appearing in codex results
response_list.append(sql)
return response_list
def generate_main_encoder(eval_data, args):
question_list, db_path_list, knowledge_list = decouple_question_schema(
datasets=eval_data, db_root_path=args.db_root_path)
assert len(question_list) == len(db_path_list) == len(knowledge_list)
if args.use_knowledge == 'True':
responses = collect_response_from_gpt_encoder(model_path=args.model_path, gpus_num=args.gpus_num, db_path_list=db_path_list, question_list=question_list, knowledge_list=knowledge_list)
else:
responses = collect_response_from_gpt_encoder(model_path=args.model_path, gpus_num=args.gpus_num, db_path_list=db_path_list, question_list=question_list, knowledge_list=None)
if args.chain_of_thought == 'True':
output_name = os.path.join(args.data_output_path, f'predict_{args.mode}_cot.json')
else:
output_name = os.path.join(args.data_output_path, f'predict_{args.mode}.json')
# pdb.set_trace()
generate_sql_file(sql_lst=responses, output_path=output_name)
print('successfully collect results from {} for {} evaluation; Use knowledge: {}; Use COT: {}; Use encoder: {}'.format(args.model_path, args.mode, args.use_knowledge, args.chain_of_thought, args.use_encoder))
print(f'output: {output_name}')
# 返回推理数据保存路径
return output_name
def test_single():
db_path = "/home/jyuan/LLM/evaluation/table_related_benchmarks/evalset/spider_data/test_database/aan_1/aan_1.sqlite"
question = "How many authors do we have?"
messages = get_messages_one(db_path, question, knowledge=None)
# print("*"*100)
# print(messages)
# print("*"*100)
model_name_or_path = "/data4/workspace/yss/models/longlin_encoder_model/contrastive"
model = LLM(
model=model_name_or_path,
max_model_len=8192,
max_num_seqs=16,
dtype="bfloat16",
limit_mm_per_prompt={"table": 10},
gpu_memory_utilization=0.9,
tensor_parallel_size=1,
)
p = SamplingParams(temperature=0, max_tokens=1024)
res = model.chat(messages=messages, sampling_params=p)
print(res[0].outputs[0].text)
if __name__ == "__main__":
# test_single()
calculate_table_num()