clementsan commited on
Commit
00bd139
·
1 Parent(s): ceae871

Update qa_chain to gradio session state

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -107,7 +107,6 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
107
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
108
  retriever=vector_db.as_retriever()
109
  progress(0.8, desc="Defining retrieval chain...")
110
- global qa_chain
111
  qa_chain = ConversationalRetrievalChain.from_llm(
112
  llm,
113
  retriever=retriever,
@@ -119,10 +118,10 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
119
  # verbose=True,
120
  )
121
  progress(0.9, desc="Done!")
122
- # return qa_chain
123
 
124
 
125
- # Initialize all elements
126
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
127
  # Create list of documents (when valid)
128
  #file_path = file_obj.name
@@ -137,16 +136,14 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Pr
137
  vector_db = create_db(doc_splits)
138
  progress(0.9, desc="Done!")
139
  return vector_db, "Complete!"
140
- #return qa_chain
141
 
142
 
143
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
144
  print("llm_option",llm_option)
145
  llm_name = list_llm[llm_option]
146
  print("llm_name",llm_name)
147
- initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
148
- return "Complete!"
149
- #return qa_chain
150
 
151
 
152
  def format_chat_history(message, chat_history):
@@ -157,7 +154,7 @@ def format_chat_history(message, chat_history):
157
  return formatted_chat_history
158
 
159
 
160
- def conversation(message, history):
161
  formatted_chat_history = format_chat_history(message, history)
162
  #print("formatted_chat_history",formatted_chat_history)
163
 
@@ -176,7 +173,7 @@ def conversation(message, history):
176
  # Append user message and response to chat history
177
  new_history = history + [(message, response_answer)]
178
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
179
- return gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
180
 
181
 
182
  def upload_file(file_obj):
@@ -192,7 +189,7 @@ def upload_file(file_obj):
192
  def demo():
193
  with gr.Blocks(theme="base") as demo:
194
  vector_db = gr.State()
195
- # qa_chain = gr.Variable()
196
 
197
  gr.Markdown(
198
  """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
@@ -252,19 +249,19 @@ def demo():
252
  outputs=[vector_db, db_progress])
253
  qachain_btn.click(initialize_LLM, \
254
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
255
- outputs=[llm_progress]).then(lambda:[None,"",0,"",0], \
256
  inputs=None, \
257
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \
258
  queue=False)
259
 
260
  # Chatbot events
261
  msg.submit(conversation, \
262
- inputs=[msg, chatbot], \
263
- outputs=[msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], \
264
  queue=False)
265
  submit_btn.click(conversation, \
266
- inputs=[msg, chatbot], \
267
- outputs=[msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], \
268
  queue=False)
269
  clear_btn.click(lambda:[None,"",0,"",0], \
270
  inputs=None, \
 
107
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
108
  retriever=vector_db.as_retriever()
109
  progress(0.8, desc="Defining retrieval chain...")
 
110
  qa_chain = ConversationalRetrievalChain.from_llm(
111
  llm,
112
  retriever=retriever,
 
118
  # verbose=True,
119
  )
120
  progress(0.9, desc="Done!")
121
+ return qa_chain
122
 
123
 
124
+ # Initialize database
125
  def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
126
  # Create list of documents (when valid)
127
  #file_path = file_obj.name
 
136
  vector_db = create_db(doc_splits)
137
  progress(0.9, desc="Done!")
138
  return vector_db, "Complete!"
 
139
 
140
 
141
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
142
  print("llm_option",llm_option)
143
  llm_name = list_llm[llm_option]
144
  print("llm_name",llm_name)
145
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
146
+ return qa_chain, "Complete!"
 
147
 
148
 
149
  def format_chat_history(message, chat_history):
 
154
  return formatted_chat_history
155
 
156
 
157
+ def conversation(qa_chain, message, history):
158
  formatted_chat_history = format_chat_history(message, history)
159
  #print("formatted_chat_history",formatted_chat_history)
160
 
 
173
  # Append user message and response to chat history
174
  new_history = history + [(message, response_answer)]
175
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
176
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page
177
 
178
 
179
  def upload_file(file_obj):
 
189
  def demo():
190
  with gr.Blocks(theme="base") as demo:
191
  vector_db = gr.State()
192
+ qa_chain = gr.State()
193
 
194
  gr.Markdown(
195
  """<center><h2>PDF-based chatbot (powered by LangChain and open-source LLMs)</center></h2>
 
249
  outputs=[vector_db, db_progress])
250
  qachain_btn.click(initialize_LLM, \
251
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
252
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \
253
  inputs=None, \
254
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page], \
255
  queue=False)
256
 
257
  # Chatbot events
258
  msg.submit(conversation, \
259
+ inputs=[qa_chain, msg, chatbot], \
260
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], \
261
  queue=False)
262
  submit_btn.click(conversation, \
263
+ inputs=[qa_chain, msg, chatbot], \
264
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], \
265
  queue=False)
266
  clear_btn.click(lambda:[None,"",0,"",0], \
267
  inputs=None, \