|
import json |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
db_schema = { |
|
"products": ["product_id", "name", "price", "description", "type"], |
|
"orders": ["order_id", "product_id", "quantity", "order_date"], |
|
"customers": ["customer_id", "name", "email", "phone_number"] |
|
} |
|
|
|
|
|
model_name = "EleutherAI/gpt-neo-2.7B" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) |
|
|
|
def generate_sql_query(context, question): |
|
""" |
|
This is the description of the database which is given to you, a user can ask |
|
anything related to this database |
|
|
|
Args: |
|
context (str): Description of the database schema or table relationships. |
|
question (str): User's natural language query. |
|
|
|
Returns: |
|
str: An answer to the question. |
|
""" |
|
|
|
prompt = f""" |
|
Context: {context} |
|
|
|
Question: {question} |
|
Query: |
|
""" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
print("Prompt Sent to Model:") |
|
print(prompt) |
|
|
|
|
|
output = model.generate(inputs.input_ids, max_length=512, num_beams=5, early_stopping=True) |
|
query = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
sql_query = query.split("Query:")[-1].strip() |
|
return sql_query |
|
|
|
|
|
schema_description = json.dumps(db_schema, indent=4) |
|
|
|
|
|
questions = [ |
|
"describe the product table for me, what kind of data it is storing and all" |
|
] |
|
|
|
for user_question in questions: |
|
print(f"Question: {user_question}") |
|
sql_query = generate_sql_query(schema_description, user_question) |
|
print(f"Generated SQL Query:\n{sql_query}\n") |