Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import sqlite3 | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
import re | |
import gradio as gr | |
# Load the Llama model and tokenizer | |
model_name = "meta-llama/Llama-3.3-70B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
# Initialize database connection | |
db_path = "Spring_2025_courses.db" | |
connection = sqlite3.connect(db_path) | |
def get_schema(): | |
"""Retrieve database schema""" | |
cursor = connection.cursor() | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
tables = cursor.fetchall() | |
schema = {} | |
for table_name in tables: | |
table_name = table_name[0] | |
cursor.execute(f"PRAGMA table_info({table_name});") | |
columns = cursor.fetchall() | |
schema[table_name] = [column[1] for column in columns] | |
return schema | |
def run_query(query): | |
"""Execute SQL query""" | |
cursor = connection.cursor() | |
cursor.execute(query) | |
return cursor.fetchall() | |
# Prompt templates | |
system_prompt = """ | |
You are a SQLite expert. Given an input question, create one syntactically correct SQLite query to run. Generate only one query. No preamble. | |
Here is the relevant table information: | |
schema: {schema} | |
Tips: | |
- Use LIKE instead of = in the queries | |
Write only one SQLite query that would answer the user's question. | |
""" | |
human_prompt = """Based on the table schema below, write a SQL query that would answer the user's question: | |
{schema} | |
Question: {question} | |
SQL Query:""" | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_prompt), | |
("human", human_prompt), | |
]) | |
# Build query generation chain | |
sql_generator = ( | |
RunnablePassthrough.assign(schema=get_schema) | |
| prompt | |
| StrOutputParser() | |
) | |
def generate_sql(question): | |
"""Generate SQL query from question""" | |
schema = get_schema() | |
input_prompt = system_prompt.format(schema=schema, question=question) | |
response = generator(input_prompt, max_length=512, num_return_sequences=1) | |
return response[0]['generated_text'] | |
def execute_safe_query(question): | |
"""Safely execute a natural language query""" | |
try: | |
# Generate SQL query | |
sql_query = generate_sql(question) | |
# Validate SQL query | |
if not sql_query.strip().lower().startswith("select"): | |
return {"error": "Only SELECT queries are allowed.", "query": sql_query, "result": None} | |
# Execute query | |
result = run_query(sql_query) | |
return {"error": None, "query": sql_query, "result": result} | |
except Exception as e: | |
return {"error": str(e), "query": None, "result": None} | |
# Deploy using Gradio | |
def query_interface(question): | |
response = execute_safe_query(question) | |
if response['error']: | |
return f"Error: {response['error']}\nGenerated Query: {response['query']}" | |
return f"Query: {response['query']}\nResult: {response['result']}" | |
iface = gr.Interface( | |
fn=query_interface, | |
inputs="text", | |
outputs="text", | |
title="SQLite Query Generator with Llama 3.3", | |
description="Ask a natural language question about the Spring 2025 courses database and get the SQL query and results.", | |
) | |
if __name__ == "__main__": | |
iface.launch() |