Shreyas094 commited on
Commit
ee5661b
·
verified ·
1 Parent(s): 8f325c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -7
app.py CHANGED
@@ -230,19 +230,24 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
230
  model = get_model(temperature, top_p, repetition_penalty)
231
  embed = get_embeddings()
232
 
233
- # Check if the FAISS database exists, if not create an empty one
234
  if os.path.exists("faiss_database"):
235
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
236
  else:
237
- database = FAISS.from_documents([], embed)
238
- database.save_local("faiss_database")
239
 
240
  if web_search:
241
  search_results = google_search(question)
242
  web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
243
 
244
- # Add web search results to the existing database
245
- database.add_documents(web_docs)
 
 
 
 
 
246
  database.save_local("faiss_database")
247
 
248
  context_str = "\n".join([doc.page_content for doc in web_docs])
@@ -258,6 +263,9 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
258
  prompt_val = ChatPromptTemplate.from_template(prompt_template)
259
  formatted_prompt = prompt_val.format(context=context_str, question=question)
260
  else:
 
 
 
261
  history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
262
 
263
  if is_related_to_history(question, conversation_history):
@@ -290,15 +298,24 @@ def update_vectors(files, use_recursive_splitter):
290
  embed = get_embeddings()
291
  total_chunks = 0
292
 
 
293
  for file in files:
294
  if use_recursive_splitter:
295
  data = load_and_split_document_recursive(file)
296
  else:
297
  data = load_and_split_document_basic(file)
298
- create_or_update_database(data, embed)
299
  total_chunks += len(data)
300
 
301
- return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
 
 
 
 
 
 
 
 
302
 
303
  def extract_db_to_excel():
304
  embed = get_embeddings()
 
230
  model = get_model(temperature, top_p, repetition_penalty)
231
  embed = get_embeddings()
232
 
233
+ # Check if the FAISS database exists
234
  if os.path.exists("faiss_database"):
235
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
236
  else:
237
+ # If no database exists, we'll create it with the first web search or document upload
238
+ database = None
239
 
240
  if web_search:
241
  search_results = google_search(question)
242
  web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
243
 
244
+ if database is None:
245
+ # Create the database with web search results if it doesn't exist
246
+ database = FAISS.from_documents(web_docs, embed)
247
+ else:
248
+ # Add web search results to the existing database
249
+ database.add_documents(web_docs)
250
+
251
  database.save_local("faiss_database")
252
 
253
  context_str = "\n".join([doc.page_content for doc in web_docs])
 
263
  prompt_val = ChatPromptTemplate.from_template(prompt_template)
264
  formatted_prompt = prompt_val.format(context=context_str, question=question)
265
  else:
266
+ if database is None:
267
+ return "No documents or web search results available. Please upload documents or enable web search."
268
+
269
  history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
270
 
271
  if is_related_to_history(question, conversation_history):
 
298
  embed = get_embeddings()
299
  total_chunks = 0
300
 
301
+ all_data = []
302
  for file in files:
303
  if use_recursive_splitter:
304
  data = load_and_split_document_recursive(file)
305
  else:
306
  data = load_and_split_document_basic(file)
307
+ all_data.extend(data)
308
  total_chunks += len(data)
309
 
310
+ if os.path.exists("faiss_database"):
311
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
312
+ database.add_documents(all_data)
313
+ else:
314
+ database = FAISS.from_documents(all_data, embed)
315
+
316
+ database.save_local("faiss_database")
317
+
318
+ return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files.""
319
 
320
  def extract_db_to_excel():
321
  embed = get_embeddings()