Shreyas094 commited on
Commit
46953d2
·
verified ·
1 Parent(s): 32fb8f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -40
app.py CHANGED
@@ -13,10 +13,12 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
13
  from langchain_community.llms import HuggingFaceHub
14
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
15
  from langchain_core.documents import Document
 
16
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
17
 
18
  # Memory database to store question-answer pairs
19
  memory_database = {}
 
20
 
21
  def load_and_split_document_basic(file):
22
  """Loads and splits the document into pages."""
@@ -57,8 +59,13 @@ def clear_cache():
57
  return "No cache to clear."
58
 
59
  prompt = """
60
- Answer the question based only on the following context:
 
 
 
 
61
  {context}
 
62
  Question: {question}
63
 
64
  Provide a concise and direct answer to the question:
@@ -81,21 +88,46 @@ def generate_chunked_response(model, prompt, max_tokens=1000, max_chunks=5):
81
  for i in range(max_chunks):
82
  chunk = model(prompt + full_response, max_new_tokens=max_tokens)
83
  chunk = chunk.strip()
84
- # Check for final sentence endings
85
  if chunk.endswith((".", "!", "?")):
86
  full_response += chunk
87
  break
88
  full_response += chunk
89
  return full_response.strip()
90
 
91
- def response(database, model, question):
92
- prompt_val = ChatPromptTemplate.from_template(prompt)
93
- retriever = database.as_retriever()
94
- context = retriever.get_relevant_documents(question)
95
- context_str = "\n".join([doc.page_content for doc in context])
96
- formatted_prompt = prompt_val.format(context=context_str, question=question)
97
- ans = generate_chunked_response(model, formatted_prompt)
98
- return ans.split("Question:")[-1].strip() # Return only the answer part
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def update_vectors(files, use_recursive_splitter):
101
  if not files:
@@ -114,26 +146,6 @@ def update_vectors(files, use_recursive_splitter):
114
 
115
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
116
 
117
- def ask_question(question, temperature, top_p, repetition_penalty):
118
- if not question:
119
- return "Please enter a question."
120
-
121
- # Check if the question exists in the memory database
122
- if question in memory_database:
123
- return memory_database[question]
124
-
125
- embed = get_embeddings()
126
- database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
127
- model = get_model(temperature, top_p, repetition_penalty)
128
-
129
- # Generate response from document database
130
- answer = response(database, model, question)
131
-
132
- # Store the question and answer in the memory database
133
- memory_database[question] = answer
134
-
135
- return answer
136
-
137
  def extract_db_to_excel():
138
  embed = get_embeddings()
139
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
@@ -150,11 +162,16 @@ def extract_db_to_excel():
150
 
151
  def export_memory_db_to_excel():
152
  data = [{"question": question, "answer": answer} for question, answer in memory_database.items()]
153
- df = pd.DataFrame(data)
 
 
 
154
 
155
  with NamedTemporaryFile(delete=False, suffix='.xlsx') as tmp:
156
  excel_path = tmp.name
157
- df.to_excel(excel_path, index=False)
 
 
158
 
159
  return excel_path
160
 
@@ -171,14 +188,21 @@ with gr.Blocks() as demo:
171
  update_button.click(update_vectors, inputs=[file_input, use_recursive_splitter], outputs=update_output)
172
 
173
  with gr.Row():
174
- question_input = gr.Textbox(label="Ask a question about your documents")
175
- temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
176
- top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
177
- repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
178
- submit_button = gr.Button("Submit")
179
-
180
- answer_output = gr.Textbox(label="Answer")
181
- submit_button.click(ask_question, inputs=[question_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=answer_output)
 
 
 
 
 
 
 
182
 
183
  extract_button = gr.Button("Extract Database to Excel")
184
  excel_output = gr.File(label="Download Excel File")
 
13
  from langchain_community.llms import HuggingFaceHub
14
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
15
  from langchain_core.documents import Document
16
+
17
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
18
 
19
  # Memory database to store question-answer pairs
20
  memory_database = {}
21
+ conversation_history = []
22
 
23
  def load_and_split_document_basic(file):
24
  """Loads and splits the document into pages."""
 
59
  return "No cache to clear."
60
 
61
  prompt = """
