import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch from functools import lru_cache import json import mysql.connector from mysql.connector import Error import os import sys from datetime import datetime import time import logging import threading # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', ) # Enable GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Database configuration DB_CONFIG = { 'host': 'sql12.freemysqlhosting.net', 'database': 'sql12740625', 'user': 'sql12740625', 'password': 'QGG9kdrE4g', 'port': 3306, 'pool_size': 5, 'pool_reset_session': True } # Global variables for model and tokenizer GLOBAL_MODEL = None GLOBAL_TOKENIZER = None db_connection_status = False def initialize_model(): """Initialize model and tokenizer globally""" global GLOBAL_MODEL, GLOBAL_TOKENIZER logging.info("Initializing model and tokenizer...") st.write("Initializing model and tokenizer...") start_time = time.time() model_name_sql = "premai-io/prem-1B-SQL" GLOBAL_TOKENIZER = AutoTokenizer.from_pretrained(model_name_sql) GLOBAL_MODEL = AutoModelForCausalLM.from_pretrained( model_name_sql, torch_dtype=torch.float32, # Use float32 for CPU ).to(device) # Set model to evaluation mode GLOBAL_MODEL.eval() logging.info(f"Model initialization took {time.time() - start_time:.2f} seconds") def test_db_connection(): """Test database connection with timeout""" global db_connection_status try: logging.info("Testing database connection...") connection = mysql.connector.connect( **DB_CONFIG, connect_timeout=10 ) if connection.is_connected(): db_info = connection.get_server_info() cursor = connection.cursor() cursor.execute("SELECT DATABASE();") db_name = cursor.fetchone()[0] cursor.close() connection.close() db_connection_status = True logging.info(f"Successfully connected to MySQL Server version {db_info} - Database: {db_name}") return True, f"Successfully connected to MySQL Server version {db_info}\nDatabase: {db_name}" except Error as e: db_connection_status = False logging.error(f"Error connecting to MySQL database: {e}") return False, f"Error connecting to MySQL database: {e}" return False, "Unable to establish database connection" def get_db_connection(): """Get database connection from pool""" logging.info("Getting database connection from pool...") return mysql.connector.connect(**DB_CONFIG) def execute_query(query): """Execute SQL query with timeout and connection pooling""" logging.info(f"Executing query: {query}") connection = None try: connection = get_db_connection() cursor = connection.cursor(dictionary=True, buffered=True) cursor.execute(query) results = cursor.fetchall() logging.info(f"Query executed successfully, retrieved {len(results)} records.") return results except Error as e: logging.error(f"Error executing query: {e}") return f"Error executing query: {e}" finally: if connection and connection.is_connected(): cursor.close() connection.close() logging.info("Database connection closed.") def generate_sql(natural_language_query): """Generate SQL query with performance optimizations""" logging.info(f"Generating SQL for query: {natural_language_query}") try: start_time = time.time() schema_info = """ CREATE TABLE sales ( pizza_id DECIMAL(8,2) PRIMARY KEY, order_id DECIMAL(8,2), pizza_name_id VARCHAR(14), quantity DECIMAL(4,2), order_date DATE, order_time VARCHAR(8), unit_price DECIMAL(5,2), total_price DECIMAL(5,2), pizza_size VARCHAR(3), pizza_category VARCHAR(7), pizza_ingredients VARCHAR(97), pizza_name VARCHAR(42) ); """ prompt = f"""### Task: Generate a SQL query to answer the following question. ### Database Schema: {schema_info} ### Question: {natural_language_query} ### SQL Query:""" inputs = GLOBAL_TOKENIZER( prompt, return_tensors="pt", padding=True, truncation=True, max_length=512, return_attention_mask=True ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = GLOBAL_MODEL.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=256, temperature=0.1, do_sample=True, top_p=0.95, num_return_sequences=1, pad_token_id=GLOBAL_TOKENIZER.eos_token_id, ) generated_query = GLOBAL_TOKENIZER.decode(outputs[0], skip_special_tokens=True) sql_query = generated_query.split("### SQL Query:")[-1].strip() logging.info(f"SQL generation took {time.time() - start_time:.2f} seconds") return sql_query except Exception as e: logging.error(f"Error generating SQL query: {str(e)}") return f"Error generating SQL query: {str(e)}" def format_result(query_result): """Format query results efficiently""" if isinstance(query_result, str) and "Error" in query_result: logging.warning(f"Query result contains an error: {query_result}") return query_result if not query_result: logging.info("No results found.") return "No results found." # Use list comprehension for better performance if len(query_result) == 1: return "\n".join(f"{k}: {v}" for k, v in query_result[0].items()) results = [f"Found {len(query_result)} results:\n"] for i, row in enumerate(query_result[:5], 1): results.append(f"Result {i}:") results.extend(f"{k}: {v}" for k, v in row.items()) results.append("") if len(query_result) > 5: results.append(f"(Showing first 5 of {len(query_result)} results)") return "\n".join(results) def check_live_connection(): """Check the database connection status periodically.""" while True: test_db_connection() time.sleep(10) # Check every 10 seconds def main(): """Main function with Streamlit UI components""" st.title("Natural Language to SQL Query") st.write("Ask questions about pizza sales data in plain English.") # Start the live connection check in a separate thread threading.Thread(target=check_live_connection, daemon=True).start() # Test and display database connection status if db_connection_status: st.success("Database connection is live.") else: st.error("Database connection is not live.") # Initialize model initialize_model() # Input field for natural language query natural_language_query = st.text_input("Enter your question", placeholder="e.g., What were the total sales for each pizza category?") if st.button("Generate and Execute Query"): if natural_language_query: # Generate SQL query sql_query = generate_sql(natural_language_query) st.write("Generated SQL Query:", sql_query) # Execute the generated query query_result = execute_query(sql_query) formatted_result = format_result(query_result) st.write("Query Result:") st.code(json.dumps(query_result, indent=2)) st.write("Human-Readable Response:") st.text(formatted_result) else: logging.warning("User did not enter a query.") st.write("Please enter a query.") if __name__ == "__main__": main()