Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import pandas as pd | |
import sqlite3 | |
from langchain import OpenAI, LLMChain, PromptTemplate | |
import sqlparse | |
import logging | |
# Initialize conversation history | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
# OpenAI API key (ensure it is securely stored) | |
# You can set the API key in your environment variables or a .env file | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
# Check if the API key is set | |
if not openai_api_key: | |
st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.") | |
st.stop() | |
# Step 1: Upload CSV data file (or use default) | |
st.title("Natural Language to SQL Query App with Enhanced Insights") | |
st.write("Upload a CSV file to get started, or use the default dataset.") | |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
if csv_file is None: | |
data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory | |
st.write("Using default_data.csv file.") | |
table_name = "default_table" | |
else: | |
data = pd.read_csv(csv_file) | |
table_name = csv_file.name.split('.')[0] | |
st.write(f"Data Preview ({csv_file.name}):") | |
st.dataframe(data.head()) | |
# Step 2: Load CSV data into a persistent SQLite database | |
db_file = 'my_database.db' | |
conn = sqlite3.connect(db_file) | |
data.to_sql(table_name, conn, index=False, if_exists='replace') | |
# SQL table metadata (for validation and schema) | |
valid_columns = list(data.columns) | |
st.write(f"Valid columns: {valid_columns}") | |
# Step 3: Set up the LLM Chains | |
# SQL Generation Chain | |
sql_template = """ | |
You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question. | |
Ensure that: | |
- You only use the columns provided. | |
- When performing string comparisons in the WHERE clause, make them case-insensitive by using 'COLLATE NOCASE' or the LOWER() function. | |
- Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column. | |
- Do not apply 'COLLATE NOCASE' to numeric columns. | |
If the question is vague or open-ended and does not pertain to specific data retrieval, respond with "NO_SQL" to indicate that a SQL query should not be generated. | |
Question: {question} | |
Table name: {table_name} | |
Valid columns: {columns} | |
SQL Query: | |
""" | |
sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns']) | |
llm = OpenAI(temperature=0, openai_api_key=openai_api_key) | |
sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt) | |
# Insights Generation Chain | |
insights_template = """ | |
You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise and informative analysis that includes data insights and actionable recommendations. | |
User's Question: {question} | |
SQL Query Result: | |
{result} | |
Analysis and Recommendations: | |
""" | |
insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result']) | |
insights_chain = LLMChain(llm=llm, prompt=insights_prompt) | |
# General Insights and Recommendations Chain | |
general_insights_template = """ | |
You are an expert data scientist. Based on the entire dataset provided below, generate a comprehensive analysis that includes key insights and actionable recommendations. | |
Dataset Summary: | |
{dataset_summary} | |
Analysis and Recommendations: | |
""" | |
general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary']) | |
general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt) | |
# Optional: Clean up function to remove incorrect COLLATE NOCASE usage | |
def clean_sql_query(query): | |
"""Removes incorrect usage of COLLATE NOCASE from the SQL query.""" | |
parsed = sqlparse.parse(query) | |
statements = [] | |
for stmt in parsed: | |
tokens = [] | |
idx = 0 | |
while idx < len(stmt.tokens): | |
token = stmt.tokens[idx] | |
if (token.ttype is sqlparse.tokens.Keyword and token.value.upper() == 'COLLATE'): | |
# Check if the next token is 'NOCASE' | |
next_token = stmt.tokens[idx + 2] if idx + 2 < len(stmt.tokens) else None | |
if next_token and next_token.value.upper() == 'NOCASE': | |
# Skip 'COLLATE' and 'NOCASE' tokens | |
idx += 3 # Skip 'COLLATE', whitespace, 'NOCASE' | |
continue | |
tokens.append(token) | |
idx += 1 | |
statements.append(''.join([str(t) for t in tokens])) | |
return ' '.join(statements) | |
# Function to classify user query | |
def classify_query(question): | |
"""Classify the user query as either 'SQL' or 'INSIGHTS'.""" | |
classification_template = """ | |
You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries. | |
Determine the appropriate category for the following user question. | |
Question: "{question}" | |
Category (SQL/INSIGHTS): | |
""" | |
classification_prompt = PromptTemplate(template=classification_template, input_variables=['question']) | |
classification_chain = LLMChain(llm=llm, prompt=classification_prompt) | |
category = classification_chain.run({'question': question}).strip().upper() | |
if category.startswith('SQL'): | |
return 'SQL' | |
else: | |
return 'INSIGHTS' | |
# Function to generate dataset summary | |
def generate_dataset_summary(data): | |
"""Generate a summary of the dataset for general insights.""" | |
summary_template = """ | |
You are an expert data scientist. Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features. | |
Dataset: | |
{data} | |
Dataset Summary: | |
""" | |
summary_prompt = PromptTemplate(template=summary_template, input_variables=['data']) | |
summary_chain = LLMChain(llm=llm, prompt=summary_prompt) | |
summary = summary_chain.run({'data': data.head().to_string(index=False)}) | |
return summary | |
# Define the callback function | |
def process_input(): | |
user_prompt = st.session_state['user_input'] | |
if user_prompt: | |
try: | |
# Append user message to history | |
st.session_state.history.append({"role": "user", "content": user_prompt}) | |
# Classify the user query | |
category = classify_query(user_prompt) | |
logging.info(f"User query classified as: {category}") | |
if "COLUMNS" in user_prompt.upper(): | |
assistant_response = f"The columns are: {', '.join(valid_columns)}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
elif category == 'SQL': | |
columns = ', '.join(valid_columns) | |
generated_sql = sql_generation_chain.run({ | |
'question': user_prompt, | |
'table_name': table_name, | |
'columns': columns | |
}).strip() | |
if generated_sql.upper() == "NO_SQL": | |
# Handle cases where no SQL should be generated | |
assistant_response = "Sure, let's discuss some general insights and recommendations based on the data." | |
# Generate dataset summary | |
dataset_summary = generate_dataset_summary(data) | |
# Generate general insights and recommendations | |
general_insights = general_insights_chain.run({ | |
'dataset_summary': dataset_summary | |
}) | |
# Append the assistant's insights to the history | |
st.session_state.history.append({"role": "assistant", "content": general_insights}) | |
else: | |
# Clean the SQL query | |
cleaned_sql = clean_sql_query(generated_sql) | |
logging.info(f"Generated SQL Query: {cleaned_sql}") | |
# Attempt to execute SQL query and handle exceptions | |
try: | |
result = pd.read_sql_query(cleaned_sql, conn) | |
if result.empty: | |
assistant_response = "The query returned no results. Please try a different question." | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
else: | |
# Convert the result to a string for the insights prompt | |
result_str = result.head(10).to_string(index=False) # Limit to first 10 rows | |
# Generate insights and recommendations based on the query result | |
insights = insights_chain.run({ | |
'question': user_prompt, | |
'result': result_str | |
}) | |
# Append the assistant's insights to the history | |
st.session_state.history.append({"role": "assistant", "content": insights}) | |
# Append the result DataFrame to the history | |
st.session_state.history.append({"role": "assistant", "content": result}) | |
except Exception as e: | |
logging.error(f"An error occurred during SQL execution: {e}") | |
assistant_response = f"Error executing SQL query: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
else: # INSIGHTS category | |
# Generate dataset summary | |
dataset_summary = generate_dataset_summary(data) | |
# Generate general insights and recommendations | |
general_insights = general_insights_chain.run({ | |
'dataset_summary': dataset_summary | |
}) | |
# Append the assistant's insights to the history | |
st.session_state.history.append({"role": "assistant", "content": general_insights}) | |
except Exception as e: | |
logging.error(f"An error occurred: {e}") | |
assistant_response = f"Error: {e}" | |
st.session_state.history.append({"role": "assistant", "content": assistant_response}) | |
# Reset the user_input in session state | |
st.session_state['user_input'] = '' | |
# Display the conversation history | |
for message in st.session_state.history: | |
if message['role'] == 'user': | |
st.markdown(f"**User:** {message['content']}") | |
elif message['role'] == 'assistant': | |
if isinstance(message['content'], pd.DataFrame): | |
st.markdown("**Assistant:** Query Results:") | |
st.dataframe(message['content']) | |
else: | |
st.markdown(f"**Assistant:** {message['content']}") | |
# Place the input field at the bottom with the callback | |
st.text_input("Enter your message:", key='user_input', on_change=process_input) | |