arithescientist commited on
Commit
1746d1f
·
verified ·
1 Parent(s): e84e3a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -60
app.py CHANGED
@@ -2,45 +2,40 @@ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
- from langchain import OpenAI, LLMChain, PromptTemplate
6
- from transformers import LlamaForCausalLM, LlamaTokenizer
7
- import torch
8
  import sqlparse
9
  import logging
10
 
 
 
 
11
 
12
  # Initialize conversation history
13
  if 'history' not in st.session_state:
14
  st.session_state.history = []
15
 
16
- # OpenAI API key (ensure it is securely stored)
17
- openai_api_key = os.getenv("OPENAI_API_KEY")
18
-
19
- # Check if the API key is set
20
- if not openai_api_key:
21
- st.error("OpenAI API key is not set. Please set the OPENAI_API_KEY environment variable.")
22
- st.stop()
23
-
24
- # Load the LLaMA model and tokenizer
25
- model_name = "meta-llama/Llama-2-7b-hf" # Adjust to the LLaMA model you want
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
27
-
28
- try:
29
- llama_tokenizer = LlamaTokenizer.from_pretrained(model_name)
30
- llama_model = LlamaForCausalLM.from_pretrained(model_name).to(device)
31
- except Exception as e:
32
- st.error(f"Error loading LLaMA model: {e}")
33
- llama_tokenizer = None
34
- llama_model = None
35
-
36
- # Function to generate responses using LLaMA
37
- def generate_llama_response(prompt):
38
- if llama_tokenizer and llama_model:
39
- inputs = llama_tokenizer(prompt, return_tensors="pt").to(device)
40
- outputs = llama_model.generate(inputs.input_ids, max_length=200)
41
- return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
42
- else:
43
- return "LLaMA model is not available."
44
 
45
  # Step 1: Upload CSV data file (or use default)
46
  st.title("Natural Language to SQL Query App with Enhanced Insights")
@@ -66,10 +61,14 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
66
  valid_columns = list(data.columns)
67
  st.write(f"Valid columns: {valid_columns}")
68
 
69
- # Step 3: Set up the LLM Chains (SQL generation with OpenAI, insights with LLaMA)
70
- # SQL Generation Chain with OpenAI
71
  sql_template = """
72
- You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
 
 
 
 
73
 
74
  Ensure that:
75
 
@@ -87,24 +86,46 @@ Table name: {table_name}
87
  Valid columns: {columns}
88
 
89
  SQL Query:
 
90
  """
91
  sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
92
- sql_llm = OpenAI(temperature=0, openai_api_key=openai_api_key, max_tokens=180)
93
- sql_generation_chain = LLMChain(llm=sql_llm, prompt=sql_prompt)
 
 
 
 
 
94
 
95
- # General Insights and Recommendations Chain with LLaMA
96
- def generate_insights_llama(question, data_summary):
97
- insights_template = f"""
98
- You are an expert data scientist. Based on the user's question and the dataset summary provided below, generate concise data insights and actionable recommendations.
99
 
100
- User's Question: {question}
 
 
 
 
 
 
 
 
 
101
 
102
- Dataset Summary:
103
- {data_summary}
 
 
 
104
 
105
- Concise Insights and Recommendations:
106
- """
107
- return generate_llama_response(insights_template)
 
 
 
 
 
 
 
108
 
109
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
110
  def clean_sql_query(query):
@@ -132,16 +153,19 @@ def clean_sql_query(query):
132
  def classify_query(question):
133
  """Classify the user query as either 'SQL' or 'INSIGHTS'."""
134
  classification_template = """
135
- You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries.
 
 
136
 
137
- Determine the appropriate category for the following user question.
138
 
139
- Question: "{question}"
140
 
141
- Category (SQL/INSIGHTS):
142
- """
 
143
  classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
144
- classification_chain = LLMChain(llm=sql_llm, prompt=classification_prompt)
145
  category = classification_chain.run({'question': question}).strip().upper()
146
  if category.startswith('SQL'):
147
  return 'SQL'
@@ -151,7 +175,22 @@ def classify_query(question):
151
  # Function to generate dataset summary
152
  def generate_dataset_summary(data):
153
  """Generate a summary of the dataset for general insights."""
154
- summary = f"Number of records: {len(data)}, Number of columns: {len(data.columns)}, Columns: {list(data.columns)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  return summary
156
 
157
  # Define the callback function
@@ -179,9 +218,21 @@ def process_input():
179
  }).strip()
180
 
181
  if generated_sql.upper() == "NO_SQL":
