Ari commited on
Commit
1d00adc
·
verified ·
1 Parent(s): b21f6bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -30
app.py CHANGED
@@ -20,7 +20,7 @@ if not openai_api_key:
20
  st.stop()
21
 
22
  # Step 1: Upload CSV data file (or use default)
23
- st.title("Natural Language to SQL Query App with Data Insights")
24
  st.write("Upload a CSV file to get started, or use the default dataset.")
25
 
26
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
@@ -55,6 +55,8 @@ Ensure that:
55
  - Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
56
  - Do not apply 'COLLATE NOCASE' to numeric columns.
57
 
 
 
58
  Question: {question}
59
 
60
  Table name: {table_name}
@@ -67,7 +69,7 @@ sql_prompt = PromptTemplate(template=sql_template, input_variables=['question',
67
  llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
68
  sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
69
 
70
- # AnswerScript for generating insights based on query results
71
  insights_template = """
72
  You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise and informative analysis that includes data insights and actionable recommendations.
73
 
@@ -81,6 +83,18 @@ Analysis and Recommendations:
81
  insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
82
  insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
83
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
85
  def clean_sql_query(query):
86
  """Removes incorrect usage of COLLATE NOCASE from the SQL query."""
@@ -103,6 +117,42 @@ def clean_sql_query(query):
103
  statements.append(''.join([str(t) for t in tokens]))
104
  return ' '.join(statements)
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # Define the callback function
107
  def process_input():
108
  user_prompt = st.session_state['user_input']
@@ -112,46 +162,77 @@ def process_input():
112
  # Append user message to history
113
  st.session_state.history.append({"role": "user", "content": user_prompt})
114
 
115
- if "columns" in user_prompt.lower():
 
 
 
 
116
  assistant_response = f"The columns are: {', '.join(valid_columns)}"
117
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
118
- else:
119
  columns = ', '.join(valid_columns)
120
  generated_sql = sql_generation_chain.run({
121
  'question': user_prompt,
122
  'table_name': table_name,
123
  'columns': columns
124
- })
125
 
126
- # Clean the SQL query
127
- generated_sql = clean_sql_query(generated_sql)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # Attempt to execute SQL query and handle exceptions
130
- try:
131
- result = pd.read_sql_query(generated_sql, conn)
132
 
133
- if result.empty:
134
- assistant_response = "The query returned no results. Please try a different question."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
136
- else:
137
- # Convert the result to a string for the insights prompt
138
- result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
139
-
140
- # Generate insights and recommendations
141
- insights = insights_chain.run({
142
- 'question': user_prompt,
143
- 'result': result_str
144
- })
145
-
146
- # Append the assistant's insights to the history
147
- st.session_state.history.append({"role": "assistant", "content": insights})
148
- # Append the result DataFrame to the history
149
- st.session_state.history.append({"role": "assistant", "content": result})
150
- except Exception as e:
151
- logging.error(f"An error occurred during SQL execution: {e}")
152
- assistant_response = f"Error executing SQL query: {e}"
153
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
154
 
 
 
 
155
  except Exception as e:
156
  logging.error(f"An error occurred: {e}")
157
  assistant_response = f"Error: {e}"
 
20
  st.stop()
21
 
22
  # Step 1: Upload CSV data file (or use default)
23
+ st.title("Natural Language to SQL Query App with Enhanced Insights")
24
  st.write("Upload a CSV file to get started, or use the default dataset.")
25
 
26
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
 
55
  - Do not use 'COLLATE NOCASE' in ORDER BY clauses unless sorting a string column.
56
  - Do not apply 'COLLATE NOCASE' to numeric columns.
57
 
58
+ If the question is vague or open-ended and does not pertain to specific data retrieval, respond with "NO_SQL" to indicate that a SQL query should not be generated.
59
+
60
  Question: {question}
61
 
62
  Table name: {table_name}
 
69
  llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
70
  sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
71
 
72
+ # Insights Generation Chain
73
  insights_template = """
74
  You are an expert data scientist. Based on the user's question and the SQL query result provided below, generate a concise and informative analysis that includes data insights and actionable recommendations.
75
 
 
83
  insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
84
  insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
85
 
86
+ # General Insights and Recommendations Chain
87
+ general_insights_template = """
88
+ You are an expert data scientist. Based on the entire dataset provided below, generate a comprehensive analysis that includes key insights and actionable recommendations.
89
+
90
+ Dataset Summary:
91
+ {dataset_summary}
92
+
93
+ Analysis and Recommendations:
94
+ """
95
+ general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary'])
96
+ general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt)
97
+
98
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
99
  def clean_sql_query(query):
100
  """Removes incorrect usage of COLLATE NOCASE from the SQL query."""
 
117
  statements.append(''.join([str(t) for t in tokens]))
118
  return ' '.join(statements)
119
 
120
+ # Function to classify user query
121
+ def classify_query(question):
122
+ """Classify the user query as either 'SQL' or 'INSIGHTS'."""
123
+ classification_template = """
124
+ You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries.
125
+
126
+ Determine the appropriate category for the following user question.
127
+
128
+ Question: "{question}"
129
+
130
+ Category (SQL/INSIGHTS):
131
+ """
132
+ classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
133
+ classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
134
+ category = classification_chain.run({'question': question}).strip().upper()
135
+ if category.startswith('SQL'):
136
+ return 'SQL'
137
+ else:
138
+ return 'INSIGHTS'
139
+
140
+ # Function to generate dataset summary
141
+ def generate_dataset_summary(data):
142
+ """Generate a summary of the dataset for general insights."""
143
+ summary_template = """
144
+ You are an expert data scientist. Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
145
+
146
+ Dataset:
147
+ {data}
148
+
149
+ Dataset Summary:
150
+ """
151
+ summary_prompt = PromptTemplate(template=summary_template, input_variables=['data'])
152
+ summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
153
+ summary = summary_chain.run({'data': data.head().to_string(index=False)})
154
+ return summary
155
+
156
  # Define the callback function
157
  def process_input():
158
  user_prompt = st.session_state['user_input']
 
162
  # Append user message to history
163
  st.session_state.history.append({"role": "user", "content": user_prompt})
164
 
165
+ # Classify the user query
166
+ category = classify_query(user_prompt)
167
+ logging.info(f"User query classified as: {category}")
168
+
169
+ if "COLUMNS" in user_prompt.upper():
170
  assistant_response = f"The columns are: {', '.join(valid_columns)}"
171
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
172
+ elif category == 'SQL':
173
  columns = ', '.join(valid_columns)
174
  generated_sql = sql_generation_chain.run({
175
  'question': user_prompt,
176
  'table_name': table_name,
177
  'columns': columns
178
+ }).strip()
179
 
180
+ if generated_sql.upper() == "NO_SQL":
181
+ # Handle cases where no SQL should be generated
182
+ assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
183
+
184
+ # Generate dataset summary
185
+ dataset_summary = generate_dataset_summary(data)
186
+
187
+ # Generate general insights and recommendations
188
+ general_insights = general_insights_chain.run({
189
+ 'dataset_summary': dataset_summary
190
+ })
191
+
192
+ # Append the assistant's insights to the history
193
+ st.session_state.history.append({"role": "assistant", "content": general_insights})
194
+ else:
195
+ # Clean the SQL query
196
+ cleaned_sql = clean_sql_query(generated_sql)
197
+ logging.info(f"Generated SQL Query: {cleaned_sql}")
198
 
199
+ # Attempt to execute SQL query and handle exceptions
200
+ try:
201
+ result = pd.read_sql_query(cleaned_sql, conn)
202
 
203
+ if result.empty:
204
+ assistant_response = "The query returned no results. Please try a different question."
205
+ st.session_state.history.append({"role": "assistant", "content": assistant_response})
206
+ else:
207
+ # Convert the result to a string for the insights prompt
208
+ result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
209
+
210
+ # Generate insights and recommendations based on the query result
211
+ insights = insights_chain.run({
212
+ 'question': user_prompt,
213
+ 'result': result_str
214
+ })
215
+
216
+ # Append the assistant's insights to the history
217
+ st.session_state.history.append({"role": "assistant", "content": insights})
218
+ # Append the result DataFrame to the history
219
+ st.session_state.history.append({"role": "assistant", "content": result})
220
+ except Exception as e:
221
+ logging.error(f"An error occurred during SQL execution: {e}")
222
+ assistant_response = f"Error executing SQL query: {e}"
223
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
224
+ else: # INSIGHTS category
225
+ # Generate dataset summary
226
+ dataset_summary = generate_dataset_summary(data)
227
+
228
+ # Generate general insights and recommendations
229
+ general_insights = general_insights_chain.run({
230
+ 'dataset_summary': dataset_summary
231
+ })
 
 
 
 
 
 
 
 
 
 
232
 
233
+ # Append the assistant's insights to the history
234
+ st.session_state.history.append({"role": "assistant", "content": general_insights})
235
+
236
  except Exception as e:
237
  logging.error(f"An error occurred: {e}")
238
  assistant_response = f"Error: {e}"