Ari commited on
Commit
c263e75
·
verified ·
1 Parent(s): 9ea7f8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -41,11 +41,14 @@ st.write(f"Valid columns: {valid_columns}")
41
 
42
  # Step 3: Set up the LLM Chain to generate SQL queries
43
  template = """
44
- You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
45
 
46
- Ensure that:
47
- - You only use the columns provided.
48
- - String comparisons in the WHERE clause are case-insensitive by using 'COLLATE NOCASE' or the LOWER() function.
 
 
 
49
 
50
  Question: {question}
51
 
@@ -53,8 +56,9 @@ Table name: {table_name}
53
 
54
  Valid columns: {columns}
55
 
56
- SQL Query:
57
  """
 
58
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
59
  sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
60
 
@@ -86,6 +90,7 @@ def process_input():
86
  # It's a SQL query
87
  st.write(f"Generated SQL Query:\n{code}")
88
  try:
 
89
  result = pd.read_sql_query(code, conn)
90
  assistant_response = f"Generated SQL Query:\n{code}"
91
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
@@ -134,6 +139,7 @@ def process_input():
134
  # Reset the user_input in session state
135
  st.session_state['user_input'] = ''
136
 
 
137
  def extract_code(response):
138
  """Extracts code enclosed between <CODE> and </CODE> tags."""
139
  import re
@@ -144,7 +150,6 @@ def extract_code(response):
144
  else:
145
  return None
146
 
147
- # Display the conversation history
148
  # Display the conversation history
149
  for message in st.session_state.history:
150
  if message['role'] == 'user':
@@ -160,5 +165,6 @@ for message in st.session_state.history:
160
  st.markdown(f"**Assistant:** {content}")
161
 
162
 
 
163
  # Place the input field at the bottom with the callback
164
  st.text_input("Enter your message:", key='user_input', on_change=process_input)
 
41
 
42
  # Step 3: Set up the LLM Chain to generate SQL queries
43
  template = """
44
+ You are an expert data scientist assistant. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query or Python code that answers the question.
45
 
46
+ Instructions:
47
+ - If the question involves data retrieval or simple aggregations, generate a SQL query.
48
+ - Ensure that you only use the columns provided.
49
+ - For case-insensitive string comparisons, use either 'LOWER(column) = LOWER(value)' or 'column = value COLLATE NOCASE', but do not use both together.
50
+ - Do not include any import statements in the code.
51
+ - Provide the code between <CODE> and </CODE> tags.
52
 
53
  Question: {question}
54
 
 
56
 
57
  Valid columns: {columns}
58
 
59
+ Response:
60
  """
61
+
62
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
63
  sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
64
 
 
90
  # It's a SQL query
91
  st.write(f"Generated SQL Query:\n{code}")
92
  try:
93
+ # Execute the SQL query
94
  result = pd.read_sql_query(code, conn)
95
  assistant_response = f"Generated SQL Query:\n{code}"
96
  st.session_state.history.append({"role": "assistant", "content": assistant_response})
 
139
  # Reset the user_input in session state
140
  st.session_state['user_input'] = ''
141
 
142
+
143
  def extract_code(response):
144
  """Extracts code enclosed between <CODE> and </CODE> tags."""
145
  import re
 
150
  else:
151
  return None
152
 
 
153
  # Display the conversation history
154
  for message in st.session_state.history:
155
  if message['role'] == 'user':
 
165
  st.markdown(f"**Assistant:** {content}")
166
 
167
 
168
+
169
  # Place the input field at the bottom with the callback
170
  st.text_input("Enter your message:", key='user_input', on_change=process_input)