|
|
|
from typing import List |
|
from langchain.vectorstores import Chroma |
|
from langchain.retrievers.multi_vector import MultiVectorRetriever |
|
from langchain.storage import InMemoryStore |
|
import uuid |
|
from langchain.document_loaders import TextLoader, DirectoryLoader |
|
import os |
|
from sentence_transformers.cross_encoder import CrossEncoder |
|
import numpy as np |
|
from langchain.schema import BaseRetriever, Document |
|
from typing import List |
|
from langchain.callbacks.manager import CallbackManagerForRetrieverRun |
|
from langchain.vectorstores import VectorStore |
|
from langchain.load import dumps, loads |
|
from typing import Any |
|
|
|
|
|
def load_data(data_path): |
|
folders = os.listdir(data_path) |
|
dir_loaders = [] |
|
loaded_documents = [] |
|
|
|
for folder in folders: |
|
dir_loader = DirectoryLoader(os.path.join(data_path, folder), loader_cls=TextLoader) |
|
dir_loaders.append(dir_loader) |
|
|
|
for dir_loader in dir_loaders: |
|
loaded_documents.extend(dir_loader.load()) |
|
|
|
return loaded_documents |
|
|
|
def process_data(data: List[str], child_text_splitter, embedding, vectorstore_name: str) -> MultiVectorRetriever: |
|
|
|
|
|
vectorstore = Chroma( |
|
collection_name=vectorstore_name, |
|
embedding_function=embedding, |
|
|
|
) |
|
|
|
|
|
store = InMemoryStore() |
|
id_key = "doc_id" |
|
|
|
|
|
retriever = MultiVectorRetriever( |
|
vectorstore=vectorstore, |
|
docstore=store, |
|
id_key=id_key, |
|
search_kwargs={"k": 25} |
|
) |
|
|
|
doc_ids = [str(uuid.uuid4()) for _ in data] |
|
sub_docs = [] |
|
|
|
for i, doc in enumerate(data): |
|
_id = doc_ids[i] |
|
_sub_docs = child_text_splitter.split_documents([doc]) |
|
for _doc in _sub_docs: |
|
_doc.metadata[id_key] = _id |
|
sub_docs.extend(_sub_docs) |
|
|
|
retriever.vectorstore.add_documents(sub_docs) |
|
retriever.docstore.mset(list(zip(doc_ids, data))) |
|
|
|
return vectorstore, retriever |
|
|
|
class CustomRetriever(BaseRetriever): |
|
|
|
retriever:Any |
|
|
|
def reciprocal_rank_fusion(self, results: list[list], k=60): |
|
""" Reciprocal_rank_fusion that takes multiple lists of ranked documents |
|
and an optional parameter k used in the RRF formula """ |
|
|
|
|
|
fused_scores = {} |
|
|
|
|
|
for docs in results: |
|
|
|
for rank, doc in enumerate(docs): |
|
|
|
doc_str = dumps(doc) |
|
|
|
if doc_str not in fused_scores: |
|
fused_scores[doc_str] = 0 |
|
|
|
previous_score = fused_scores[doc_str] |
|
|
|
fused_scores[doc_str] += 1 / (rank + k) |
|
|
|
|
|
reranked_results = [ |
|
(loads(doc), score) |
|
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) |
|
] |
|
|
|
|
|
rr_list=[] |
|
for doc in reranked_results: |
|
rr_list.append(doc[0]) |
|
return rr_list |
|
|
|
|
|
def _get_relevant_documents( |
|
self, queries: list, *, run_manager: CallbackManagerForRetrieverRun |
|
) -> List[Document]: |
|
|
|
documents=[] |
|
for i in range(len(queries)): |
|
document = self.retriever.get_relevant_documents(queries[i], callbacks=run_manager.get_child()) |
|
documents.append(document) |
|
|
|
unique_documents = self.reciprocal_rank_fusion(documents) |
|
|
|
|
|
docs_content = [] |
|
for i in range(len(unique_documents)): |
|
docs_content.append(unique_documents[i].page_content) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = CrossEncoder('nnngoc/ms-marco-MiniLM-L-6-v2-32-6M-1') |
|
|
|
|
|
sentence_combinations = [[queries[0], document] for document in docs_content] |
|
|
|
|
|
similarity_scores = model.predict(sentence_combinations) |
|
|
|
|
|
sim_scores_argsort = reversed(np.argsort(similarity_scores)) |
|
|
|
|
|
docs = [] |
|
for idx in sim_scores_argsort: |
|
docs.append(unique_documents[idx]) |
|
|
|
docs_top_10 = docs[0:10] |
|
|
|
return docs_top_10 |