rishdotblog commited on
Commit
3741f67
·
1 Parent(s): ed1b417

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +65 -0
inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+ import argparse
4
+
5
+ def generate_prompt(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
6
+ with open(prompt_file, "r") as f:
7
+ prompt = f.read()
8
+
9
+ with open(metadata_file, "r") as f:
10
+ table_metadata_string = f.read()
11
+
12
+ prompt = prompt.format(
13
+ user_question=question, table_metadata_string=table_metadata_string
14
+ )
15
+ return prompt
16
+
17
+
18
+ def get_tokenizer_model(model_name):
19
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ trust_remote_code=True,
23
+ torch_dtype=torch.float16,
24
+ device_map="auto",
25
+ use_cache=True,
26
+ )
27
+ return tokenizer, model
28
+
29
+ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
30
+ tokenizer, model = get_tokenizer_model("defog/sqlcoder")
31
+ prompt = generate_prompt(question, prompt_file, metadata_file)
32
+
33
+ # make sure the model stops generating at triple ticks
34
+ eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
35
+ pipe = pipeline(
36
+ "text-generation",
37
+ model=model,
38
+ tokenizer=tokenizer,
39
+ max_new_tokens=300,
40
+ do_sample=False,
41
+ num_beams=5, # do beam search with 5 beams for high quality results
42
+ )
43
+ generated_query = (
44
+ pipe(
45
+ prompt,
46
+ num_return_sequences=1,
47
+ eos_token_id=eos_token_id,
48
+ pad_token_id=eos_token_id,
49
+ )[0]["generated_text"]
50
+ .split("```sql")[-1]
51
+ .split("```")[0]
52
+ .split(";")[0]
53
+ .strip()
54
+ + ";"
55
+ )
56
+ return generated_query
57
+
58
+ if __name__ == "__main__":
59
+ # Parse arguments
60
+ parser = argparse.ArgumentParser(description="Run inference on a question")
61
+ parser.add_argument("-q","--question", type=str, help="Question to run inference on")
62
+ args = parser.parse_args()
63
+ question = args.question
64
+ print("Loading a model and generating a SQL query for answering your question...")
65
+ print(run_inference(question))