182
- assistant_response = "No SQL query could be generated."
183
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
 
 
 
 
 
 
 
 
 
 
 
184
  else:
 
185
  cleaned_sql = clean_sql_query(generated_sql)
186
  logging.info(f"Generated SQL Query: {cleaned_sql}")
187
 
@@ -193,17 +244,34 @@ def process_input():
193
  assistant_response = "The query returned no results. Please try a different question."
194
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
195
  else:
196
- # Display query results
 
 
 
 
 
 
 
 
 
 
 
197
  st.session_state.history.append({"role": "assistant", "content": result})
198
-
199
  except Exception as e:
200
  logging.error(f"An error occurred during SQL execution: {e}")
201
  assistant_response = f"Error executing SQL query: {e}"
202
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
203
  else: # INSIGHTS category
 
204
  dataset_summary = generate_dataset_summary(data)
205
- insights = generate_insights_llama(user_prompt, dataset_summary)
206
- st.session_state.history.append({"role": "assistant", "content": insights})
 
 
 
 
 
 
207
 
208
  except Exception as e:
209
  logging.error(f"An error occurred: {e}")
@@ -213,7 +281,6 @@ def process_input():
213
  # Reset the user_input in session state
214
  st.session_state['user_input'] = ''
215
 
216
-
217
  # Display the conversation history
218
  for message in st.session_state.history:
219
  if message['role'] == 'user':
 
2
  import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
+ from langchain import LLMChain, PromptTemplate
 
 
6
  import sqlparse
7
  import logging
8
 
9
+ # Import necessary modules from transformers and langchain
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
11
+ from langchain.llms import HuggingFacePipeline
12
 
13
  # Initialize conversation history
14
  if 'history' not in st.session_state:
15
  st.session_state.history = []
16
 
17
+ # Set up the Llama-2-7b-chat-hf model
18
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
19
+
20
+ # Load the tokenizer and model
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', torch_dtype='auto') # Adjust device_map and torch_dtype as needed
23
+
24
+ # Create the text-generation pipeline with appropriate parameters
25
+ pipe = pipeline(
26
+ "text-generation",
27
+ model=model,
28
+ tokenizer=tokenizer,
29
+ max_new_tokens=512,
30
+ temperature=0.1,
31
+ repetition_penalty=1.1,
32
+ do_sample=True, # Use sampling to introduce some randomness
33
+ eos_token_id=tokenizer.eos_token_id,
34
+ pad_token_id=tokenizer.eos_token_id
35
+ )
36
+
37
+ # Wrap the pipeline with HuggingFacePipeline for use in LangChain
38
+ llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
39
 
40
  # Step 1: Upload CSV data file (or use default)
41
  st.title("Natural Language to SQL Query App with Enhanced Insights")
 
61
  valid_columns = list(data.columns)
62
  st.write(f"Valid columns: {valid_columns}")
63
 
64
+ # Step 3: Set up the LLM Chains with adjusted prompts
65
+ # SQL Generation Chain
66
  sql_template = """
67
+ [INST] <<SYS>>
68
+ You are an expert data scientist.
69
+ <</SYS>>
70
+
71
+ Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
72
 
73
  Ensure that:
74
 
 
86
  Valid columns: {columns}
87
 
88
  SQL Query:
89
+ [/INST]
90
  """
91
  sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
92
+ sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
93
+
94
+ # Insights Generation Chain
95
+ insights_template = """
96
+ [INST] <<SYS>>
97
+ You are an expert data scientist.
98
+ <</SYS>>
99
 
100
+ Based on the user's question and the SQL query result provided below, generate a concise analysis that includes key data insights and actionable recommendations. Limit the response to a maximum of 150 words.
 
 
 
101
 
102
+ User's Question: {question}
103
+
104
+ SQL Query Result:
105
+ {result}
106
+
107
+ Concise Analysis (max 200 words):
108
+ [/INST]
109
+ """
110
+ insights_prompt = PromptTemplate(template=insights_template, input_variables=['question', 'result'])
111
+ insights_chain = LLMChain(llm=llm, prompt=insights_prompt)
112
 
113
+ # General Insights and Recommendations Chain
114
+ general_insights_template = """
115
+ [INST] <<SYS>>
116
+ You are an expert data scientist.
117
+ <</SYS>>
118
 
119
+ Based on the entire dataset provided below, generate a concise analysis with key insights and recommendations. Limit the response to 150 words.
120
+
121
+ Dataset Summary:
122
+ {dataset_summary}
123
+
124
+ Concise Analysis and Recommendations (max 150 words):
125
+ [/INST]
126
+ """
127
+ general_insights_prompt = PromptTemplate(template=general_insights_template, input_variables=['dataset_summary'])
128
+ general_insights_chain = LLMChain(llm=llm, prompt=general_insights_prompt)
129
 
