Hyma7 commited on
Commit
043f795
·
verified ·
1 Parent(s): 9fe5fd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -70
app.py CHANGED
@@ -1,70 +1,37 @@
1
- import streamlit as st
2
- from data_preparation import load_dataset
3
- from retrieval import load_embedding_model, retrieve_top_k
4
- from reranking import load_ranking_model, rerank
5
- from evaluation import evaluate_ndcg
6
-
7
- # Set up the Streamlit interface
8
- st.title("Multi-Stage Text Retrieval Pipeline for QA")
9
-
10
- # Query Input
11
- query = st.text_input("Enter a question:", "What is the capital of France?")
12
-
13
- # Embedding model selection
14
- embedding_model = st.selectbox(
15
- "Select Embedding Model for Candidate Retrieval",
16
- ["sentence-transformers/all-MiniLM-L6-v2", "nvidia/nv-embedqa-e5-v5"]
17
- )
18
-
19
- # Ranking model selection
20
- ranking_model = st.selectbox(
21
- "Select Ranking Model for Re-Ranking",
22
- ["cross-encoder/ms-marco-MiniLM-L-12-v2", "nvidia/nv-rerankqa-mistral-4b-v3"]
23
- )
24
-
25
- # Run retrieval pipeline on button click
26
- if st.button("Run Retrieval"):
27
- # Load dataset
28
- st.write("Loading dataset...")
29
- corpus, queries, qrels = load_dataset("nq")
30
-
31
- # Load selected embedding model
32
- st.write(f"Loading embedding model: {embedding_model}...")
33
- embed_model = load_embedding_model(embedding_model)
34
-
35
- # Retrieve top-k passages using embedding model
36
- st.write("Retrieving top-k passages...")
37
- top_k_passages = retrieve_top_k(embed_model, query, corpus, k=10)
38
-
39
- # Display retrieved passages
40
- st.write("Top-k passages before reranking:")
41
- for i, (passage, score) in enumerate(top_k_passages):
42
- st.write(f"{i+1}. Passage: {passage}, Score: {score:.4f}")
43
-
44
- # Load selected ranking model
45
- st.write(f"Loading ranking model: {ranking_model}...")
46
- rank_model, rank_tokenizer = load_ranking_model(ranking_model)
47
-
48
- # Rerank the retrieved passages
49
- st.write("Reranking passages...")
50
- ranked_passages = rerank(rank_model, rank_tokenizer, query, top_k_passages)
51
-
52
- # Display reranked passages
53
- st.write("Top-k passages after reranking:")
54
- for i, (passage, score) in enumerate(ranked_passages):
55
- st.write(f"{i+1}. Passage: {passage}, Score: {score:.4f}")
56
-
57
- # Evaluate using NDCG@10
58
- st.write("Evaluating NDCG@10...")
59
- query_id = list(queries.keys())[0] # Assuming we are using the first query for evaluation
60
- ndcg_score = evaluate_ndcg(ranked_passages, qrels[query_id])
61
- st.write(f"NDCG@10: {ndcg_score:.4f}")
62
-
63
- # Sidebar with instructions
64
- st.sidebar.title("Instructions")
65
- st.sidebar.write("""
66
- 1. Enter a question in the text input.
67
- 2. Select the embedding model for candidate retrieval.
68
- 3. Select the ranking model for reranking the retrieved passages.
69
- 4. Click 'Run Retrieval' to start the pipeline and display the results.
70
- """)
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer
3
+ import torch
4
+
5
+ # Load model and tokenizer
6
+ @st.cache_resource
7
+ def load_model():
8
+ model_path = 'C:\Users\neeli\Downloads\bert-tensorflow2-uncased-tf2-qa-v1'
9
+ model = AutoModelForQuestionAnswering.from_pretrained(model_path)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
11
+ return model, tokenizer
12
+
13
+ model, tokenizer = load_model()
14
+
15
+ # Function to get answer from question and context
16
+ def get_answer(question, context):
17
+ inputs = tokenizer.encode_plus(question, context, return_tensors='pt')
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+
21
+ answer_start = torch.argmax(outputs.start_logits)
22
+ answer_end = torch.argmax(outputs.end_logits) + 1
23
+ answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
24
+ return answer
25
+
26
+ # Streamlit UI
27
+ st.title("Question Answering Application")
28
+
29
+ question = st.text_input("Enter your question:")
30
+ context = st.text_area("Enter context text:", height=200)
31
+
32
+ if st.button("Get Answer"):
33
+ if question and context:
34
+ answer = get_answer(question, context)
35
+ st.write(f"**Answer:** {answer}")
36
+ else:
37
+ st.warning("Please enter both a question and context.")