Hyma7 commited on
Commit
11f5db5
·
verified ·
1 Parent(s): 6694ac3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -51
app.py CHANGED
@@ -1,66 +1,100 @@
1
- import streamlit as st
2
- import pandas as pd
3
  import numpy as np
 
 
4
  from sentence_transformers import SentenceTransformer
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
- import torch
 
 
7
 
8
- # Load dataset
9
- def load_qa_dataset():
10
- # Example using BEIR's NQ dataset
11
- dataset = load_dataset('beir/nq')
12
- return dataset['train'] # Use the training set for simplicity
 
 
 
 
 
13
 
14
- # Load embedding models
15
- small_embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
16
- large_embedding_model = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2')
 
 
 
17
 
18
- # Load ranking models
19
- small_ranking_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
20
- small_ranking_tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
21
 
22
- large_ranking_model = AutoModelForSequenceClassification.from_pretrained("nvidia/nv-rerankqa-mistral-4b-v3")
23
- large_ranking_tokenizer = AutoTokenizer.from_pretrained("nvidia/nv-rerankqa-mistral-4b-v3")
24
 
25
- # Streamlit UI
26
- st.title("Multi-Stage Text Retrieval Pipeline for QA")
27
 
28
- # Load dataset
29
- dataset = load_qa_dataset()
 
 
30
 
31
- # Input question
32
- question = st.text_input("Enter a question:")
33
- if question:
34
- passages = dataset['text'][:100] # Limit to 100 passages for demo purposes
 
 
 
 
35
 
36
- # Stage 1: Candidate Retrieval using Embedding Models
37
- st.write("**Stage 1: Candidate Retrieval**")
38
- question_embedding = small_embedding_model.encode(question)
39
- passage_embeddings = small_embedding_model.encode(passages)
 
40
 
41
- # Find top-k similar passages using cosine similarity
42
- top_k = 5
43
- similarities = np.inner(question_embedding, passage_embeddings)
44
- top_k_indices = np.argsort(similarities)[-top_k:][::-1]
45
 
46
- retrieved_passages = [passages[i] for i in top_k_indices]
47
- st.write("Top-k retrieved passages:")
48
- for passage in retrieved_passages:
49
- st.write(passage)
50
 
51
- # Stage 2: Reranking with Ranking Models
52
- st.write("**Stage 2: Reranking**")
53
- inputs = [small_ranking_tokenizer(question, passage, return_tensors='pt', truncation=True, padding=True) for passage in retrieved_passages]
54
- reranked_scores = []
 
 
 
 
55
 
56
- for input_pair in inputs:
57
- with torch.no_grad():
58
- ranking_outputs = small_ranking_model(**input_pair)
59
- score = ranking_outputs.logits.softmax(dim=1)[:, 1].item() # Score for positive class
60
- reranked_scores.append(score)
 
61
 
