multi-stage-retrieval-QA / retrieval.py
Hyma7's picture
Upload 6 files
ff15b29 verified
raw
history blame
603 Bytes
from sentence_transformers import SentenceTransformer, util
def load_embedding_model(model_name):
return SentenceTransformer(model_name)
def retrieve_top_k(model, query, corpus, k=10):
query_embedding = model.encode(query, convert_to_tensor=True)
corpus_embeddings = model.encode([corpus[doc_id]["text"] for doc_id in corpus], convert_to_tensor=True)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=k)[0]
top_k_passages = [(corpus[list(corpus.keys())[hit['corpus_id']]]["text"], hit['score']) for hit in hits]
return top_k_passages