ariankhalfani commited on
Commit
5e8012a
1 Parent(s): d6251a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -43
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import sqlite3
3
  import requests
4
- import PyPDF2
5
  import faiss
6
  import numpy as np
7
  from sentence_transformers import SentenceTransformer
@@ -19,11 +19,11 @@ def query_huggingface(payload):
19
 
20
  # Function to extract text from PDF
21
  def extract_text_from_pdf(pdf_file):
22
- pdf_reader = PyPDF2.PdfReader(pdf_file)
23
  text = ""
24
- for page_num in range(len(pdf_reader.pages)):
25
- page = pdf_reader.pages[page_num]
26
- text += page.extract_text()
 
27
  return text
28
 
29
  # Initialize SQLite database
@@ -60,6 +60,9 @@ def get_context():
60
  # Function to create or update the FAISS index
61
  def update_faiss_index():
62
  contexts = get_context()
 
 
 
63
  embeddings = model.encode(contexts, convert_to_tensor=True)
64
  index = faiss.IndexFlatL2(embeddings.shape[1])
65
  index.add(embeddings.cpu().numpy())
@@ -67,6 +70,9 @@ def update_faiss_index():
67
 
68
  # Retrieve relevant context from the FAISS index
69
  def retrieve_relevant_context(index, contexts, query, top_k=5):
 
 
 
70
  query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
71
  distances, indices = index.search(query_embedding, top_k)
72
  relevant_contexts = [contexts[i] for i in indices[0]]
@@ -77,49 +83,25 @@ init_db()
77
  model = SentenceTransformer('all-MiniLM-L6-v2')
78
  faiss_index, context_list = update_faiss_index()
79
 
80
- # Function to handle chatbot responses
81
- def chatbot_response(question):
82
  relevant_contexts = retrieve_relevant_context(faiss_index, context_list, question)
83
  user_input = f"question: {question} context: {' '.join(relevant_contexts)}"
84
  response = query_huggingface({"inputs": user_input})
85
  response_text = response.get("generated_text", "Sorry, I couldn't generate a response.")
86
  return response_text
87
 
88
- # Function to handle PDF uploads
89
- def handle_pdf_upload(pdf_file):
90
- context = extract_text_from_pdf(pdf_file)
91
- add_context(pdf_file.name, context)
92
- faiss_index, context_list = update_faiss_index() # Update FAISS index
93
- return f"Context from {pdf_file.name} added to the database."
94
-
95
- # Gradio UI
96
- with gr.Blocks() as demo:
97
- gr.Markdown("# Storage Warehouse Customer Service Chatbot")
98
-
99
- with gr.Row():
100
- with gr.Column(scale=4):
101
- with gr.Box():
102
- pdf_upload = gr.File(label="Upload PDF", file_types=["pdf"], interactive=True)
103
- upload_button = gr.Button("Upload")
104
- upload_status = gr.Textbox(label="Upload Status")
105
-
106
- def handle_upload(files):
107
- for file in files:
108
- result = handle_pdf_upload(file.name)
109
- upload_status.value = result
110
-
111
- upload_button.click(fn=handle_upload, inputs=pdf_upload, outputs=upload_status)
112
-
113
- with gr.Column(scale=8):
114
- chatbot = gr.Chatbot(label="Chatbot")
115
- question = gr.Textbox(label="Your question here:")
116
- submit_button = gr.Button("Submit")
117
-
118
- def handle_chat(user_input):
119
- bot_response = chatbot_response(user_input)
120
- return gr.Chatbot.update([[user_input, bot_response]])
121
 
122
- submit_button.click(fn=handle_chat, inputs=question, outputs=chatbot)
 
 
123
 
124
- if __name__ == "__main__":
125
- demo.launch()
 
1
  import os
2
  import sqlite3
3
  import requests
4
+ import fitz # PyMuPDF
5
  import faiss
6
  import numpy as np
7
  from sentence_transformers import SentenceTransformer
 
19
 
20
  # Function to extract text from PDF
21
  def extract_text_from_pdf(pdf_file):
 
22
  text = ""
23
+ pdf_document = fitz.open(stream=pdf_file.read(), filetype="pdf")
24
+ for page_num in range(len(pdf_document)):
25
+ page = pdf_document.load_page(page_num)
26
+ text += page.get_text()
27
  return text
28
 
29
  # Initialize SQLite database
 
60
  # Function to create or update the FAISS index
61
  def update_faiss_index():
62
  contexts = get_context()
63
+ if len(contexts) == 0:
64
+ return None, contexts
65
+
66
  embeddings = model.encode(contexts, convert_to_tensor=True)
67
  index = faiss.IndexFlatL2(embeddings.shape[1])
68
  index.add(embeddings.cpu().numpy())
 
70
 
71
  # Retrieve relevant context from the FAISS index
72
  def retrieve_relevant_context(index, contexts, query, top_k=5):
73
+ if index is None or len(contexts) == 0:
74
+ return []
75
+
76
  query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
77
  distances, indices = index.search(query_embedding, top_k)
78
  relevant_contexts = [contexts[i] for i in indices[0]]
 
83
  model = SentenceTransformer('all-MiniLM-L6-v2')
84
  faiss_index, context_list = update_faiss_index()
85
 
86
+ # Gradio interface
87
+ def chatbot(question):
88
  relevant_contexts = retrieve_relevant_context(faiss_index, context_list, question)
89
  user_input = f"question: {question} context: {' '.join(relevant_contexts)}"
90
  response = query_huggingface({"inputs": user_input})
91
  response_text = response.get("generated_text", "Sorry, I couldn't generate a response.")
92
  return response_text
93
 
94
+ # File upload function
95
+ def upload_pdf(file):
96
+ context = extract_text_from_pdf(file)
97
+ add_context(file.name, context)
98
+ global faiss_index, context_list
99
+ faiss_index, context_list = update_faiss_index()
100
+ return "PDF content added to context."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # Gradio interface
103
+ iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="Storage Warehouse Customer Service Chatbot")
104
+ file_upload = gr.Interface(fn=upload_pdf, inputs="file", outputs="text", title="Upload PDF for Context")
105
 
106
+ app = gr.TabbedInterface([iface, file_upload], ["Chatbot", "Upload PDF"])
107
+ app.launch()