sahilmayekar commited on
Commit
258afbb
·
1 Parent(s): 3f95da6
Files changed (1) hide show
  1. app.py +77 -2
app.py CHANGED
@@ -1,4 +1,79 @@
1
  import streamlit as st
 
 
 
2
 
3
- x = st.slider('Select a value2')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ import argparse
5
 
6
+ def generate_prompt(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
7
+ with open(prompt_file, "r") as f:
8
+ prompt = f.read()
9
+
10
+ with open(metadata_file, "r") as f:
11
+ table_metadata_string = f.read()
12
+
13
+ prompt = prompt.format(
14
+ user_question=question, table_metadata_string=table_metadata_string
15
+ )
16
+ return prompt
17
+
18
+
19
+ def get_tokenizer_model(model_name):
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_name,
23
+ trust_remote_code=True,
24
+ torch_dtype=torch.float16,
25
+ device_map="auto",
26
+ use_cache=True,
27
+ )
28
+ return tokenizer, model
29
+
30
+ def run_inference(question, prompt_file="prompt.md", metadata_file="metadata.sql"):
31
+ tokenizer, model = get_tokenizer_model("defog/sqlcoder-7b-2")
32
+ prompt = generate_prompt(question, prompt_file, metadata_file)
33
+
34
+ # make sure the model stops generating at triple ticks
35
+ # eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
36
+ eos_token_id = tokenizer.eos_token_id
37
+ pipe = pipeline(
38
+ "text-generation",
39
+ model=model,
40
+ tokenizer=tokenizer,
41
+ max_new_tokens=300,
42
+ do_sample=False,
43
+ return_full_text=False, # added return_full_text parameter to prevent splitting issues with prompt
44
+ num_beams=5, # do beam search with 5 beams for high quality results
45
+ )
46
+ generated_query = (
47
+ pipe(
48
+ prompt,
49
+ num_return_sequences=1,
50
+ eos_token_id=eos_token_id,
51
+ pad_token_id=eos_token_id,
52
+ )[0]["generated_text"]
53
+ .split(";")[0]
54
+ .split("```")[0]
55
+ .strip()
56
+ + ";"
57
+ )
58
+ return generated_query
59
+
60
+ def main():
61
+ st.title("SQLCoder App")
62
+ st.sidebar.title("Input Question")
63
+ question = st.sidebar.text_area("Enter your question here", height=200)
64
+ if st.sidebar.button("Generate SQL Query"):
65
+ st.spinner("Generating SQL query...")
66
+ generated_query = run_inference(question)
67
+ st.success("SQL query generated successfully:")
68
+ st.code(generated_query, language="sql")
69
+
70
+ if __name__ == "__main__":
71
+ # Parse arguments
72
+ _default_question="Do we get more sales from customers in New York compared to customers in San Francisco? Give me the total sales for each city, and the difference between the two."
73
+ parser = argparse.ArgumentParser(description="Run inference on a question")
74
+ parser.add_argument("-q","--question", type=str, default=_default_question, help="Question to run inference on")
75
+ args = parser.parse_args()
76
+ question = args.question
77
+ print("Loading a model and generating a SQL query for answering your question...")
78
+ print(run_inference(question))
79
+ main()