|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
schema = { |
|
"products": { |
|
"columns": ["product_id", "name", "price", "category_id"], |
|
"relations": "category_id -> categories.id", |
|
}, |
|
"categories": { |
|
"columns": ["id", "category_name"], |
|
"relations": None, |
|
}, |
|
"orders": { |
|
"columns": ["order_id", "customer_name", "product_id", "order_date"], |
|
"relations": "product_id -> products.product_id", |
|
}, |
|
} |
|
|
|
|
|
def generate_context(schema): |
|
context_lines = [] |
|
for table, details in schema.items(): |
|
|
|
columns = ", ".join(details["columns"]) |
|
context_lines.append(f"The {table} table has the following columns: {columns}.") |
|
|
|
|
|
if details["relations"]: |
|
context_lines.append(f"The {table} table has the following relationship: {details['relations']}.") |
|
|
|
return "\n".join(context_lines) |
|
|
|
|
|
schema_context = generate_context(schema) |
|
|
|
|
|
model_name = "suriya7/t5-base-text-to-sql" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
user_query = "List all orders where the product price is greater than 50." |
|
|
|
|
|
|
|
input_text = f"Convert the following question into an SQL query:\nSchema:\n{schema_context}\n\nQuestion:\n{user_query}" |
|
inputs = tokenizer.encode(input_text, return_tensors="pt") |
|
|
|
|
|
outputs = model.generate(inputs, max_length=128, num_beams=4, early_stopping=True) |
|
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
print("User Query:", user_query) |
|
print("Generated SQL Query:", generated_sql) |
|
|