from rank_bm25 import BM25Okapi from tokenizing import tokenize_text, tokenize_doc, tokenize_doc_to_str import numpy as np from tqdm import tqdm import os import pickle import torch from sentence_transformers import SentenceTransformer # Define the base path base_path = "./Data" # "/mnt/d/Semester7/NLP/RAG/Data" class Retriever: def __init__(self, docs: [dict]) -> None: self.docs = docs self.tokenized_docs_path = os.path.join(base_path, "tokenized_docs.pkl") self.bm25_path = os.path.join(base_path, "bm25.pkl") self.sbert_embeddings_path = os.path.join(base_path, "embeddings_parts") # Initialize SBERT device = 'cuda' if torch.cuda.is_available() else 'cpu' self.sbert = SentenceTransformer('sentence-transformers/all-distilroberta-v1', device=device) # Load or tokenize documents if os.path.exists(self.tokenized_docs_path) and os.path.exists(self.bm25_path) and os.path.exists(self.sbert_embeddings_path): print("Loading cache") self.load_cache() print("Cache loaded") else: self.tokenize_and_initialize() def tokenize_and_initialize(self): # Tokenize the documents with a progress bar self.tokenized_docs = [tokenize_doc(doc) for doc in tqdm(self.docs, desc="Tokenizing documents")] self.str_docs = [tokenize_doc_to_str(doc) for doc in self.docs] # Ensure that the tokenized_docs list is not empty if not self.tokenized_docs: raise ValueError("The list of tokenized documents is empty. Please check the input documents and the tokenization process.") # Initialize BM25 with the tokenized texts self.bm25 = BM25Okapi(self.tokenized_docs) # Get embeddings for all the tocenized documents self.sbert_embeddings = self.sbert.encode(self.str_docs, show_progress_bar=True) self.sbert_embeddings = self.sbert_embeddings.cpu() # Save the tokenized documents and BM25 model self.save_cache() def save_cache(self): with open(self.tokenized_docs_path, 'wb') as f: pickle.dump(self.tokenized_docs, f) with open(self.bm25_path, 'wb') as f: pickle.dump(self.bm25, f) split_size = 1000 # Number of rows per split embeddings_size = self.sbert_embeddings.size(0) # Split and save the tensor for i in range(0, embeddings_size, split_size): end_idx = min(i + split_size, embeddings_size) part = self.sbert_embeddings[i:end_idx].clone() torch.save(part, os.path.join(self.sbert_embeddings_path, f"embeddings_part_{i//split_size}.pt")) def load_cache(self): with open(self.tokenized_docs_path, 'rb') as f: self.tokenized_docs = pickle.load(f) with open(self.bm25_path, 'rb') as f: self.bm25 = pickle.load(f) print("Loading SBERT embeddings") # Load and combine loaded_parts = [] files = os.listdir(self.sbert_embeddings_path) # Sort numerically based on the numeric part of the filename sorted_files = sorted(files, key=lambda x: int(x.split('_')[2].split('.')[0])) counter = 0 for file in sorted_files: if file.startswith("embeddings_part_") and file.endswith(".pt"): part_path = os.path.join(self.sbert_embeddings_path, file) loaded_parts.append(torch.load(part_path)) if counter % 50 == 0: print("Loaded", file) counter += 1 self.sbert_embeddings = torch.cat(loaded_parts, dim=0) print("SBERT embeddings loaded") def get_docs(self, user_message: str, n: int = 30, bm25_only: bool = False, semantic_only: bool = False, scores_combination: bool = True, bm_koef: float = 0.75) -> [str]: # In case of BM25 only, return the top n documents based on BM25 scores, if somebody sets a couple # of flags to True, the func will return the top n documents based on the first flag set to True # remove "tell me a joke about" ot "tell me a joke and its title about" from the user message user_message = user_message.replace("tell me a joke about", "").replace("tell me a joke and its title about", "") if bm25_only: semantic_only = False scores_combination = False print("BM25 only") scores = torch.tensor(self._get_bm25_scores(user_message)) elif semantic_only: scores_combination = False print("Semantic only") scores = self.get_semantic_scores(user_message) elif scores_combination: print("Combination") bm_scores = self._get_bm25_scores(user_message) semantic_scores = self.get_semantic_scores(user_message) scores = torch.tensor(bm_koef * bm_scores) + (1 - bm_koef) * semantic_scores # Sort the documents by their BM25 scores in descending order sorted_doc_indices = np.argsort(scores) result_docs = [self.docs[i] for i in sorted_doc_indices[-n:] if scores[i] > 0] return result_docs[::-1] # Return the top n documents in descending order which means the most relevant documents are first def _get_bm25_scores(self, user_message: str) -> np.array: tokenized_user_message = tokenize_text(user_message) return self.bm25.get_scores(tokenized_user_message) def get_semantic_scores(self, user_message: str) -> np.array: user_message_embedding = self.sbert.encode(user_message) scores = self.sbert.similarity(user_message_embedding, self.sbert_embeddings) return scores[0]