nnngoc commited on
Commit
97f536f
1 Parent(s): f7739ba
Files changed (1) hide show
  1. rag.py +183 -130
rag.py CHANGED
@@ -1,63 +1,12 @@
1
  from langchain.embeddings import HuggingFaceEmbeddings
2
- from langchain.vectorstores import Chroma
3
- from langchain.document_loaders import TextLoader, DirectoryLoader
4
- import os
5
- import re
6
- from sentence_transformers.cross_encoder import CrossEncoder
7
- import numpy as np
8
- from langchain.schema.retriever import BaseRetriever, Document
9
- from typing import List
10
- from langchain.callbacks.manager import CallbackManagerForRetrieverRun
11
- from langchain.vectorstores import VectorStore
12
- from llm import URALLM
13
  from langchain.prompts import PromptTemplate
 
14
 
15
- # Get role for passage document
16
- def get_role(document):
17
- """
18
- Get role for student.
19
- """
20
- # Tìm kiếm các từ khóa liên quan đến vai trò học viên trong document.
21
- keywords = [
22
- "sinh viên",
23
- "đại học",
24
- "học viên",
25
- "thạc sĩ",
26
- "nghiên cứu sinh",
27
- "tiến sĩ",
28
- ]
29
- role = []
30
- for keyword in keywords:
31
- if keyword in document.metadata['source'].lower():
32
- role.append(keyword)
33
- return ", ".join(role)
34
-
35
- def processing_data(data_path):
36
- folders = os.listdir(data_path)
37
-
38
- dir_loaders = []
39
-
40
- # Add the documents to the project
41
- for folder in folders:
42
- dir_loader = DirectoryLoader((os.path.join(data_path, folder)), loader_cls=TextLoader)
43
- dir_loaders.append(dir_loader)
44
-
45
- # Load the text files.
46
- loaded_documents = []
47
- for dir_loader in dir_loaders:
48
- loaded_documents.append(dir_loader.load())
49
-
50
- data = []
51
- for i in range(len(loaded_documents)):
52
- for j in range(len(loaded_documents[i])):
53
- data.append(loaded_documents[i][j])
54
-
55
- # Final data prepare for vector database
56
- for document in data:
57
- role = get_role(document)
58
- document.metadata['role'] = role
59
-
60
- return data
61
 
62
  # Embedding model
