Ari commited on
Commit
5189e45
·
verified ·
1 Parent(s): 0bb1965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -9
app.py CHANGED
@@ -27,15 +27,16 @@ csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
27
  if csv_file is None:
28
  data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
29
  st.write("Using default_data.csv file.")
 
30
  else:
31
  data = pd.read_csv(csv_file)
 
32
  st.write(f"Data Preview ({csv_file.name}):")
33
  st.dataframe(data.head())
34
 
35
  # Step 2: Load CSV data into a persistent SQLite database
36
  db_file = 'my_database.db'
37
  conn = sqlite3.connect(db_file)
38
- table_name = csv_file.name.split('.')[0] if csv_file else "default_table"
39
  data.to_sql(table_name, conn, index=False, if_exists='replace')
40
 
41
  # SQL table metadata (for validation and schema)
@@ -43,7 +44,7 @@ valid_columns = list(data.columns)
43
  st.write(f"Valid columns: {valid_columns}")
44
 
45
  # Step 3: Set up the LLM Chain to generate SQL queries
46
- template = """
47
  You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
48
 
49
  Ensure that:
@@ -61,9 +62,9 @@ Valid columns: {columns}
61
 
62
  SQL Query:
63
  """
64
- prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
65
  llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
66
- sql_generation_chain = LLMChain(llm=llm, prompt=prompt)
67
 
68
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
69
  def clean_sql_query(query):
@@ -107,17 +108,34 @@ def process_input():
107
  'columns': columns
108
  })
109
 
110
- # Debug: Display generated SQL query for inspection
111
- st.write(f"Generated SQL Query:\n{generated_sql}")
112
-
113
  # Clean the SQL query
114
  generated_sql = clean_sql_query(generated_sql)
115
 
116
  # Attempt to execute SQL query and handle exceptions
117
  try:
118
  result = pd.read_sql_query(generated_sql, conn)
119
- assistant_response = f"Generated SQL Query:\n{generated_sql}"
120
- st.session_state.history.append({"role": "assistant", "content": assistant_response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  st.session_state.history.append({"role": "assistant", "content": result})
122
  except Exception as e:
123
  logging.error(f"An error occurred during SQL execution: {e}")
 
27
  if csv_file is None:
28
  data = pd.read_csv("default_data.csv") # Ensure this file exists in your working directory
29
  st.write("Using default_data.csv file.")
30
+ table_name = "default_table"
31
  else:
32
  data = pd.read_csv(csv_file)
33
+ table_name = csv_file.name.split('.')[0]
34
  st.write(f"Data Preview ({csv_file.name}):")
35
  st.dataframe(data.head())
36
 
37
  # Step 2: Load CSV data into a persistent SQLite database
38
  db_file = 'my_database.db'
39
  conn = sqlite3.connect(db_file)
 
40
  data.to_sql(table_name, conn, index=False, if_exists='replace')
41
 
42
  # SQL table metadata (for validation and schema)
 
44
  st.write(f"Valid columns: {valid_columns}")
45
 
46
  # Step 3: Set up the LLM Chain to generate SQL queries
47
+ sql_template = """
48
  You are an expert data scientist. Given a natural language question, the name of the table, and a list of valid columns, generate a valid SQL query that answers the question.
49
 
50
  Ensure that:
 
62
 
63
  SQL Query:
64
  """
65
+ sql_prompt = PromptTemplate(template=sql_template, input_variables=['question', 'table_name', 'columns'])
66
  llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
67
+ sql_generation_chain = LLMChain(llm=llm, prompt=sql_prompt)
68
 
69
  # Optional: Clean up function to remove incorrect COLLATE NOCASE usage
70
  def clean_sql_query(query):
 
108
  'columns': columns
109
  })
110
 
 
 
 
111
  # Clean the SQL query
112
  generated_sql = clean_sql_query(generated_sql)
113
 
114
  # Attempt to execute SQL query and handle exceptions
115
  try:
116
  result = pd.read_sql_query(generated_sql, conn)
117
+
118
+ # Limit the result to first 5 rows for brevity
119
+ result_limited = result.head(5)
120
+ result_str = result_limited.to_string(index=False)
121
+
122
+ # Generate natural language answer
123
+ answer_template = """
124
+ Given the user's question and the SQL query result, provide a concise and informative answer to the question using the data from the query result.
125
+
126
+ User's question: {question}
127
+ Query result:
128
+ {result}
129
+
130
+ Answer:
131
+ """
132
+ answer_prompt = PromptTemplate(template=answer_template, input_variables=['question', 'result'])
133
+ answer_chain = LLMChain(llm=llm, prompt=answer_prompt)
134
+ assistant_answer = answer_chain.run({'question': user_prompt, 'result': result_str})
135
+
136
+ # Append the assistant's answer to the history
137
+ st.session_state.history.append({"role": "assistant", "content": assistant_answer})
138
+ # Append the result DataFrame to the history
139
  st.session_state.history.append({"role": "assistant", "content": result})
140
  except Exception as e:
141
  logging.error(f"An error occurred during SQL execution: {e}")