wikty
update
2cb208a
metadata
license: apache-2.0

简介

该款自然语言生成 SQL 的模型(NL2SQL/Text2SQL)是以 replit-code-v1-3b 代码续写预训练模型为基础进行 LoRA 微调的,这里仅提供 LoRA 权重(大概 11M),推理时需要结合原始预训练模型一起使用,具体参考下文示例。

用法

NL2SQL 任务中输入参数含有用户查询文本+数据库表信息,目前按照以下格式拼接模型的输入文本:

# Table Allergy_Type , columns = [ Allergy , AllergyType ]
# Table Has_Allergy , columns = [ StuID , Allergy ]
# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]
# primary keys: [ Allergy_Type.Allergy , Student.StuID ]
# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]
# Create a query for question: 显示所有男生的学生ID。
query =

具体使用方法参考以下示例:

import sqlparse
import torch
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline

device = 'cuda'
base_model_path = 'replit/replit-code-v1-3b'
lora_model_path = 'DMetaSoul/nl2sql-chinese-standard-3b-lora'
sampling = False
tokenizer = AutoTokenizer.from_pretrained(base_model_path, 
    trust_remote_code=True, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(base_model_path,
    trust_remote_code=True, torch_dtype=torch.float16)
if lora_model_path:
    model = PeftModel.from_pretrained(model, lora_model_path,
        torch_dtype=torch.float16)
model.eval()
model.to(device)

input_texts = [
    "# Table Allergy_Type , columns = [ Allergy , AllergyType ]\n# Table Has_Allergy , columns = [ StuID , Allergy ]\n# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]\n# primary keys: [ Allergy_Type.Allergy , Student.StuID ]\n# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]\n# Create a query for question: 显示所有女学生的名、 姓氏、年龄。他们的性别是“女”.\nquery =",
    "# Table Allergy_Type , columns = [ Allergy , AllergyType ]\n# Table Has_Allergy , columns = [ StuID , Allergy ]\n# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]\n# primary keys: [ Allergy_Type.Allergy , Student.StuID ]\n# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]\n# Create a query for question: 显示所有男生的学生ID。\nquery =",
]
inputs = tokenizer(input_texts, max_length=512, return_tensors="pt",
    padding=True, truncation=True)
inputs = {k:v.to(device) for k,v in inputs.items()}

with torch.no_grad():
    if sampling:
        outputs = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.95,
            temperature=1.0, num_return_sequences=1, return_full_text=False,
            max_length=512, return_dict_in_generate=True, output_scores=True)
    else:
        outputs = model.generate(**inputs, num_beams=4, num_return_sequences=1, return_full_text=False
            max_length=512, return_dict_in_generate=True, output_scores=True)

output_ids = outputs.sequences
results = tokenizer.batch_decode(output_ids, skip_special_tokens=True,
            clean_up_tokenization_spaces=True)

for question, sql in zip(input_texts, results):
    print(question)
    print('SQL: {}'.format(sqlparse.format(sql, reindent=True)))