Spaces:
Sleeping
Sleeping
kartavya23
commited on
Upload 4 files
Browse files- rag_101/__pycache__/retriever.cpython-39.pyc +0 -0
- rag_101/client.py +61 -0
- rag_101/rag.py +52 -0
- rag_101/retriever.py +160 -0
rag_101/__pycache__/retriever.cpython-39.pyc
ADDED
Binary file (4.88 kB). View file
|
|
rag_101/client.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.callbacks import FileCallbackHandler
|
2 |
+
from langchain_community.chat_models import ChatOllama
|
3 |
+
from langchain_core.output_parsers import StrOutputParser
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
from rag_101.retriever import (
|
8 |
+
RAGException,
|
9 |
+
create_parent_retriever,
|
10 |
+
load_embedding_model,
|
11 |
+
load_pdf,
|
12 |
+
load_reranker_model,
|
13 |
+
retrieve_context,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class RAGClient:
|
18 |
+
embedding_model = load_embedding_model()
|
19 |
+
reranker_model = load_reranker_model()
|
20 |
+
|
21 |
+
def __init__(self, files, model="mistral"):
|
22 |
+
docs = load_pdf(files=files)
|
23 |
+
self.retriever = create_parent_retriever(docs, self.embedding_model)
|
24 |
+
|
25 |
+
llm = ChatOllama(model=model)
|
26 |
+
prompt_template = ChatPromptTemplate.from_template(
|
27 |
+
(
|
28 |
+
"Please answer the following question based on the provided `context` that follows the question.\n"
|
29 |
+
"Think step by step before coming to answer. If you do not know the answer then just say 'I do not know'\n"
|
30 |
+
"question: {question}\n"
|
31 |
+
"context: ```{context}```\n"
|
32 |
+
)
|
33 |
+
)
|
34 |
+
self.chain = prompt_template | llm | StrOutputParser()
|
35 |
+
|
36 |
+
def stream(self, query: str) -> dict:
|
37 |
+
try:
|
38 |
+
context, similarity_score = self.retrieve_context(query)[0]
|
39 |
+
context = context.page_content
|
40 |
+
if similarity_score < 0.005:
|
41 |
+
context = "This context is not confident. " + context
|
42 |
+
except RAGException as e:
|
43 |
+
context, similarity_score = e.args[0], 0
|
44 |
+
logger.info(context)
|
45 |
+
for r in self.chain.stream({"context": context, "question": query}):
|
46 |
+
yield r
|
47 |
+
|
48 |
+
def retrieve_context(self, query: str):
|
49 |
+
return retrieve_context(
|
50 |
+
query, retriever=self.retriever, reranker_model=self.reranker_model
|
51 |
+
)
|
52 |
+
|
53 |
+
def generate(self, query: str) -> dict:
|
54 |
+
contexts = self.retrieve_context(query)
|
55 |
+
|
56 |
+
return {
|
57 |
+
"contexts": contexts,
|
58 |
+
"response": self.chain.invoke(
|
59 |
+
{"context": contexts[0][0].page_content, "question": query}
|
60 |
+
),
|
61 |
+
}
|
rag_101/rag.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import List, Optional, Union
|
3 |
+
|
4 |
+
from langchain_community.chat_models import ChatOllama
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from retriever import (
|
8 |
+
create_parent_retriever,
|
9 |
+
load_embedding_model,
|
10 |
+
load_pdf,
|
11 |
+
load_reranker_model,
|
12 |
+
retrieve_context,
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def main(
|
17 |
+
file: str = "2401.08406v3.pdf",
|
18 |
+
llm_name="mistral",
|
19 |
+
):
|
20 |
+
docs = load_pdf(files=file)
|
21 |
+
|
22 |
+
embedding_model = load_embedding_model()
|
23 |
+
retriever = create_parent_retriever(docs, embedding_model)
|
24 |
+
reranker_model = load_reranker_model()
|
25 |
+
|
26 |
+
llm = ChatOllama(model=llm_name)
|
27 |
+
prompt_template = ChatPromptTemplate.from_template(
|
28 |
+
(
|
29 |
+
"Please answer the following question based on the provided `context` that follows the question.\n"
|
30 |
+
"If you do not know the answer then just say 'I do not know'\n"
|
31 |
+
"question: {question}\n"
|
32 |
+
"context: ```{context}```\n"
|
33 |
+
)
|
34 |
+
)
|
35 |
+
chain = prompt_template | llm | StrOutputParser()
|
36 |
+
|
37 |
+
while True:
|
38 |
+
query = input("Ask question: ")
|
39 |
+
context = retrieve_context(
|
40 |
+
query, retriever=retriever, reranker_model=reranker_model
|
41 |
+
)[0]
|
42 |
+
print("LLM Response: ", end="")
|
43 |
+
for e in chain.stream({"context": context[0].page_content, "question": query}):
|
44 |
+
print(e, end="")
|
45 |
+
print()
|
46 |
+
time.sleep(0.1)
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
from jsonargparse import CLI
|
51 |
+
|
52 |
+
CLI(main)
|
rag_101/retriever.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["HF_HOME"] = "weights"
|
4 |
+
os.environ["TORCH_HOME"] = "weights"
|
5 |
+
|
6 |
+
from typing import List, Optional, Union
|
7 |
+
|
8 |
+
from langchain.callbacks import FileCallbackHandler
|
9 |
+
from langchain.retrievers import ContextualCompressionRetriever, ParentDocumentRetriever
|
10 |
+
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
11 |
+
from langchain.storage import InMemoryStore
|
12 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
13 |
+
from langchain_community.document_loaders import UnstructuredFileLoader
|
14 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
15 |
+
from langchain_community.vectorstores import FAISS, Chroma
|
16 |
+
from langchain_core.documents import Document
|
17 |
+
from loguru import logger
|
18 |
+
from rich import print
|
19 |
+
from sentence_transformers import CrossEncoder
|
20 |
+
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
|
21 |
+
|
22 |
+
logfile = "log/output.log"
|
23 |
+
logger.add(logfile, colorize=True, enqueue=True)
|
24 |
+
handler = FileCallbackHandler(logfile)
|
25 |
+
|
26 |
+
|
27 |
+
persist_directory = None
|
28 |
+
|
29 |
+
|
30 |
+
class RAGException(Exception):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
|
34 |
+
|
35 |
+
def rerank_docs(reranker_model, query, retrieved_docs):
|
36 |
+
query_and_docs = [(query, r.page_content) for r in retrieved_docs]
|
37 |
+
scores = reranker_model.predict(query_and_docs)
|
38 |
+
return sorted(list(zip(retrieved_docs, scores)), key=lambda x: x[1], reverse=True)
|
39 |
+
|
40 |
+
|
41 |
+
def load_pdf(
|
42 |
+
files: Union[str, List[str]] = "2401.08406v3.pdf"
|
43 |
+
) -> List[Document]:
|
44 |
+
if isinstance(files, str):
|
45 |
+
loader = UnstructuredFileLoader(
|
46 |
+
files,
|
47 |
+
post_processors=[clean_extra_whitespace, group_broken_paragraphs],
|
48 |
+
)
|
49 |
+
return loader.load()
|
50 |
+
|
51 |
+
loaders = [
|
52 |
+
UnstructuredFileLoader(
|
53 |
+
file,
|
54 |
+
post_processors=[clean_extra_whitespace, group_broken_paragraphs],
|
55 |
+
)
|
56 |
+
for file in files
|
57 |
+
]
|
58 |
+
docs = []
|
59 |
+
for loader in loaders:
|
60 |
+
docs.extend(
|
61 |
+
loader.load(),
|
62 |
+
)
|
63 |
+
return docs
|
64 |
+
|
65 |
+
|
66 |
+
def create_parent_retriever(
|
67 |
+
docs: List[Document], embeddings_model: HuggingFaceEmbeddings()
|
68 |
+
):
|
69 |
+
parent_splitter = RecursiveCharacterTextSplitter(
|
70 |
+
separators=["\n\n\n", "\n\n"],
|
71 |
+
chunk_size=2000,
|
72 |
+
length_function=len,
|
73 |
+
is_separator_regex=False,
|
74 |
+
)
|
75 |
+
|
76 |
+
# This text splitter is used to create the child documents
|
77 |
+
child_splitter = RecursiveCharacterTextSplitter(
|
78 |
+
separators=["\n\n\n", "\n\n"],
|
79 |
+
chunk_size=1000,
|
80 |
+
chunk_overlap=300,
|
81 |
+
length_function=len,
|
82 |
+
is_separator_regex=False,
|
83 |
+
)
|
84 |
+
# The vectorstore to use to index the child chunks
|
85 |
+
vectorstore = Chroma(
|
86 |
+
collection_name="split_documents",
|
87 |
+
embedding_function=embeddings_model,
|
88 |
+
persist_directory=persist_directory,
|
89 |
+
)
|
90 |
+
# The storage layer for the parent documents
|
91 |
+
store = InMemoryStore()
|
92 |
+
retriever = ParentDocumentRetriever(
|
93 |
+
vectorstore=vectorstore,
|
94 |
+
docstore=store,
|
95 |
+
child_splitter=child_splitter,
|
96 |
+
parent_splitter=parent_splitter,
|
97 |
+
k=10,
|
98 |
+
)
|
99 |
+
retriever.add_documents(docs)
|
100 |
+
return retriever
|
101 |
+
|
102 |
+
|
103 |
+
def retrieve_context(query, retriever, reranker_model):
|
104 |
+
retrieved_docs = retriever.get_relevant_documents(query)
|
105 |
+
|
106 |
+
if len(retrieved_docs) == 0:
|
107 |
+
raise RAGException(
|
108 |
+
f"Couldn't retrieve any relevant document with the query `{query}`. Try modifying your question!"
|
109 |
+
)
|
110 |
+
reranked_docs = rerank_docs(
|
111 |
+
query=query, retrieved_docs=retrieved_docs, reranker_model=reranker_model
|
112 |
+
)
|
113 |
+
return reranked_docs
|
114 |
+
|
115 |
+
|
116 |
+
def load_embedding_model(
|
117 |
+
model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cuda"
|
118 |
+
) -> HuggingFaceEmbeddings:
|
119 |
+
model_kwargs = {"device": device}
|
120 |
+
encode_kwargs = {
|
121 |
+
"normalize_embeddings": True
|
122 |
+
} # set True to compute cosine similarity
|
123 |
+
embedding_model = HuggingFaceEmbeddings(
|
124 |
+
model_name=model_name,
|
125 |
+
model_kwargs=model_kwargs,
|
126 |
+
encode_kwargs=encode_kwargs,
|
127 |
+
)
|
128 |
+
return embedding_model
|
129 |
+
|
130 |
+
|
131 |
+
def load_reranker_model(
|
132 |
+
reranker_model_name: str = "BAAI/bge-reranker-large", device: str = "cuda"
|
133 |
+
) -> CrossEncoder:
|
134 |
+
reranker_model = CrossEncoder(
|
135 |
+
model_name=reranker_model_name, max_length=1024, device=device
|
136 |
+
)
|
137 |
+
return reranker_model
|
138 |
+
|
139 |
+
|
140 |
+
def main(
|
141 |
+
file: str = "2401.08406v3.pdf",
|
142 |
+
query: Optional[str] = None,
|
143 |
+
llm_name="mistral",
|
144 |
+
):
|
145 |
+
docs = load_pdf(files=file)
|
146 |
+
|
147 |
+
embedding_model = load_embedding_model()
|
148 |
+
retriever = create_parent_retriever(docs, embedding_model)
|
149 |
+
reranker_model = load_reranker_model()
|
150 |
+
|
151 |
+
context = retrieve_context(
|
152 |
+
query, retriever=retriever, reranker_model=reranker_model
|
153 |
+
)[0]
|
154 |
+
print("context:\n", context, "\n", "=" * 50, "\n")
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
from jsonargparse import CLI
|
159 |
+
|
160 |
+
CLI(main)
|