62
- # Sort passages by ranking scores
63
- ranked_passages = sorted(zip(retrieved_passages, reranked_scores), key=lambda x: x[1], reverse=True)
64
- st.write("Ranked passages by relevance score:")
65
- for passage, score in ranked_passages:
66
- st.write(f"{passage} (Score: {score:.2f})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
2
  import numpy as np
3
+ import faiss
4
+ import streamlit as st
5
  from sentence_transformers import SentenceTransformer
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
7
+ from beir import util
8
+ from beir.datasets.data_loader import GenericDataLoader
9
+ from beir.evaluation.evaluator import EvaluateRetrieval
10
 
11
+ # Function to load the dataset
12
+ def load_dataset():
13
+ dataset_name = "nq"
14
+ data_path = f"datasets/{dataset_name}.zip"
15
+ if not os.path.exists(data_path):
16
+ url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
17
+ util.download_and_unzip(url, "datasets/")
18
+
19
+ corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
20
+ return corpus, queries, qrels
21
 
22
+ # Function for candidate retrieval
23
+ def candidate_retrieval(corpus, queries):
24
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
25
+ corpus_ids = list(corpus.keys())
26
+ corpus_texts = [corpus[pid]["text"] for pid in corpus_ids]
27
+ corpus_embeddings = embed_model.encode(corpus_texts, convert_to_numpy=True)
28
 
29
+ index = faiss.IndexFlatL2(corpus_embeddings.shape[1])
30
+ index.add(corpus_embeddings)
 
31
 
32
+ query_texts = [queries[qid] for qid in queries.keys()]
33
+ query_embeddings = embed_model.encode(query_texts, convert_to_numpy=True)
34
 
35
+ _, retrieved_indices = index.search(query_embeddings, 10)
36
+ return retrieved_indices, corpus_ids
37
 
38
+ # Function for reranking
39
+ def rerank_passages(retrieved_indices, corpus, queries):
40
+ cross_encoder_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
41
+ tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
42
 
43
+ reranked_passages = []
44
+ for i, query in enumerate(queries.values()):
45
+ query_passage_pairs = [(query, corpus[corpus_ids[idx]]["text"]) for idx in retrieved_indices[i]]
46
+ inputs = tokenizer(query_passage_pairs, padding=True, truncation=True, return_tensors="pt")
47
+ scores = cross_encoder_model(**inputs).logits.squeeze(-1)
48
+
49
+ top_reranked_passages = [passage for _, passage in sorted(zip(scores, query_passage_pairs), key=lambda x: x[0], reverse=True)]
50
+ reranked_passages.append(top_reranked_passages)
51
 
52
+ return reranked_passages
53
+
54
+ # Function for evaluation
55
+ def evaluate(qrels, retrieved_indices, reranked_passages, queries):
56
+ evaluator = EvaluateRetrieval()
57
 
58
+ results_stage1 = {}
59
+ for i, query_id in enumerate(queries.keys()):
60
+ results_stage1[query_id] = {corpus_ids[idx]: 1 for idx in retrieved_indices[i]}
 
61
 
62
+ ndcg_score_stage1 = evaluator.evaluate(qrels, results_stage1, [10])['NDCG@10']
 
 
 
63
 
64
+ results_stage2 = {}
65
+ for i, query_id in enumerate(queries.keys()):
66
+ results_stage2[query_id] = {}
67
+ for passage in reranked_passages[i]:
68
+ for pid, doc in corpus.items():
69
+ if doc["text"] == passage[1]:
70
+ results_stage2[query_id][pid] = 1
71
+ break
72
 
73
+ ndcg_score_stage2 = evaluator.evaluate(qrels, results_stage2, [10])['NDCG@10']
74
+ return ndcg_score_stage1, ndcg_score_stage2
75
+
76
+ # Streamlit app
77
+ def main():
78
+ st.title("Multi-Stage Text Retrieval Pipeline")
79
 
80
+ if st.button("Load Dataset"):
81
+ corpus, queries, qrels = load_dataset()
82
+ st.success("Dataset loaded successfully!")
83
+
84
+ if st.button("Run Candidate Retrieval"):
85
+ retrieved_indices, corpus_ids = candidate_retrieval(corpus, queries)
86
+ st.success("Candidate retrieval completed!")
87
+ st.write("Retrieved indices:", retrieved_indices)
88
+
89
+ if st.button("Run Reranking"):
90
+ reranked_passages = rerank_passages(retrieved_indices, corpus, queries)
91
+ st.success("Reranking completed!")
92
+ st.write("Reranked passages:", reranked_passages)
93
+
94
+ if st.button("Evaluate"):
95
+ ndcg_score_stage1, ndcg_score_stage2 = evaluate(qrels, retrieved_indices, reranked_passages, queries)
96
+ st.write(f"NDCG@10 for Stage 1 (Candidate Retrieval): {ndcg_score_stage1}")
97
+ st.write(f"NDCG@10 for Stage 2 (Reranking): {ndcg_score_stage2}")
98
+
99
+ if __name__ == "__main__":
100
+ main()