63
  embedding = HuggingFaceEmbeddings(
@@ -65,96 +14,200 @@ embedding = HuggingFaceEmbeddings(
65
  model_kwargs={"device": "cpu"}
66
  )
67
 
68
- # embedding = HuggingFaceEmbeddings(
69
- # model_name="sentence-transformers/all-MiniLM-L6-v2",
70
- # model_kwargs={"device": "cpu"}
71
- # )
72
-
73
- # Vector database
74
- data_path = 'raw_data'
75
- persist_directory = 'vector_db'
76
- vectordb = Chroma.from_documents(
77
- documents=processing_data(data_path),
78
- embedding=embedding,
79
- persist_directory=persist_directory
80
- )
81
 
82
- class CustomRetriever(BaseRetriever):
83
- vectorstores:Chroma
84
- retriever:vectordb.as_retriever()
85
 
86
- def _get_relevant_documents(
87
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
88
- ) -> List[Document]:
89
- # Use your existing retriever to get the documents
90
- documents = self.retriever.get_relevant_documents(query, callbacks=run_manager.get_child())
91
 
92
- # Get page content
93
- docs_content = []
94
- for i in range(len(documents)):
95
- docs_content.append(documents[i].page_content)
96
 
97
- model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
98
 
99
- # So we create the respective sentence combinations
100
- sentence_combinations = [[query, document] for document in docs_content]
 
 
 
 
 
101
 
102
- # Compute the similarity scores for these combinations
103
- similarity_scores = model.predict(sentence_combinations)
104
 
105
- # Sort the scores in decreasing order
106
- sim_scores_argsort = reversed(np.argsort(similarity_scores))
107
 
108
- # Store the rerank document in new list
109
- docs = []
110
- for idx in sim_scores_argsort:
111
- docs.append(documents[idx])
112
 
113
- docs_top_4 = docs[0:4]
 
114
 
115
- return docs_top_4
116
-
117
- llm = URALLM()
118
- custom_retriever = CustomRetriever(vectorstores = vectordb,retriever = vectordb.as_retriever(search_kwargs={"k": 50}))
119
 
120
- # Build prompt
121
- template = """[INST] <<SYS>>
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- Bạn là một chatbot hỗ trợ các quy định học vụ của trường Đại học Bách Khoa - ĐHQG TP.HCM.
124
- Bạn sử dụng văn bản được cung cấp để trả lời câu hỏi cho người dùng.
125
- Không sử dụng bất kỳ thông tin nào khác ngoài văn bản đã cho.
126
- Trả lời đầy đủ và ngắn gọn nhất có thể.
127
- Không đề cập tên riêng trong câu trả lời.
128
- Không chứa các ký tự: "-/", "]", "/", "-" trong câu trả lời.
 
 
 
 
129
 
130
- <</SYS>>
131
 
132
- Văn bản: {context}
133
  Câu hỏi: {question}
 
134
  Câu trả lời:
135
- [/INST]"""
136
- QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template,)
137
-
138
- # Run chain
139
- from langchain.chains import RetrievalQA
140
-
141
- qa_chain = RetrievalQA.from_chain_type(llm,
142
- verbose=False,
143
- # retriever=vectordb.as_retriever(),
144
- retriever=custom_retriever,
145
- return_source_documents=True,
146
- chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
147
-
148
- def remove_special_characters(text):
149
- text = text.replace('].', '')
150
- text = text.replace('/.', '')
151
- text = text.replace('/.-', '')
152
- text = text.replace('-', '')
153
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  def rag(question: str) -> str:
156
- # call QA chain
157
- response = qa_chain({"query": question})
158
 
159
- return remove_special_characters(response["result"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
 
1
  from langchain.embeddings import HuggingFaceEmbeddings
 
 
 
 
 
 
 
 
 
 
 
2
  from langchain.prompts import PromptTemplate
3
+ from utility import load_data, process_data, CustomRetriever
4
 
5
+
6
+ data1 = load_data('raw_data/sv')
7
+ data2 = load_data('raw_data/thacsi')
8
+ data3 = load_data('raw_data/tiensi')
9
+ data = data1 + data2 + data3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Embedding model
12
  embedding = HuggingFaceEmbeddings(
 
14
  model_kwargs={"device": "cpu"}
15
  )
16
 
17
+ # The splitter to use to create smaller chunks
18
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
 
 
21
 
22
+ #####################################################################
 
 
 
 
23
 
24
+ vectorstore1, retriever1 = process_data(data, child_text_splitter, embedding, "data")
25
+ vectorstore2, retriever2 = process_data(data2, child_text_splitter, embedding, "data2")
26
+ vectorstore3, retriever3 = process_data(data3, child_text_splitter, embedding, "data3")
 
27
 
28
+ ##############################################################################
29
 
30
+ ANYSCALE_API_BASE = "credential-1711634141163"
31
+ ANYSCALE_API_KEY = "esecret_chitz7splr5ut6vfvqpn72itd3"
32
+ ANYSCALE_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
33
+ # ANYSCALE_MODEL_NAME = "meta-llama/Llama-3-8b-chat-hf"
34
+ # ANYSCALE_MODEL_NAME = "google/gemma-7b-it"
35
+ # ANYSCALE_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
36
+ # ANYSCALE_MODEL_NAME = "mistralai/Mixtral-8x7B-Instruct-v0.1"
37
 
38
+ import os
 
39
 
40
+ os.environ["ANYSCALE_API_BASE"] = ANYSCALE_API_BASE
41
+ os.environ["ANYSCALE_API_KEY"] = ANYSCALE_API_KEY
42
 
43
+ from langchain.chains import LLMChain
44
+ from langchain_community.llms import Anyscale
45
+ from langchain_core.prompts import PromptTemplate
46
+ from langchain_community.chat_models import ChatAnyscale
47
 
48
+ # llm = Anyscale(model_name=ANYSCALE_MODEL_NAME)
49
+ llm= ChatAnyscale(model_name=ANYSCALE_MODEL_NAME, temperature=0)
50
 
51
+ #####################################################################
 
 
 
52
 
53
+ from langchain_openai.llms.azure import AzureOpenAI
54
+ llm_openai = AzureOpenAI(
55
+ deployment_name="gpt-35-turbo-instruct",
56
+ # deployment_name="gpt-35-turbo-16k",
57
+ api_key = 'c90c0e7fb1894a898c56123580a6ee3e',
58
+ api_version = "2023-09-15-preview",
59
+ azure_endpoint = "https://bkchatbot.openai.azure.com/",
60
+ temperature=0.0,
61
+ max_tokens=500
62
+ )
63
+
64
+ ##########################################################################
65
+
66
+ from langchain_core.output_parsers import StrOutputParser
67
+ from langchain_core.runnables import RunnablePassthrough
68
 
69
+ def format_docs(docs):
70
+ return "\n\n".join(doc.page_content for doc in docs)
71
+
72
+ # Build prompt
73
+ from langchain.prompts import PromptTemplate
74
+ template ="""
75
+ Trả lời câu hỏi dựa trên những quy định được cung cấp, tổng hợp thông tin và đưa ra câu trả lời ngắn gọn và đầy đủ cuối cùng.
76
+ Không cần ghi chú và trích dẫn nguồn thông tin đã tham khảo trong câu trả lời.
77
+ Câu trả lời nên bắt đầu bằng: "Theo quy định của Trường ĐH Bách Khoa Tp.HCM, ..."
78
+ Nếu trong quy văn bản không có thông tin cho câu trả lời, vui lòng thông báo: "Xin lỗi, tôi không có thông tin cho câu hỏi này!"
79
 
80
+ Quy định: {context}
81
 
 
82
  Câu hỏi: {question}
83
+
84
  Câu trả lời:
85
+ """
86
+ QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template)
87
+
88
+ #############################################################################
89
+
90
+ from langchain_core.runnables import RunnableParallel
91
+
92
+ rag_chain_from_docs = (
93
+ RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
94
+ | QA_CHAIN_PROMPT
95
+ | llm
96
+ | StrOutputParser()
97
+ )
98
+
99
+ ###############################################################################
100
+
101
+ from langchain.prompts import ChatPromptTemplate
102
+
103
+ # Multi Query: Different Perspectives
104
+ template = """
105
+ ### Hãy tạo ra thêm các truy vấn tìm kiếm tương đương ngữ nghĩa với một câu hỏi ban đầu.
106
+ Kết quả hiển thị dạng list gồm câu hỏi ban đầu và 2 câu hỏi thay thế.
107
+
108
+ ### Câu hỏi ban đầu: {question}
109
+ ### Kết quả:
110
+
111
+ """
112
+ prompt_perspectives = ChatPromptTemplate.from_template(template)
113
+
114
+ from langchain_core.output_parsers import StrOutputParser
115
+ # from langchain_openai import ChatOpenAI
116
+
117
+ generate_queries = (
118
+ prompt_perspectives
119
+ | llm_openai
120
+ | StrOutputParser()
121
+ | (lambda x: x.split("\n"))
122
+ )
123
+
124
+ #########################################################################################
125
+
126
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
127
+
128
+ # initialize the bm25 retriever and chroma retriever
129
+ bm25_retriever1 = BM25Retriever.from_documents(data, k=25)
130
+ ensemble_retriever1 = EnsembleRetriever(retrievers=[bm25_retriever1, retriever1], weights=[0.5, 0.5])
131
+
132
+ bm25_retriever2 = BM25Retriever.from_documents(data2, k=25)
133
+ ensemble_retriever2 = EnsembleRetriever(retrievers=[bm25_retriever2, retriever2], weights=[0.5, 0.5])
134
+
135
+ bm25_retriever3 = BM25Retriever.from_documents(data3, k=25)
136
+ ensemble_retriever3 = EnsembleRetriever(retrievers=[bm25_retriever3, retriever3], weights=[0.5, 0.5])
137
+
138
+ #########################################################################################
139
+
140
+ custom_retriever1 = CustomRetriever(retriever = ensemble_retriever1)
141
+ custom_retriever2 = CustomRetriever(retriever = ensemble_retriever2)
142
+ custom_retriever3 = CustomRetriever(retriever = ensemble_retriever3)
143
+
144
+ multiq_chain1 = generate_queries | custom_retriever1
145
+ multiq_chain2 = generate_queries | custom_retriever2
146
+ multiq_chain3 = generate_queries | custom_retriever3
147
+
148
+ rag_chain_with_source1 = RunnableParallel(
149
+ {"context": multiq_chain1, "question": RunnablePassthrough()}
150
+ ).assign(answer=rag_chain_from_docs)
151
+
152
+ rag_chain_with_source2 = RunnableParallel(
153
+ {"context": multiq_chain2 , "question": RunnablePassthrough()}
154
+ ).assign(answer=rag_chain_from_docs)
155
+
156
+ rag_chain_with_source3 = RunnableParallel(
157
+ {"context": multiq_chain3, "question": RunnablePassthrough()}
158
+ ).assign(answer=rag_chain_from_docs)
159
+
160
+ ############################################################################################
161
+
162
+ from flashtext import KeywordProcessor
163
+ keyword_processor = KeywordProcessor()
164
+ # keyword_processor.add_keyword(<unclean name>, <standardised name>)
165
+ keyword_processor.add_keyword('thạc sĩ')
166
+ keyword_processor.add_keyword('học viên')
167
+ keyword_processor.add_keyword('nghiên cứu sinh')
168
+ keyword_processor.add_keyword('tiến sĩ')
169
+
170
+ ################################################################################
171
+
172
+ rag_chain = [rag_chain_with_source1, rag_chain_with_source2, rag_chain_with_source3]
173
+
174
+ ###################################################################################
175
 
176
  def rag(question: str) -> str:
 
 
177
 
178
+ keywords_found = keyword_processor.extract_keywords(question)
179
+ if 'thạc sĩ' in keywords_found or 'học viên' in keywords_found:
180
+ response = rag_chain[1].invoke(question)
181
+ elif 'nghiên cứu sinh' in keywords_found or 'tiến sĩ' in keywords_found:
182
+ response = rag_chain[2].invoke(question)
183
+ else:
184
+ response = rag_chain[0].invoke(question)
185
+
186
+ return response['answer']
187
+
188
+ ###################################################################################
189
+
190
+
191
+ # # Run chain
192
+ # from langchain.chains import RetrievalQA
193
+
194
+ # qa_chain = RetrievalQA.from_chain_type(llm,
195
+ # verbose=False,
196
+ # # retriever=vectordb.as_retriever(),
197
+ # retriever=custom_retriever,
198
+ # return_source_documents=True,
199
+ # chain_type_kwargs={"prompt": QA_CHAIN_PROMPT})
200
+
201
+ # def remove_special_characters(text):
202
+ # text = text.replace('].', '')
203
+ # text = text.replace('/.', '')
204
+ # text = text.replace('/.-', '')
205
+ # text = text.replace('-', '')
206
+ # return text
207
+
208
+ # def rag(question: str) -> str:
209
+ # # call QA chain
210
+ # response = qa_chain({"query": question})
211
+
212
+ # return remove_special_characters(response["result"])
213