from datasets import load_dataset from config import CONFIG from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer, util class Retriever: def __init__(self): self.corpus = None self.bm25 = None self.model = None self.chunk_embeddings = None def load_and_prepare_dataset(self): dataset = load_dataset(CONFIG['DATASET']) dataset = dataset['train'].select(range(CONFIG['MAX_NUM_OF_RECORDS'])) dataset = dataset.map(lambda x: {'chunks': self.chunk_text(x['abstract'])}) self.corpus = [chunk for chunks in dataset["chunks"] for chunk in chunks] def prepare_bm25(self): tokenized_corpus = [doc.split(" ") for doc in self.corpus] self.bm25 = BM25Okapi(tokenized_corpus) def compute_embeddings(self): self.model = SentenceTransformer('all-MiniLM-L6-v2') tokenizer = self.model._first_module().tokenizer if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token self.chunk_embeddings = self.model.encode(self.corpus, convert_to_tensor=True) def chunk_text(self, text, chunk_size=CONFIG['CHUNK_SIZE']): words = text.split() return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] def retrieve_documents_bm25(self, query): tokenized_query = query.split(" ") scores = self.bm25.get_scores(tokenized_query) top_docs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:CONFIG['TOP_DOCS']] return [self.corpus[i] for i in top_docs] def retrieve_documents_semantic(self, query): query_embedding = self.model.encode(query, convert_to_tensor=True) scores = util.pytorch_cos_sim(query_embedding, self.chunk_embeddings)[0] top_chunks = scores.topk(CONFIG['TOP_DOCS']).indices return [self.corpus[i] for i in top_chunks]