nkcong206 commited on
Commit
d5ac512
·
1 Parent(s): 33afc2e
Files changed (2) hide show
  1. app.py +129 -72
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,15 +2,17 @@ import streamlit as st
2
  import os
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain_core.prompts import ChatPromptTemplate
5
- from langchain_community.document_loaders import TextLoader
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from langchain.prompts import PromptTemplate
8
 
9
  from langchain_core.output_parsers import StrOutputParser
10
 
11
  from langchain_core.runnables import RunnablePassthrough
12
- from langchain_chroma import Chroma
13
  import Raptor
 
 
 
14
 
15
  page = st.title("Chat with AskUSTH")
16
 
@@ -23,6 +25,8 @@ if "rag" not in st.session_state:
23
  if "llm" not in st.session_state:
24
  st.session_state.llm = None
25
 
 
 
26
  @st.cache_resource
27
  def get_chat_google_model(api_key):
28
  os.environ["GOOGLE_API_KEY"] = api_key
@@ -50,6 +54,27 @@ def get_embedding_model():
50
  if "embd" not in st.session_state:
51
  st.session_state.embd = get_embedding_model()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if "model" not in st.session_state:
54
  st.session_state.model = None
55
 
@@ -77,64 +102,12 @@ if st.session_state.gemini_api is None:
77
  if st.session_state.gemini_api and st.session_state.model is None:
78
  st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
79
 
80
- if st.session_state.save_dir is None:
81
- save_dir = "./Documents"
82
- if not os.path.exists(save_dir):
83
- os.makedirs(save_dir)
84
- st.session_state.save_dir = save_dir
85
-
86
- def load_txt(file_path):
87
- loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
88
- doc = loader_sv.load()
89
- return doc
90
-
91
- with st.sidebar:
92
- uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
93
- if st.session_state.gemini_api:
94
- if uploaded_files:
95
- documents = []
96
- uploaded_file_names = set()
97
- new_docs = False
98
- for uploaded_file in uploaded_files:
99
- uploaded_file_names.add(uploaded_file.name)
100
- if uploaded_file.name not in st.session_state.uploaded_files:
101
- file_path = os.path.join(st.session_state.save_dir, uploaded_file.name)
102
- with open(file_path, mode='wb') as w:
103
- w.write(uploaded_file.getvalue())
104
- else:
105
- continue
106
-
107
- new_docs = True
108
-
109
- doc = load_txt(file_path)
110
-
111
- documents.extend([*doc])
112
-
113
- if new_docs:
114
- st.session_state.uploaded_files = uploaded_file_names
115
- st.session_state.rag = None
116
- else:
117
- st.session_state.uploaded_files = set()
118
- st.session_state.rag = None
119
-
120
  def format_docs(docs):
121
  return "\n\n".join(doc.page_content for doc in docs)
122
 
123
  @st.cache_resource
124
- def compute_rag_chain(_model, _embd, docs_texts):
125
- results = Raptor.recursive_embed_cluster_summarize(_model, _embd, docs_texts, level=1, n_levels=3)
126
- all_texts = docs_texts.copy()
127
- i = 0
128
- for level in sorted(results.keys()):
129
- summaries = results[level][1]["summaries"].tolist()
130
- all_texts.extend(summaries)
131
- print(f"summary {i} -------------------------------------------------")
132
- print(summaries)
133
- i += 1
134
- print("all_texts ______________________________________")
135
- print(all_texts)
136
- vectorstore = Chroma.from_texts(texts=all_texts, embedding=_embd)
137
- retriever = vectorstore.as_retriever()
138
  template = """
139
  Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
140
  Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
@@ -145,24 +118,105 @@ def compute_rag_chain(_model, _embd, docs_texts):
145
  {question}
146
  """
147
  prompt = PromptTemplate(template=template, input_variables=["context", "question"])
148
- rag_chain = (
149
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
150
  | prompt
151
  | _model
152
  | StrOutputParser()
153
  )
154
- return rag_chain
155
 
156
- @st.dialog("Setup RAG")
157
- def load_rag():
158
- docs_texts = [d.page_content for d in documents]
159
- st.session_state.rag = compute_rag_chain(st.session_state.model, st.session_state.embd, docs_texts)
160
- st.rerun()
161
 
162
- if st.session_state.uploaded_files and st.session_state.model is not None:
163
- if st.session_state.rag is None:
164
- load_rag()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  if st.session_state.model is not None:
167
  if st.session_state.llm is None:
168
  mess = ChatPromptTemplate.from_messages(
@@ -193,13 +247,16 @@ if st.session_state.model is not None:
193
  st.write(prompt)
194
 
195
  with st.chat_message("assistant"):
196
- if st.session_state.rag is not None:
 
 
197
  respone = st.session_state.rag.invoke(prompt)
198
- st.write(respone)
 
199
  else:
200
- ans = st.session_state.llm.invoke(prompt)
201
- respone = ans.content
202
- st.write(respone)
203
 
204
- st.session_state.chat_history.append({"role": "assistant", "content": respone})
205
 
 
2
  import os
3
  from langchain_google_genai import ChatGoogleGenerativeAI
4
  from langchain_core.prompts import ChatPromptTemplate
 
5
  from langchain_huggingface import HuggingFaceEmbeddings
6
  from langchain.prompts import PromptTemplate
7
 
8
  from langchain_core.output_parsers import StrOutputParser
9
 
10
  from langchain_core.runnables import RunnablePassthrough
11
+ from langchain_qdrant import QdrantVectorStore
12
  import Raptor
13
+ from io import StringIO
14
+ from qdrant_client import QdrantClient
15
+ from qdrant_client.models import Distance, VectorParams
16
 
17
  page = st.title("Chat with AskUSTH")
18
 
 
25
  if "llm" not in st.session_state:
26
  st.session_state.llm = None
27
 
28
+
29
+
30
  @st.cache_resource
31
  def get_chat_google_model(api_key):
32
  os.environ["GOOGLE_API_KEY"] = api_key
 
54
  if "embd" not in st.session_state:
55
  st.session_state.embd = get_embedding_model()
56
 
57
+ @st.cache_resource
58
+ def load_chromadb(collection_name):
59
+ client = QdrantClient(
60
+ url="https://da9fadd2-dc5a-4481-ac79-4e2677a2354b.europe-west3-0.gcp.cloud.qdrant.io",
61
+ api_key="X_-IVToBM07Mot4Mmzg5xNjYzc1DlIgl0VQDUNmGhI_Z-WA5FJ2ETA"
62
+ )
63
+
64
+ client.recreate_collection(
65
+ collection_name=collection_name,
66
+ vectors_config=VectorParams(size=768, distance=Distance.COSINE)
67
+ )
68
+ db = QdrantVectorStore(
69
+ client=client,
70
+ collection_name=collection_name,
71
+ embedding=st.session_state.embd,
72
+ )
73
+ return db
74
+
75
+ if "vector_store" not in st.session_state:
76
+ st.session_state.vector_store = load_chromadb("data")
77
+
78
  if "model" not in st.session_state:
79
  st.session_state.model = None
80
 
 
102
  if st.session_state.gemini_api and st.session_state.model is None:
103
  st.session_state.model = get_chat_google_model(st.session_state.gemini_api)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def format_docs(docs):
106
  return "\n\n".join(doc.page_content for doc in docs)
107
 
108
  @st.cache_resource
109
+ def rag_chain(_model, _vectorstore):
110
+ retriever = _vectorstore.as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
111
  template = """
112
  Bạn là một trợ lí AI hỗ trợ tuyển sinh và sinh viên. \n
113
  Hãy trả lời câu hỏi chính xác, tập trung vào thông tin liên quan đến câu hỏi. \n
 
118
  {question}
119
  """
120
  prompt = PromptTemplate(template=template, input_variables=["context", "question"])
121
+ rag = (
122
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
123
  | prompt
124
  | _model
125
  | StrOutputParser()
126
  )
127
+ return rag
128
 
129
+ if st.session_state.model is not None and st.session_state.vector_store is not None:
130
+ st.session_state.rag = rag_chain(st.session_state.model, st.session_state.vector_store)
 
 
 
131
 
132
+ # if st.session_state.save_dir is None:
133
+ # save_dir = "./Documents"
134
+ # if not os.path.exists(save_dir):
135
+ # os.makedirs(save_dir)
136
+ # st.session_state.save_dir = save_dir
137
+
138
+ # def load_txt(file_path):
139
+ # loader_sv = TextLoader(file_path=file_path, encoding="utf-8")
140
+ # doc = loader_sv.load()
141
+ # return doc
142
+
143
+ if "new_docs" not in st.session_state:
144
+ st.session_state.new_docs = False
145
+
146
+ with st.sidebar:
147
+ uploaded_files = st.file_uploader("Chọn file txt", accept_multiple_files=True, type=["txt"])
148
+ if st.session_state.model:
149
+ documents = []
150
+ uploaded_file_names = set()
151
+ if uploaded_files:
152
+ for uploaded_file in uploaded_files:
153
+ uploaded_file_names.add(uploaded_file.name)
154
+ if uploaded_file_names != st.session_state.uploaded_files and not st.session_state.new_docs:
155
+ st.session_state.uploaded_files = uploaded_file_names
156
+ st.session_state.new_docs = True
157
+ if uploaded_files:
158
+ for uploaded_file in uploaded_files:
159
+ stringio=StringIO(uploaded_file.getvalue().decode('utf-8'))
160
+ read_data=str(stringio.read())
161
+ documents.append(read_data)
162
+
163
+ def update_rag_chain(_model, _embd, _vectorstore, docs_texts):
164
+ results = Raptor.recursive_embed_cluster_summarize(_model, _embd, docs_texts, level=1, n_levels=3)
165
+ all_texts = docs_texts.copy()
166
+ for level in sorted(results.keys()):
167
+ summaries = results[level][1]["summaries"].tolist()
168
+ all_texts.extend(summaries)
169
+ _vectorstore.reset_collection()
170
+ _vectorstore.add_texts(texts=all_texts)
171
+ rag = rag_chain(_model, _vectorstore)
172
+ return rag
173
+
174
+ def reset_rag_chain(_model, _vectorstore):
175
+ _vectorstore.reset_collection()
176
+ rag = rag_chain(_model, _vectorstore)
177
+ return rag
178
+
179
+ if "query_router" not in st.session_state:
180
+ st.session_state.query_router = None
181
+
182
+ @st.cache_resource
183
+ def query_router(_model):
184
+ mess = ChatPromptTemplate.from_messages(
185
+ [
186
+ (
187
+ "system",
188
+ """Bạn là một chatbot hỗ trợ giải đáp về đại học, nhiệm vụ của bạn là phân loại câu hỏi.
189
+ Nếu câu hỏi về đại học thì trả về 'university', nếu không liên quan tới tuyển sinh và sinh viên thì trả về 'other'.
190
+ Bắt buộc Kết quả chỉ trả về với một trong hai lựa chọn trên.
191
+ Không được trả lời thêm bất kỳ thông tin nào khác.""",
192
+ ),
193
+ ("human", "{input}"),
194
+ ]
195
+ )
196
+ chain = mess | _model
197
+ return chain
198
 
199
+ if st.session_state.model is not None:
200
+ st.session_state.query_router = query_router(st.session_state.model)
201
+
202
+ @st.dialog("Update DB")
203
+ def update_vectorstore(_model, _embd, _vectorstore, docs):
204
+ docs_texts = [d for d in docs]
205
+ st.session_state.rag = update_rag_chain(_model, _embd, _vectorstore, docs_texts)
206
+ st.rerun()
207
+
208
+ @st.dialog("Reset DB")
209
+ def reset_vectorstore(_model, _vectorstore):
210
+ st.session_state.rag = reset_rag_chain(_model, _vectorstore)
211
+ st.rerun()
212
+
213
+ if st.session_state.new_docs:
214
+ st.session_state.new_docs = False
215
+ if st.session_state.uploaded_files:
216
+ update_vectorstore(st.session_state.model, st.session_state.embd, st.session_state.vector_store, documents)
217
+ else:
218
+ reset_vectorstore(st.session_state.model, st.session_state.vector_store)
219
+
220
  if st.session_state.model is not None:
221
  if st.session_state.llm is None:
222
  mess = ChatPromptTemplate.from_messages(
 
247
  st.write(prompt)
248
 
249
  with st.chat_message("assistant"):
250
+ router = st.session_state.query_router.invoke(prompt)
251
+ switch = router.content
252
+ if "university" in switch:
253
  respone = st.session_state.rag.invoke(prompt)
254
+ f_response = f"RAG: {respone}"
255
+ st.write(f_response)
256
  else:
257
+ respone = st.session_state.llm.invoke(prompt)
258
+ f_response = f"other: {respone.content}"
259
+ st.write(f_response)
260
 
261
+ st.session_state.chat_history.append({"role": "assistant", "content": f_response})
262
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ langchain-community
4
  langchain-huggingface
5
  umap-learn
6
  scikit-learn
7
- langchain-chroma
 
 
4
  langchain-huggingface
5
  umap-learn
6
  scikit-learn
7
+ langchain-qdrant
8
+ qdrant-client