lfoppiano commited on
Commit
99f35d8
Β·
1 Parent(s): 137e5e2

add conversational memory

Browse files
Files changed (1) hide show
  1. streamlit_app.py +22 -2
streamlit_app.py CHANGED
@@ -6,6 +6,7 @@ from tempfile import NamedTemporaryFile
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
  from langchain.llms.huggingface_hub import HuggingFaceHub
 
9
 
10
  dotenv.load_dotenv(override=True)
11
 
@@ -51,6 +52,9 @@ if 'ner_processing' not in st.session_state:
51
  if 'uploaded' not in st.session_state:
52
  st.session_state['uploaded'] = False
53
 
 
 
 
54
  st.set_page_config(
55
  page_title="Scientific Document Insights Q/A",
56
  page_icon="πŸ“",
@@ -67,6 +71,11 @@ def new_file():
67
  st.session_state['loaded_embeddings'] = None
68
  st.session_state['doc_id'] = None
69
  st.session_state['uploaded'] = True
 
 
 
 
 
70
 
71
 
72
  # @st.cache_resource
@@ -169,7 +178,7 @@ with st.sidebar:
169
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
170
 
171
  st.markdown(
172
- ":warning: Mistral and Zephyr are free to use, however requests might hit limits of the huggingface free API and fail. :warning: ")
173
 
174
  if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
175
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
@@ -206,6 +215,11 @@ with st.sidebar:
206
  # else:
207
  # is_api_key_provided = st.session_state['api_key']
208
 
 
 
 
 
 
209
  st.title("πŸ“ Scientific Document Insights Q/A")
210
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
211
 
@@ -298,7 +312,8 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
298
  elif mode == "LLM":
299
  with st.spinner("Generating response..."):
300
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
301
- context_size=context_size)
 
302
 
303
  if not text_response:
304
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
@@ -317,5 +332,10 @@ if st.session_state.loaded_embeddings and question and len(question) > 0 and st.
317
  st.write(text_response)
318
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
319
 
 
 
 
 
 
320
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
321
  play_old_messages()
 
6
  import dotenv
7
  from grobid_quantities.quantities import QuantitiesAPI
8
  from langchain.llms.huggingface_hub import HuggingFaceHub
9
+ from langchain.memory import ConversationBufferWindowMemory
10
 
11
  dotenv.load_dotenv(override=True)
12
 
 
52
  if 'uploaded' not in st.session_state:
53
  st.session_state['uploaded'] = False
54
 
55
+ if 'memory' not in st.session_state:
56
+ st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
57
+
58
  st.set_page_config(
59
  page_title="Scientific Document Insights Q/A",
60
  page_icon="πŸ“",
 
71
  st.session_state['loaded_embeddings'] = None
72
  st.session_state['doc_id'] = None
73
  st.session_state['uploaded'] = True
74
+ st.session_state['memory'].clear()
75
+
76
+
77
+ def clear_memory():
78
+ st.session_state['memory'].clear()
79
 
80
 
81
  # @st.cache_resource
 
178
  disabled=st.session_state['doc_id'] is not None or st.session_state['uploaded'])
179
 
180
  st.markdown(
181
+ ":warning: Mistral and Zephyr are **FREE** to use. Requests might fail anytime. Use at your own risk. :warning: ")
182
 
183
  if (model == 'mistral-7b-instruct-v0.1' or model == 'zephyr-7b-beta') and model not in st.session_state['api_keys']:
184
  if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ:
 
215
  # else:
216
  # is_api_key_provided = st.session_state['api_key']
217
 
218
+ st.button(
219
+ 'Reset chat memory.',
220
+ on_click=clear_memory(),
221
+ help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.")
222
+
223
  st.title("πŸ“ Scientific Document Insights Q/A")
224
  st.subheader("Upload a scientific article in PDF, ask questions, get insights.")
225
 
 
312
  elif mode == "LLM":
313
  with st.spinner("Generating response..."):
314
  _, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
315
+ context_size=context_size,
316
+ memory=st.session_state.memory)
317
 
318
  if not text_response:
319
  st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
 
332
  st.write(text_response)
333
  st.session_state.messages.append({"role": "assistant", "mode": mode, "content": text_response})
334
 
335
+ for id in range(0, len(st.session_state.messages), 2):
336
+ question = st.session_state.messages[id]['content']
337
+ answer = st.session_state.messages[id + 1]['content']
338
+ st.session_state.memory.save_context({"input": question}, {"output": answer})
339
+
340
  elif st.session_state.loaded_embeddings and st.session_state.doc_id:
341
  play_old_messages()