""" This module provides custom implementation of a document retriever, designed for multi-stage retrieval. The system uses ensemble methods combining BM25 and Chroma Embeddings to retrieve relevant documents for a given query. It also utilizes various optimizations like rank fusion and weighted reciprocal rank by Langchain. Classes: -------- - MyEnsembleRetriever: Custom retriever for BM25 and Chroma Embeddings. - MyRetriever: Handles multi-stage retrieval. """ import re import ast import copy import math import logging from typing import Dict, List, Optional from langchain.chains import LLMChain from langchain.schema import BaseRetriever, Document from langchain.retrievers import BM25Retriever, EnsembleRetriever from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) from toolkit.utils import Config, clean_text, DocIndexer, IndexerOperator from toolkit.prompts import PromptTemplates prompt_templates = PromptTemplates() configs = Config("configparser.ini") logger = logging.getLogger(__name__) class MyEnsembleRetriever(EnsembleRetriever): """ Custom retriever for BM24 and Chroma Embeddings """ retrievers: Dict[str, BaseRetriever] def rank_fusion( self, query: str, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: """ Retrieve the results of the retrievers and use rank_fusion_func to get the final result. Args: query: The query to search for. Returns: A list of reranked documents. """ # Get the results of all retrievers. retriever_docs = [] for key, retriever in self.retrievers.items(): if key == "bm25": res = retriever.get_relevant_documents( clean_text(query), callbacks=run_manager.get_child(tag=f"retriever_{key}"), ) retriever_docs.append(res) else: res = retriever.get_relevant_documents( query, callbacks=run_manager.get_child(tag=f"retriever_{key}") ) retriever_docs.append(res) # apply rank fusion fused_documents = self.weighted_reciprocal_rank(retriever_docs) return fused_documents async def arank_fusion( self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: """ Asynchronously retrieve the results of the retrievers and use rank_fusion_func to get the final result. Args: query: The query to search for. Returns: A list of reranked documents. """ # Get the results of all retrievers. retriever_docs = [] for key, retriever in self.retrievers.items(): if key == "bm25": res = retriever.get_relevant_documents( clean_text(query), callbacks=run_manager.get_child(tag=f"retriever_{key}"), ) retriever_docs.append(res) # print("retriever_docs 1:", res) else: res = await retriever.aget_relevant_documents( query, callbacks=run_manager.get_child(tag=f"retriever_{key}") ) retriever_docs.append(res) # apply rank fusion fused_documents = self.weighted_reciprocal_rank(retriever_docs) return fused_documents def weighted_reciprocal_rank( self, doc_lists: List[List[Document]] ) -> List[Document]: """ Perform weighted Reciprocal Rank Fusion on multiple rank lists. You can find more details about RRF here: https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf Args: doc_lists: A list of rank lists, where each rank list contains unique items. Returns: list: The final aggregated list of items sorted by their weighted RRF scores in descending order. """ if len(doc_lists) != len(self.weights): raise ValueError( "Number of rank lists must be equal to the number of weights." ) # replace the page_content with the original uncleaned page_content doc_lists_ = copy.copy(doc_lists) for doc_list in doc_lists_: for doc in doc_list: doc.page_content = doc.metadata["page_content"] # doc.metadata["page_content"] = None # Create a union of all unique documents in the input doc_lists all_documents = set() for doc_list in doc_lists_: for doc in doc_list: all_documents.add(doc.page_content) # Initialize the RRF score dictionary for each document rrf_score_dic = {doc: 0.0 for doc in all_documents} # Calculate RRF scores for each document for doc_list, weight in zip(doc_lists_, self.weights): for rank, doc in enumerate(doc_list, start=1): rrf_score = weight * (1 / (rank + self.c)) rrf_score_dic[doc.page_content] += rrf_score # Sort documents by their RRF scores in descending order sorted_documents = sorted( rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True ) # Map the sorted page_content back to the original document objects page_content_to_doc_map = { doc.page_content: doc for doc_list in doc_lists_ for doc in doc_list } sorted_docs = [ page_content_to_doc_map[page_content] for page_content in sorted_documents ] return sorted_docs class MyRetriever: """ Retriever class to handle multi-stage retrieval. """ def __init__( self, llm, embedding_chunks_small: List[Document], embedding_chunks_medium: List[Document], docs_chunks_small: DocIndexer, docs_chunks_medium: DocIndexer, first_retrieval_k: int, second_retrieval_k: int, num_windows: int, retriever_weights: List[float], ): """ Initialize the MyRetriever class. Args: llm: Language model for retrieval. embedding_chunks_small (List[Document]): List of small embedding chunks. embedding_chunks_medium (List[Document]): List of medium embedding chunks. docs_chunks_small (DocIndexer): Document indexer for small chunks. docs_chunks_medium (DocIndexer): Document indexer for medium chunks. first_retrieval_k (int): Number of top documents to retrieve in first retrieval. second_retrieval_k (int): Number of top documents to retrieve in second retrieval. num_windows (int): Number of overlapping windows to consider. retriever_weights (List[float]): Weights for ensemble retrieval. """ self.llm = llm self.embedding_chunks_small = embedding_chunks_small self.embedding_chunks_medium = embedding_chunks_medium self.docs_index_small = DocIndexer(docs_chunks_small) self.docs_index_medium = DocIndexer(docs_chunks_medium) self.first_retrieval_k = first_retrieval_k self.second_retrieval_k = second_retrieval_k self.num_windows = num_windows self.retriever_weights = retriever_weights def get_retriever( self, docs_chunks, emb_chunks, emb_filter=None, k=2, weights=(0.5, 0.5), ): """ Initialize and return a retriever instance with specified parameters. Args: docs_chunks: The document chunks for the BM25 retriever. emb_chunks: The document chunks for the Embedding retriever. emb_filter: A filter for embedding retriever. k (int): The number of top documents to return. weights (list): Weights for ensemble retrieval. Returns: MyEnsembleRetriever: An instance of MyEnsembleRetriever. """ bm25_retriever = BM25Retriever.from_documents(docs_chunks) bm25_retriever.k = k emb_retriever = emb_chunks.as_retriever( search_kwargs={ "filter": emb_filter, "k": k, "search_type": "mmr", } ) return MyEnsembleRetriever( retrievers={"bm25": bm25_retriever, "chroma": emb_retriever}, weights=weights, ) def find_overlaps(self, doc: List[Document]): """ Find overlapping intervals of windows. Args: doc (Document): A document object to find overlaps in. Returns: list: A list of overlapping intervals. """ intervals = [] for item in doc: intervals.append( ( item.metadata["large_chunks_idx_lower_bound"], item.metadata["large_chunks_idx_upper_bound"], ) ) remaining_intervals, grouped_intervals, centroids = intervals.copy(), [], [] while remaining_intervals: curr_interval = remaining_intervals.pop(0) curr_group = [curr_interval] subset_interval = None for start, end in remaining_intervals.copy(): for s, e in curr_group: overlap = set(range(s, e + 1)) & set(range(start, end + 1)) if overlap: curr_group.append((start, end)) remaining_intervals.remove((start, end)) if set(range(start, end + 1)).issubset(set(range(s, e + 1))): subset_interval = (start, end) break if subset_interval: centroid = [math.ceil((subset_interval[0] + subset_interval[1]) / 2)] elif len(curr_group) > 2: first_overlap = max( set(range(curr_group[0][0], curr_group[0][1] + 1)) & set(range(curr_group[1][0], curr_group[1][1] + 1)) ) last_overlap_set = set( range(curr_group[-1][0], curr_group[-1][1] + 1) ) & set(range(curr_group[-2][0], curr_group[-2][1] + 1)) if not last_overlap_set: last_overlap = first_overlap # Fallback if no overlap else: last_overlap = min(last_overlap_set) step = 1 if first_overlap <= last_overlap else -1 centroid = list(range(first_overlap, last_overlap + step, step)) else: centroid = [ round( sum([math.ceil((s + e) / 2) for s, e in curr_group]) / len(curr_group) ) ] grouped_intervals.append( curr_group if len(curr_group) > 1 else curr_group[0] ) centroids.extend(centroid) return centroids def get_filter(self, top_k: int, file_md5: str, doc: List[Document]): """ Create a filter for retrievers based on overlapping intervals. Args: top_k (int): Number of top intervals to consider. file_md5 (str): MD5 hash of the file to filter. doc (List[Document]): List of document objects. Returns: tuple: A tuple of containing dictionary filters for DocIndexer and Chroma retrievers. """ overlaps = self.find_overlaps(doc) if len(overlaps) < 1: raise ValueError("No overlapping intervals found.") overlaps_k = overlaps[:top_k] logger.info("windows_at_2nd_retrieval: %s", overlaps_k) search_dict_docindexer = {"OR": []} search_dict_chroma = {"$or": []} for chunk_idx in overlaps_k: search_dict_docindexer["OR"].append( { "large_chunks_idx_lower_bound": ( IndexerOperator.LTE, chunk_idx, ), "large_chunks_idx_upper_bound": ( IndexerOperator.GTE, chunk_idx, ), "source_md5": (IndexerOperator.EQ, file_md5), } ) if len(overlaps_k) == 1: search_dict_chroma = { "$and": [ {"large_chunks_idx_lower_bound": {"$lte": overlaps_k[0]}}, {"large_chunks_idx_upper_bound": {"$gte": overlaps_k[0]}}, {"source_md5": {"$eq": file_md5}}, ] } else: search_dict_chroma["$or"].append( { "$and": [ {"large_chunks_idx_lower_bound": {"$lte": chunk_idx}}, {"large_chunks_idx_upper_bound": {"$gte": chunk_idx}}, {"source_md5": {"$eq": file_md5}}, ] } ) return search_dict_docindexer, search_dict_chroma def get_relevant_doc_ids(self, docs: List[Document], query: str): """ Get relevant document IDs given a query using an LLM. Args: docs (List[Document]): List of document objects to find relevant IDs in. query (str): The query string. Returns: list: A list of relevant document IDs. """ snippets = "\n\n\n".join( [ f"Context {idx}:\n{{{doc.page_content}}}. {{source: {doc.metadata['source']}}}" for idx, doc in enumerate(docs) ] ) id_chain = LLMChain( llm=self.llm, prompt=prompt_templates.get_docs_selection_template(configs.model_name), output_key="IDs", ) ids = id_chain.run({"query": query, "snippets": snippets}) logger.info("relevant doc ids: %s", ids) pattern = r"\[\s*\d+\s*(?:,\s*\d+\s*)*\]" match = re.search(pattern, ids) if match: return ast.literal_eval(match.group(0)) else: return [] def get_relevant_documents( self, query: str, num_query: int, *, run_manager: Optional[CallbackManagerForChainRun] = None, ) -> List[Document]: """ Perform multi-stage retrieval to get relevant documents. Args: query (str): The query string. num_query (int): Number of queries. run_manager (Optional[CallbackManagerForChainRun], optional): Callback manager for chain run. Returns: List[Document]: A list of relevant documents. """ # ! First retrieval first_retriever = self.get_retriever( docs_chunks=self.docs_index_small.documents, emb_chunks=self.embedding_chunks_small, emb_filter=None, k=self.first_retrieval_k, weights=self.retriever_weights, ) first = first_retriever.get_relevant_documents( query, callbacks=run_manager.get_child() ) for doc in first: logger.info("----1st retrieval----: %s", doc) ids_clean = self.get_relevant_doc_ids(first, query) # ids_clean = [0, 1, 2] logger.info("relevant cleaned doc ids: %s", ids_clean) qa_chunks = {} # key is file name, value is a list of relevant documents # res_chunks = [] if ids_clean and isinstance(ids_clean, list): source_md5_dict = {} for ids_c in ids_clean: if ids_c < len(first): if ids_c not in source_md5_dict: source_md5_dict[first[ids_c].metadata["source_md5"]] = [ first[ids_c] ] # else: # source_md5_dict[first[ids_c].metadata["source_md5"]].append( # ids_clean[ids_c] # ) if len(source_md5_dict) == 0: source_md5_dict[first[0].metadata["source_md5"]] = [first[0]] num_docs = len(source_md5_dict.keys()) third_num_k = max( 1, ( int( ( configs.max_llm_context / (configs.base_chunk_size * configs.chunk_scale) ) // (num_docs * num_query) ) ), ) for source_md5, docs in source_md5_dict.items(): logger.info( "selected_docs_at_1st_retrieval: %s", docs[0].metadata["source"] ) second_docs_chunks = self.docs_index_small.retrieve_metadata( { "source_md5": (IndexerOperator.EQ, source_md5), } ) second_retriever = self.get_retriever( docs_chunks=second_docs_chunks, emb_chunks=self.embedding_chunks_small, emb_filter={"source_md5": source_md5}, k=self.second_retrieval_k, weights=self.retriever_weights, ) # ! Second retrieval second = second_retriever.get_relevant_documents( query, callbacks=run_manager.get_child() ) for doc in second: logger.info("----2nd retrieval----: %s", doc) docs.extend(second) docindexer_filter, chroma_filter = self.get_filter( self.num_windows, source_md5, docs ) third_docs_chunks = self.docs_index_medium.retrieve_metadata( docindexer_filter ) third_retriever = self.get_retriever( docs_chunks=third_docs_chunks, emb_chunks=self.embedding_chunks_medium, emb_filter=chroma_filter, k=third_num_k, weights=self.retriever_weights, ) # ! Third retrieval third_temp = third_retriever.get_relevant_documents( query, callbacks=run_manager.get_child() ) third = third_temp[:third_num_k] # chunks = sorted(third, key=lambda x: x.metadata["medium_chunk_idx"]) for doc in third: logger.info( "----3rd retrieval----page_content: %s", [doc.page_content] ) mtdata = doc.metadata mtdata["page_content"] = None logger.info("----3rd retrieval----metadata: %s", mtdata) file_name = third[0].metadata["source"].split("/")[-1] if file_name not in qa_chunks: qa_chunks[file_name] = third else: qa_chunks[file_name].extend(third) return qa_chunks async def aget_relevant_documents( self, query: str, num_query: int, *, run_manager: AsyncCallbackManagerForChainRun, ) -> List[Document]: """ Asynchronous version of get_relevant_documents method. Args: query (str): The query string. num_query (int): Number of queries. run_manager (AsyncCallbackManagerForChainRun): Callback manager for asynchronous chain run. Returns: List[Document]: A list of relevant documents. """ # ! First retrieval first_retriever = self.get_retriever( docs_chunks=self.docs_index_small.documents, emb_chunks=self.embedding_chunks_small, emb_filter=None, k=self.first_retrieval_k, weights=self.retriever_weights, ) first = await first_retriever.aget_relevant_documents( query, callbacks=run_manager.get_child() ) for doc in first: logger.info("----1st retrieval----: %s", doc) ids_clean = self.get_relevant_doc_ids(first, query) logger.info("relevant doc ids: %s", ids_clean) qa_chunks = {} # key is file name, value is a list of relevant documents # res_chunks = [] if ids_clean and isinstance(ids_clean, list): source_md5_dict = {} for ids_c in ids_clean: if ids_c < len(first): if ids_c not in source_md5_dict: source_md5_dict[first[ids_c].metadata["source_md5"]] = [ first[ids_c] ] # else: # source_md5_dict[first[ids_c].metadata["source_md5"]].append( # ids_clean[ids_c] # ) if len(source_md5_dict) == 0: source_md5_dict[first[0].metadata["source_md5"]] = [first[0]] num_docs = len(source_md5_dict.keys()) third_num_k = max( 1, ( int( ( configs.max_llm_context / (configs.base_chunk_size * configs.chunk_scale) ) // (num_docs * num_query) ) ), ) for source_md5, docs in source_md5_dict.items(): logger.info( "selected_docs_at_1st_retrieval: %s", docs[0].metadata["source"] ) second_docs_chunks = self.docs_index_small.retrieve_metadata( { "source_md5": (IndexerOperator.EQ, source_md5), } ) second_retriever = self.get_retriever( docs_chunks=second_docs_chunks, emb_chunks=self.embedding_chunks_small, emb_filter={"source_md5": source_md5}, k=self.second_retrieval_k, weights=self.retriever_weights, ) # ! Second retrieval second = await second_retriever.aget_relevant_documents( query, callbacks=run_manager.get_child() ) for doc in second: logger.info("----2nd retrieval----: %s", doc) docs.extend(second) docindexer_filter, chroma_filter = self.get_filter( self.num_windows, source_md5, docs ) third_docs_chunks = self.docs_index_medium.retrieve_metadata( docindexer_filter ) third_retriever = self.get_retriever( docs_chunks=third_docs_chunks, emb_chunks=self.embedding_chunks_medium, emb_filter=chroma_filter, k=third_num_k, weights=self.retriever_weights, ) # ! Third retrieval third_temp = await third_retriever.aget_relevant_documents( query, callbacks=run_manager.get_child() ) third = third_temp[:third_num_k] # chunks = sorted(third, key=lambda x: x.metadata["medium_chunk_idx"]) for doc in third: logger.info( "----3rd retrieval----page_content: %s", [doc.page_content] ) mtdata = doc.metadata mtdata["page_content"] = None logger.info("----3rd retrieval----metadata: %s", mtdata) file_name = third[0].metadata["source"].split("/")[-1] if file_name not in qa_chunks: qa_chunks[file_name] = third else: qa_chunks[file_name].extend(third) return qa_chunks