GitChat / rag_101 /retriever.py
kartavya23's picture
Upload retriever.py
fdb4a30 verified
import os
os.environ["HF_HOME"] = "weights"
os.environ["TORCH_HOME"] = "weights"
from typing import List, Optional, Union
from langchain.callbacks import FileCallbackHandler
from langchain.retrievers import ContextualCompressionRetriever, ParentDocumentRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.storage import InMemoryStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS, Chroma
from langchain_core.documents import Document
from loguru import logger
from rich import print
from sentence_transformers import CrossEncoder
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
logfile = "log/output.log"
logger.add(logfile, colorize=True, enqueue=True)
handler = FileCallbackHandler(logfile)
persist_directory = None
class RAGException(Exception):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def rerank_docs(reranker_model, query, retrieved_docs):
query_and_docs = [(query, r.page_content) for r in retrieved_docs]
scores = reranker_model.predict(query_and_docs)
return sorted(list(zip(retrieved_docs, scores)), key=lambda x: x[1], reverse=True)
def load_pdf(
files: Union[str, List[str]] = "2401.08406v3.pdf"
) -> List[Document]:
if isinstance(files, str):
loader = UnstructuredFileLoader(
files,
post_processors=[clean_extra_whitespace, group_broken_paragraphs],
)
return loader.load()
loaders = [
UnstructuredFileLoader(
file,
post_processors=[clean_extra_whitespace, group_broken_paragraphs],
)
for file in files
]
docs = []
for loader in loaders:
docs.extend(
loader.load(),
)
return docs
def create_parent_retriever(
docs: List[Document], embeddings_model: HuggingFaceEmbeddings()
):
parent_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n\n", "\n\n"],
chunk_size=2000,
length_function=len,
is_separator_regex=False,
)
# This text splitter is used to create the child documents
child_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n\n", "\n\n"],
chunk_size=1000,
chunk_overlap=300,
length_function=len,
is_separator_regex=False,
)
# The vectorstore to use to index the child chunks
vectorstore = Chroma(
collection_name="split_documents",
embedding_function=embeddings_model,
persist_directory=persist_directory,
)
# The storage layer for the parent documents
store = InMemoryStore()
retriever = ParentDocumentRetriever(
vectorstore=vectorstore,
docstore=store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
k=10,
)
retriever.add_documents(docs)
return retriever
def retrieve_context(query, retriever, reranker_model):
retrieved_docs = retriever.get_relevant_documents(query)
if len(retrieved_docs) == 0:
raise RAGException(
f"Couldn't retrieve any relevant document with the query `{query}`. Try modifying your question!"
)
reranked_docs = rerank_docs(
query=query, retrieved_docs=retrieved_docs, reranker_model=reranker_model
)
return reranked_docs
def load_embedding_model(
model_name: str = "BAAI/bge-large-en-v1.5", device: str = "cpu"
) -> HuggingFaceEmbeddings:
model_kwargs = {"device": device}
encode_kwargs = {
"normalize_embeddings": True
} # set True to compute cosine similarity
embedding_model = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
return embedding_model
def load_reranker_model(
reranker_model_name: str = "BAAI/bge-reranker-large", device: str = "cpu"
) -> CrossEncoder:
reranker_model = CrossEncoder(
model_name=reranker_model_name, max_length=1024, device=device
)
return reranker_model
def main(
file: str = "2401.08406v3.pdf",
query: Optional[str] = None,
llm_name="mistral",
):
docs = load_pdf(files=file)
embedding_model = load_embedding_model()
retriever = create_parent_retriever(docs, embedding_model)
reranker_model = load_reranker_model()
context = retrieve_context(
query, retriever=retriever, reranker_model=reranker_model
)[0]
print("context:\n", context, "\n", "=" * 50, "\n")
if __name__ == "__main__":
from jsonargparse import CLI
CLI(main)