|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
## 简介 |
|
|
|
该款自然语言生成 SQL 的模型(NL2SQL/Text2SQL)是以 [replit-code-v1-3b](https://huggingface.co/replit/replit-code-v1-3b) 代码续写预训练模型为基础进行 LoRA 微调的,这里仅提供 LoRA 权重(大概 11M),推理时需要结合原始预训练模型一起使用,具体参考下文示例。 |
|
|
|
## 用法 |
|
|
|
NL2SQL 任务中输入参数含有用户查询文本+数据库表信息,目前按照以下格式拼接模型的输入文本: |
|
|
|
``` |
|
# Table Allergy_Type , columns = [ Allergy , AllergyType ] |
|
# Table Has_Allergy , columns = [ StuID , Allergy ] |
|
# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ] |
|
# primary keys: [ Allergy_Type.Allergy , Student.StuID ] |
|
# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ] |
|
# Create a query for question: 显示所有男生的学生ID。 |
|
query = |
|
``` |
|
|
|
具体使用方法参考以下示例: |
|
|
|
```python |
|
import sqlparse |
|
import torch |
|
from peft import PeftModel |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline |
|
|
|
device = 'cuda' |
|
base_model_path = 'replit/replit-code-v1-3b' |
|
lora_model_path = 'DMetaSoul/nl2sql-chinese-standard-3b-lora' |
|
sampling = False |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_path, |
|
trust_remote_code=True, padding_side='left') |
|
model = AutoModelForCausalLM.from_pretrained(base_model_path, |
|
trust_remote_code=True, torch_dtype=torch.float16) |
|
if lora_model_path: |
|
model = PeftModel.from_pretrained(model, lora_model_path, |
|
torch_dtype=torch.float16) |
|
model.eval() |
|
model.to(device) |
|
|
|
input_texts = [ |
|
"# Table Allergy_Type , columns = [ Allergy , AllergyType ]\n# Table Has_Allergy , columns = [ StuID , Allergy ]\n# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]\n# primary keys: [ Allergy_Type.Allergy , Student.StuID ]\n# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]\n# Create a query for question: 显示所有女学生的名、 姓氏、年龄。他们的性别是“女”.\nquery =", |
|
"# Table Allergy_Type , columns = [ Allergy , AllergyType ]\n# Table Has_Allergy , columns = [ StuID , Allergy ]\n# Table Student , columns = [ StuID , LName , Fname , Age , Sex , Major , Advisor , city_code ]\n# primary keys: [ Allergy_Type.Allergy , Student.StuID ]\n# foreign keys: [ Has_Allergy.Allergy = Allergy_Type.Allergy , Has_Allergy.StuID = Student.StuID ]\n# Create a query for question: 显示所有男生的学生ID。\nquery =", |
|
] |
|
inputs = tokenizer(input_texts, max_length=512, return_tensors="pt", |
|
padding=True, truncation=True) |
|
inputs = {k:v.to(device) for k,v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
if sampling: |
|
outputs = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.95, |
|
temperature=1.0, num_return_sequences=1, return_full_text=False, |
|
max_length=512, return_dict_in_generate=True, output_scores=True) |
|
else: |
|
outputs = model.generate(**inputs, num_beams=4, num_return_sequences=1, return_full_text=False |
|
max_length=512, return_dict_in_generate=True, output_scores=True) |
|
|
|
output_ids = outputs.sequences |
|
results = tokenizer.batch_decode(output_ids, skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True) |
|
|
|
for question, sql in zip(input_texts, results): |
|
print(question) |
|
print('SQL: {}'.format(sqlparse.format(sql, reindent=True))) |
|
|
|
``` |
|
|