62
+ Answer the question based on the following context and conversation history:
63
+ Conversation History:
64
+ {history}
65
+
66
+ Context from documents:
67
  {context}
68
+
69
  Question: {question}
70
 
71
  Provide a concise and direct answer to the question:
 
88
  for i in range(max_chunks):
89
  chunk = model(prompt + full_response, max_new_tokens=max_tokens)
90
  chunk = chunk.strip()
 
91
  if chunk.endswith((".", "!", "?")):
92
  full_response += chunk
93
  break
94
  full_response += chunk
95
  return full_response.strip()
96
 
97
+ def manage_conversation_history(question, answer, history, max_history=5):
98
+ history.append({"question": question, "answer": answer})
99
+ if len(history) > max_history:
100
+ history.pop(0)
101
+ return history
102
+
103
+ def ask_question(question, temperature, top_p, repetition_penalty):
104
+ global conversation_history
105
+
106
+ if not question:
107
+ return "Please enter a question."
108
+
109
+ if question in memory_database:
110
+ answer = memory_database[question]
111
+ else:
112
+ embed = get_embeddings()
113
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
114
+ model = get_model(temperature, top_p, repetition_penalty)
115
+
116
+ history_str = "\n".join([f"Q: {item['question']}\nA: {item['answer']}" for item in conversation_history])
117
+ prompt_val = ChatPromptTemplate.from_template(prompt)
118
+ retriever = database.as_retriever()
119
+ relevant_docs = retriever.get_relevant_documents(question)
120
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
121
+ formatted_prompt = prompt_val.format(history=history_str, context=context_str, question=question)
122
+
123
+ answer = generate_chunked_response(model, formatted_prompt)
124
+ answer = answer.split("Question:")[-1].strip()
125
+
126
+ memory_database[question] = answer
127
+
128
+ conversation_history = manage_conversation_history(question, answer, conversation_history)
129
+
130
+ return answer
131
 
132
  def update_vectors(files, use_recursive_splitter):
133
  if not files:
 
146
 
147
  return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files."
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def extract_db_to_excel():
150
  embed = get_embeddings()
151
  database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
 
162
 
163
  def export_memory_db_to_excel():
164
  data = [{"question": question, "answer": answer} for question, answer in memory_database.items()]
165
+ df_memory = pd.DataFrame(data)
166
+
167
+ data_history = [{"question": item["question"], "answer": item["answer"]} for item in conversation_history]
168
+ df_history = pd.DataFrame(data_history)
169
 
170
  with NamedTemporaryFile(delete=False, suffix='.xlsx') as tmp:
171
  excel_path = tmp.name
172
+ with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
173
+ df_memory.to_excel(writer, sheet_name='Memory Database', index=False)
174
+ df_history.to_excel(writer, sheet_name='Conversation History', index=False)
175
 
176
  return excel_path
177
 
 
188
  update_button.click(update_vectors, inputs=[file_input, use_recursive_splitter], outputs=update_output)
189
 
190
  with gr.Row():
191
+ with gr.Column(scale=2):
192
+ chatbot = gr.Chatbot(label="Conversation")
193
+ question_input = gr.Textbox(label="Ask a question about your documents")
194
+ submit_button = gr.Button("Submit")
195
+ with gr.Column(scale=1):
196
+ temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
197
+ top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1)
198
+ repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1)
199
+
200
+ def chat(question, history):
201
+ answer = ask_question(question, temperature_slider.value, top_p_slider.value, repetition_penalty_slider.value)
202
+ history.append((question, answer))
203
+ return "", history
204
+
205
+ submit_button.click(chat, inputs=[question_input, chatbot], outputs=[question_input, chatbot])
206
 
207
  extract_button = gr.Button("Extract Database to Excel")
208
  excel_output = gr.File(label="Download Excel File")