Spaces:
Sleeping
Sleeping
File size: 4,470 Bytes
11f5db5 e28d050 f782117 e28d050 d47b442 e28d050 d47b442 043f795 d47b442 043f795 d47b442 f4add15 d47b442 7847e6f e28d050 d47b442 11f5db5 e28d050 8dc7a1c e28d050 11f5db5 8dc7a1c 11f5db5 e28d050 d47b442 e28d050 8dc7a1c e28d050 8dc7a1c e28d050 11f5db5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import streamlit as st
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
from sklearn.metrics import ndcg_score
# Helper function to load the dataset
def download_and_extract_dataset():
import urllib.request
import zipfile
import os
dataset_url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip"
dataset_zip_path = "nq.zip"
data_path = "./datasets/nq"
# Download the dataset if not already downloaded
if not os.path.exists(dataset_zip_path):
st.write("Downloading the dataset... This may take a few minutes.")
urllib.request.urlretrieve(dataset_url, dataset_zip_path)
st.write("Download complete!")
# Unzip the dataset if not already unzipped
if not os.path.exists(data_path):
st.write("Unzipping the dataset...")
with zipfile.ZipFile(dataset_zip_path, 'r') as zip_ref:
zip_ref.extractall("./datasets")
st.write("Dataset unzipped!")
return data_path
# Function to load corpus, queries, and qrels
def load_dataset():
from beir.datasets.data_loader import GenericDataLoader
data_path = download_and_extract_dataset()
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
return corpus, queries, qrels
# Stage 1: Candidate retrieval using Sentence Transformer
def candidate_retrieval(query, corpus, top_k=10):
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
corpus_ids = list(corpus.keys())
corpus_embeddings = model.encode([corpus[doc_id]['text'] for doc_id in corpus_ids], convert_to_tensor=True)
query_embedding = model.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
retrieved_docs = [corpus_ids[hit['corpus_id']] for hit in hits]
return retrieved_docs
# Stage 2: Reranking using cross-encoder
def rerank(retrieved_docs, query, corpus, top_k=5):
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
scores = []
for doc_id in retrieved_docs:
text = corpus[doc_id]['text']
inputs = tokenizer(query, text, return_tensors="pt", truncation=True, padding=True)
outputs = model(**inputs)
scores.append(outputs.logits.item())
reranked_indices = np.argsort(scores)[::-1][:top_k]
reranked_docs = [retrieved_docs[idx] for idx in reranked_indices]
return reranked_docs, scores
# Function to evaluate using NDCG@10
def evaluate_ndcg(reranked_docs, qrels, query_id, k=10):
true_relevance = [qrels.get((query_id, doc_id), 0) for doc_id in reranked_docs]
ideal_relevance = sorted(true_relevance, reverse=True)
# NDCG expects input as 2D arrays
return ndcg_score([ideal_relevance], [true_relevance], k=k)
# Streamlit main function
def main():
st.title("Multi-Stage Retrieval Pipeline with Evaluation")
st.write("Loading the dataset...")
corpus, queries, qrels = load_dataset()
st.write(f"Corpus Size: {len(corpus)}")
# User input for asking a question
user_query = st.text_input("Ask a question:")
if user_query:
st.write(f"Your query: {user_query}")
st.write("Running Candidate Retrieval...")
retrieved_docs = candidate_retrieval(user_query, corpus, top_k=10)
st.write("Running Reranking...")
reranked_docs, rerank_scores = rerank(retrieved_docs, user_query, corpus, top_k=5)
st.write("Top Reranked Documents:")
for doc_id in reranked_docs:
st.write(f"Document ID: {doc_id}")
st.write(f"Document Text: {corpus[doc_id]['text'][:500]}...") # Show the first 500 characters of the document
# Evaluation if the user query exists in the qrels (ground truth relevance labels)
query_id = list(queries.keys())[0] # Dummy query ID for now
if query_id in queries:
ndcg_score_value = evaluate_ndcg(reranked_docs, qrels, query_id, k=10)
st.write(f"NDCG@10 Score: {ndcg_score_value}")
else:
st.write("No ground truth available for this query.")
st.write("Query executed successfully!")
if __name__ == "__main__":
main()
|