Gonalb's picture
init commit
05e3517
raw
history blame
14.5 kB
import chainlit as cl
import pandas as pd
import time
from typing import Dict, Any
from agents.table_selection import table_selection_agent
from agents.data_retrieval import sample_data_retrieval_agent
from agents.sql_generation import sql_generation_agent
from agents.validation import query_validation_and_optimization
from agents.execution import execution_agent
from utils.bigquery_utils import init_bigquery_connection
from utils.feedback_utils import save_feedback_to_bigquery
@cl.on_chat_start
async def on_chat_start():
"""Initialize the chat session."""
# Initialize BigQuery client
client = init_bigquery_connection()
# Store the client in the user session
cl.user_session.set("client", client)
# Send a welcome message
await cl.Message(
content="👋 Welcome to the Natural Language to SQL Query Assistant! Ask me any question about your e-commerce data.",
author="SQL Assistant"
).send()
# Add some example questions without using actions
await cl.Message(
content="Here are some example questions you can ask:",
author="SQL Assistant"
).send()
examples = [
"What are the top 5 products by revenue?",
"How many orders were placed in the last month?",
"Which customers spent the most in 2023?",
"What is the average order value by product category?"
]
# Display all examples in a single message
examples_text = "\n\n".join([f"• {example}" for example in examples])
examples_text += "\n\n(You can copy and paste any of these examples to try them out)"
await cl.Message(
content=examples_text,
author="SQL Assistant"
).send()
@cl.on_message
async def on_message(message: cl.Message):
"""Handle user messages."""
query = message.content
# Check if we're in "awaiting feedback" mode
awaiting_feedback = cl.user_session.get("awaiting_feedback", False)
if awaiting_feedback:
client = cl.user_session.get("client")
original_query = cl.user_session.get("original_query")
generated_sql = cl.user_session.get("generated_sql")
optimized_sql = cl.user_session.get("optimized_sql")
# Save the detailed feedback
feedback_details = f"negative: {query}"
success = save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
feedback_details
)
# Reset the awaiting feedback flag
cl.user_session.set("awaiting_feedback", False)
if success:
await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
else:
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
return
# If not in feedback mode, process as a regular query
# Get the BigQuery client from the user session
client = cl.user_session.get("client")
# Store the original query in the user session for feedback
cl.user_session.set("original_query", query)
# Send a thinking message
thinking_msg = await cl.Message(content="🤔 Thinking...", author="SQL Assistant").send()
try:
# Step 1: Analyze relevant tables
thinking_msg.content = "🔍 Analyzing relevant tables..."
await thinking_msg.update()
# Initialize the state with the query
state = {"sql_query": query, "client": client}
tables_state = table_selection_agent(state)
relevant_tables = tables_state.get("relevant_tables", [])
# Send the tables analysis with a slight delay for better UX
await cl.sleep(1)
if relevant_tables:
tables_text = "I've identified these relevant tables for your query:\n\n"
tables_text += "\n".join([f"- `{table}`" for table in relevant_tables])
await cl.Message(content=tables_text, author="SQL Assistant").send()
# Step 2: Retrieve sample data
thinking_msg.content = "📊 Retrieving sample data..."
await thinking_msg.update()
await cl.sleep(1)
# Update state with relevant tables and get sample data
state.update(tables_state)
sample_data_state = sample_data_retrieval_agent(state)
# Step 3: Generate SQL
thinking_msg.content = "💻 Generating SQL query..."
await thinking_msg.update()
await cl.sleep(1)
# Update state with sample data and generate SQL
state.update(sample_data_state)
sql_state = sql_generation_agent(state)
generated_sql = sql_state.get("generated_sql", "No SQL generated")
# Store the generated SQL in the user session
cl.user_session.set("generated_sql", generated_sql)
# Send the generated SQL
await cl.Message(
content=f"Here's the SQL query I generated:\n\n```sql\n{generated_sql}\n```",
author="SQL Assistant"
).send()
# Step 4: Optimize SQL
thinking_msg.content = "🔧 Optimizing the query..."
await thinking_msg.update()
await cl.sleep(1)
# Update state with generated SQL and optimize
state.update(sql_state)
optimization_state = query_validation_and_optimization(state)
optimized_sql = optimization_state.get("optimized_sql", "No optimized SQL")
# Store the optimized SQL in the user session
cl.user_session.set("optimized_sql", optimized_sql)
# Send the optimized SQL
await cl.Message(
content=f"Here's the optimized version of the query:\n\n```sql\n{optimized_sql}\n```",
author="SQL Assistant"
).send()
# Step 5: Execute query
thinking_msg.content = "⚙️ Executing query..."
await thinking_msg.update()
await cl.sleep(1)
# Update state with optimized SQL and execute
state.update(optimization_state)
execution_state = execution_agent(state)
execution_result = execution_state.get("execution_result", {})
# Format and send the results
if isinstance(execution_result, dict) and "error" in execution_result:
error_msg = execution_result.get("error", "Unknown error occurred")
await cl.Message(
content=f"❌ Error executing query: {error_msg}",
author="SQL Assistant"
).send()
elif not execution_result:
await cl.Message(
content="✅ Query executed successfully but returned no results.",
author="SQL Assistant"
).send()
else:
try:
# Convert results to DataFrame for better display
if isinstance(execution_result[0], tuple):
# Try to get column names from BigQuery schema
try:
# Get the schema from the query job
query_job = client.query(optimized_sql)
schema = query_job.result().schema
column_names = [field.name for field in schema]
# Use these column names for the DataFrame
df = pd.DataFrame(execution_result, columns=column_names)
except Exception:
# Fallback to generic column names
columns = [f"Column_{i}" for i in range(len(execution_result[0]))]
df = pd.DataFrame(execution_result, columns=columns)
else:
df = pd.DataFrame(execution_result)
# Display the DataFrame as a table
await cl.Message(
content="✅ Query executed successfully! Here are the results:",
author="SQL Assistant"
).send()
# Send the DataFrame as an element
elements = [cl.Dataframe(data=df)]
await cl.Message(content="", elements=elements, author="SQL Assistant").send()
# Also provide a summary of the results with feedback buttons
num_rows = len(df)
num_cols = len(df.columns)
# Ask for feedback using AskActionMessage
res = await cl.AskActionMessage(
content=f"The query returned {num_rows} rows and {num_cols} columns.\n\nWas this result helpful?",
actions=[
cl.Action(name="feedback", payload={"value": "positive"}, label="👍 Good results"),
cl.Action(name="feedback", payload={"value": "negative"}, label="👎 Not what I wanted")
],
).send()
if res:
feedback_value = res.get("payload", {}).get("value")
client = cl.user_session.get("client")
original_query = cl.user_session.get("original_query")
generated_sql = cl.user_session.get("generated_sql")
optimized_sql = cl.user_session.get("optimized_sql")
if feedback_value == "positive":
# Handle positive feedback
success = save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
"positive"
)
if success:
await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
else:
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
elif feedback_value == "negative":
# For negative feedback, just ask for text input
await cl.Message(content="I'm sorry the results weren't what you expected. Please type your feedback about what was wrong.", author="SQL Assistant").send()
# Set flag to indicate we're awaiting detailed feedback
cl.user_session.set("awaiting_feedback", True)
# Save initial negative feedback
save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
"negative"
)
except Exception as e:
await cl.Message(
content=f"❌ Error formatting results: {str(e)}",
author="SQL Assistant"
).send()
except Exception as e:
# Handle any errors
thinking_msg.content = f"❌ Error: {str(e)}"
await thinking_msg.update()
await cl.Message(
content=f"I encountered an error while processing your query: {str(e)}",
author="SQL Assistant"
).send()
# Callback handlers for actions
@cl.action_callback("feedback")
async def on_feedback_action(action):
"""Handle feedback action."""
feedback_value = action.payload.get("value")
client = cl.user_session.get("client")
original_query = cl.user_session.get("original_query")
generated_sql = cl.user_session.get("generated_sql")
optimized_sql = cl.user_session.get("optimized_sql")
if feedback_value == "positive":
# Handle positive feedback
success = save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
"positive"
)
if success:
await cl.Message(content="Thanks for your positive feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
else:
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
@cl.action_callback("feedback_bad")
async def on_feedback_bad(action):
"""Handle negative feedback."""
# Ask for more detailed feedback
res = await cl.AskUserMessage(
content="I'm sorry the results weren't what you expected. Could you please provide more details about what was wrong?",
author="SQL Assistant",
timeout=300,
elements=[
cl.Textarea(
id="feedback_details",
label="Your feedback",
initial_value="",
rows=3
)
]
).send()
feedback_details = "negative"
if res and "feedback_details" in res:
feedback_details = f"negative: {res['feedback_details']}"
client = cl.user_session.get("client")
original_query = cl.user_session.get("original_query")
generated_sql = cl.user_session.get("generated_sql")
optimized_sql = cl.user_session.get("optimized_sql")
# Save the feedback to BigQuery
success = save_feedback_to_bigquery(
client,
original_query,
generated_sql,
optimized_sql,
feedback_details
)
if success:
await cl.Message(content="Thanks for your detailed feedback! I've saved it to improve future responses.", author="SQL Assistant").send()
else:
await cl.Message(content="Thanks for your feedback! (Note: There was an issue saving it to the database)", author="SQL Assistant").send()
# This is needed for Chainlit to run properly
if __name__ == "__main__":
# Note: Chainlit uses its own CLI command to run the app
# You'll run this with: chainlit run new_app.py -w
pass