Hyma7 commited on
Commit
7847e6f
·
verified ·
1 Parent(s): acadc69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -31
app.py CHANGED
@@ -1,45 +1,46 @@
1
  import streamlit as st
2
- import pandas as pd
3
- from datasets import load_dataset
4
  from sentence_transformers import SentenceTransformer
5
- from transformers import pipeline
6
 
7
- # Load a subset of the Natural Questions dataset for testing
8
- def load_nq_dataset():
9
- return load_dataset("nq", split='train[:1%]') # Load a small subset for testing
 
 
 
 
 
10
 
11
  # Load models
12
- @st.cache_resource
13
- def load_models():
14
- embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
15
- ranking_model = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-12-v2")
16
- return embedding_model, ranking_model
 
 
17
 
18
- # Main function for the Streamlit app
19
  def main():
20
  st.title("Multi-Stage Text Retrieval Pipeline for QA")
21
-
22
- # Load dataset
23
- dataset = load_nq_dataset()
24
-
25
- # User input
26
  question = st.text_input("Enter a question:")
27
 
28
  if question:
29
- # Load models
30
- embedding_model, ranking_model = load_models()
31
-
32
- # Retrieve passages (mock implementation, replace with actual logic)
33
- top_k_passages = dataset['context'][:5] # Replace with actual retrieval logic
34
- embeddings = embedding_model.encode(top_k_passages)
35
-
36
- # Re-rank passages
37
- ranked_passages = ranking_model([question + " " + passage for passage in top_k_passages])
38
-
39
- # Display results
40
- st.write("Top Retrieved Passages:")
41
- for i, (passage, score) in enumerate(zip(top_k_passages, ranked_passages)):
42
- st.write(f"{i + 1}: {passage} (Score: {score['score']:.2f})")
 
43
 
44
  if __name__ == "__main__":
45
  main()
 
1
  import streamlit as st
2
+ import numpy as np
 
3
  from sentence_transformers import SentenceTransformer
4
+ from transformers import CrossEncoder
5
 
6
+ # Sample passages
7
+ passages = [
8
+ "The sky is blue.",
9
+ "The grass is green.",
10
+ "The sun is bright.",
11
+ "Rain falls from the sky.",
12
+ "Flowers bloom in spring."
13
+ ]
14
 
15
  # Load models
16
+ embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
17
+ ranking_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
18
+
19
+ def get_relevant_passages(question, passages):
20
+ keywords = question.lower().split()
21
+ relevant_passages = [p for p in passages if any(keyword in p.lower() for keyword in keywords)]
22
+ return relevant_passages if relevant_passages else passages # Return all if no match
23
 
 
24
  def main():
25
  st.title("Multi-Stage Text Retrieval Pipeline for QA")
 
 
 
 
 
26
  question = st.text_input("Enter a question:")
27
 
28
  if question:
29
+ relevant_passages = get_relevant_passages(question, passages)
30
+ st.write("Relevant passages:")
31
+ for p in relevant_passages:
32
+ st.write(f"- {p}")
33
+
34
+ # Embedding and ranking
35
+ if st.button("Retrieve Answers"):
36
+ passage_embeddings = embedding_model.encode(relevant_passages)
37
+ question_embedding = embedding_model.encode(question)
38
+ scores = np.dot(passage_embeddings, question_embedding.T)
39
+ ranked_indices = np.argsort(scores)[::-1]
40
+
41
+ st.write("Ranked passages:")
42
+ for idx in ranked_indices:
43
+ st.write(f"- {relevant_passages[idx]} (Score: {scores[idx]:.2f})")
44
 
45
  if __name__ == "__main__":
46
  main()