130
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
131
  def clean_sql_query(query):
 
153
  def classify_query(question):
154
  """Classify the user query as either 'SQL' or 'INSIGHTS'."""
155
  classification_template = """
156
+ [INST] <<SYS>>
157
+ You are an AI assistant that classifies user queries into two categories: 'SQL' for specific data retrieval queries and 'INSIGHTS' for general analytical or recommendation queries.
158
+ <</SYS>>
159
 
160
+ Determine the appropriate category for the following user question.
161
 
162
+ Question: "{question}"
163
 
164
+ Category (SQL/INSIGHTS):
165
+ [/INST]
166
+ """
167
  classification_prompt = PromptTemplate(template=classification_template, input_variables=['question'])
168
+ classification_chain = LLMChain(llm=llm, prompt=classification_prompt)
169
  category = classification_chain.run({'question': question}).strip().upper()
170
  if category.startswith('SQL'):
171
  return 'SQL'
 
175
  # Function to generate dataset summary
176
  def generate_dataset_summary(data):
177
  """Generate a summary of the dataset for general insights."""
178
+ summary_template = """
179
+ [INST] <<SYS>>
180
+ You are an expert data scientist.
181
+ <</SYS>>
182
+
183
+ Based on the dataset provided below, generate a concise summary that includes the number of records, number of columns, data types, and any notable features.
184
+
185
+ Dataset:
186
+ {data}
187
+
188
+ Dataset Summary:
189
+ [/INST]
190
+ """
191
+ summary_prompt = PromptTemplate(template=summary_template, input_variables=['data'])
192
+ summary_chain = LLMChain(llm=llm, prompt=summary_prompt)
193
+ summary = summary_chain.run({'data': data.head().to_string(index=False)})
194
  return summary
195
 
196
  # Define the callback function
 
218
  }).strip()
219
 
220
  if generated_sql.upper() == "NO_SQL":
221
+ # Handle cases where no SQL should be generated
222
+ assistant_response = "Sure, let's discuss some general insights and recommendations based on the data."
223
+
224
+ # Generate dataset summary
225
+ dataset_summary = generate_dataset_summary(data)
226
+
227
+ # Generate general insights and recommendations
228
+ general_insights = general_insights_chain.run({
229
+ 'dataset_summary': dataset_summary
230
+ })
231
+
232
+ # Append the assistant's insights to the history
233
+ st.session_state.history.append({"role": "assistant", "content": general_insights})
234
  else:
235
+ # Clean the SQL query
236
  cleaned_sql = clean_sql_query(generated_sql)
237
  logging.info(f"Generated SQL Query: {cleaned_sql}")
238
 
 
244
  assistant_response = "The query returned no results. Please try a different question."
245
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
246
  else:
247
+ # Convert the result to a string for the insights prompt
248
+ result_str = result.head(10).to_string(index=False) # Limit to first 10 rows
249
+
250
+ # Generate insights and recommendations based on the query result
251
+ insights = insights_chain.run({
252
+ 'question': user_prompt,
253
+ 'result': result_str
254
+ })
255
+
256
+ # Append the assistant's insights to the history
257
+ st.session_state.history.append({"role": "assistant", "content": insights})
258
+ # Append the result DataFrame to the history
259
  st.session_state.history.append({"role": "assistant", "content": result})
 
260
  except Exception as e:
261
  logging.error(f"An error occurred during SQL execution: {e}")
262
  assistant_response = f"Error executing SQL query: {e}"
263
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
264
  else: # INSIGHTS category
265
+ # Generate dataset summary
266
  dataset_summary = generate_dataset_summary(data)
267
+
268
+ # Generate general insights and recommendations
269
+ general_insights = general_insights_chain.run({
270
+ 'dataset_summary': dataset_summary
271
+ })
272
+
273
+ # Append the assistant's insights to the history
274
+ st.session_state.history.append({"role": "assistant", "content": general_insights})
275
 
276
  except Exception as e:
277
  logging.error(f"An error occurred: {e}")
 
281
  # Reset the user_input in session state
282
  st.session_state['user_input'] = ''
283
 
 
284
  # Display the conversation history
285
  for message in st.session_state.history:
286
  if message['role'] == 'user':