chronos / db.py
Manoj Kumar
Mark POhase 1
e6f4fec
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Example schema
schema = {
"products": {
"columns": ["product_id", "name", "price", "category_id"],
"relations": "category_id -> categories.id",
},
"categories": {
"columns": ["id", "category_name"],
"relations": None,
},
"orders": {
"columns": ["order_id", "customer_name", "product_id", "order_date"],
"relations": "product_id -> products.product_id",
},
}
# Step 1: Generate context dynamically from schema
def generate_context(schema):
context_lines = []
for table, details in schema.items():
# List table columns
columns = ", ".join(details["columns"])
context_lines.append(f"The {table} table has the following columns: {columns}.")
# Add relationships if present
if details["relations"]:
context_lines.append(f"The {table} table has the following relationship: {details['relations']}.")
return "\n".join(context_lines)
# Generate schema context
schema_context = generate_context(schema)
# Step 2: Load the T5-base-text-to-sql model
model_name = "suriya7/t5-base-text-to-sql" # A model fine-tuned for SQL generation
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Step 3: Define a natural language query
user_query = "List all orders where the product price is greater than 50."
# Prepare the input for the model
# Adjust the prompt to focus on SQL generation
input_text = f"Convert the following question into an SQL query:\nSchema:\n{schema_context}\n\nQuestion:\n{user_query}"
inputs = tokenizer.encode(input_text, return_tensors="pt")
# Step 4: Generate SQL query
outputs = model.generate(inputs, max_length=128, num_beams=4, early_stopping=True)
generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Step 5: Display the result
print("User Query:", user_query)
print("Generated SQL Query:", generated_sql)