File size: 4,470 Bytes
11f5db5
e28d050
 
 
 
 
f782117
e28d050
d47b442
e28d050
 
 
 
d47b442
 
 
043f795
d47b442
 
 
 
 
043f795
d47b442
 
 
 
 
 
f4add15
d47b442
7847e6f
e28d050
d47b442
 
 
 
 
11f5db5
e28d050
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dc7a1c
 
 
 
 
 
 
 
 
e28d050
 
11f5db5
8dc7a1c
11f5db5
e28d050
d47b442
e28d050
 
 
 
 
 
 
 
 
 
 
 
8dc7a1c
e28d050
 
 
 
 
 
8dc7a1c
 
 
 
 
 
 
 
e28d050
11f5db5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
from sklearn.metrics import ndcg_score

# Helper function to load the dataset
def download_and_extract_dataset():
    import urllib.request
    import zipfile
    import os

    dataset_url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip"
    dataset_zip_path = "nq.zip"
    data_path = "./datasets/nq"

    # Download the dataset if not already downloaded
    if not os.path.exists(dataset_zip_path):
        st.write("Downloading the dataset... This may take a few minutes.")
        urllib.request.urlretrieve(dataset_url, dataset_zip_path)
        st.write("Download complete!")

    # Unzip the dataset if not already unzipped
    if not os.path.exists(data_path):
        st.write("Unzipping the dataset...")
        with zipfile.ZipFile(dataset_zip_path, 'r') as zip_ref:
            zip_ref.extractall("./datasets")
        st.write("Dataset unzipped!")

    return data_path

# Function to load corpus, queries, and qrels
def load_dataset():
    from beir.datasets.data_loader import GenericDataLoader
    data_path = download_and_extract_dataset()
    corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
    return corpus, queries, qrels

# Stage 1: Candidate retrieval using Sentence Transformer
def candidate_retrieval(query, corpus, top_k=10):
    model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    corpus_ids = list(corpus.keys())
    corpus_embeddings = model.encode([corpus[doc_id]['text'] for doc_id in corpus_ids], convert_to_tensor=True)

    query_embedding = model.encode(query, convert_to_tensor=True)
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
    
    retrieved_docs = [corpus_ids[hit['corpus_id']] for hit in hits]
    return retrieved_docs

# Stage 2: Reranking using cross-encoder
def rerank(retrieved_docs, query, corpus, top_k=5):
    tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
    model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")

    scores = []
    for doc_id in retrieved_docs:
        text = corpus[doc_id]['text']
        inputs = tokenizer(query, text, return_tensors="pt", truncation=True, padding=True)
        outputs = model(**inputs)
        scores.append(outputs.logits.item())

    reranked_indices = np.argsort(scores)[::-1][:top_k]
    reranked_docs = [retrieved_docs[idx] for idx in reranked_indices]
    return reranked_docs, scores

# Function to evaluate using NDCG@10
def evaluate_ndcg(reranked_docs, qrels, query_id, k=10):
    true_relevance = [qrels.get((query_id, doc_id), 0) for doc_id in reranked_docs]
    ideal_relevance = sorted(true_relevance, reverse=True)
    
    # NDCG expects input as 2D arrays
    return ndcg_score([ideal_relevance], [true_relevance], k=k)

# Streamlit main function
def main():
    st.title("Multi-Stage Retrieval Pipeline with Evaluation")

    st.write("Loading the dataset...")
    corpus, queries, qrels = load_dataset()
    st.write(f"Corpus Size: {len(corpus)}")
    
    # User input for asking a question
    user_query = st.text_input("Ask a question:")
    
    if user_query:
        st.write(f"Your query: {user_query}")
        
        st.write("Running Candidate Retrieval...")
        retrieved_docs = candidate_retrieval(user_query, corpus, top_k=10)
        
        st.write("Running Reranking...")
        reranked_docs, rerank_scores = rerank(retrieved_docs, user_query, corpus, top_k=5)
        
        st.write("Top Reranked Documents:")
        for doc_id in reranked_docs:
            st.write(f"Document ID: {doc_id}")
            st.write(f"Document Text: {corpus[doc_id]['text'][:500]}...")  # Show the first 500 characters of the document
        
        # Evaluation if the user query exists in the qrels (ground truth relevance labels)
        query_id = list(queries.keys())[0]  # Dummy query ID for now
        if query_id in queries:
            ndcg_score_value = evaluate_ndcg(reranked_docs, qrels, query_id, k=10)
            st.write(f"NDCG@10 Score: {ndcg_score_value}")
        else:
            st.write("No ground truth available for this query.")
        
        st.write("Query executed successfully!")

if __name__ == "__main__":
    main()