tareeb23 commited on
Commit
7cf5172
·
verified ·
1 Parent(s): 824e52b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+ import re
4
+ from collections import Counter
5
+
6
+ @st.cache_resource
7
+ def load_qa_pipeline():
8
+ return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2")
9
+
10
+ def normalize_answer(s):
11
+ """Lower text and remove punctuation, articles and extra whitespace."""
12
+ def remove_articles(text):
13
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
14
+
15
+ def white_space_fix(text):
16
+ return ' '.join(text.split())
17
+
18
+ def remove_punc(text):
19
+ exclude = set(string.punctuation)
20
+ return ''.join(ch for ch in text if ch not in exclude)
21
+
22
+ def lower(text):
23
+ return text.lower()
24
+
25
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
26
+
27
+ def compute_exact_match(prediction, ground_truth):
28
+ return int(normalize_answer(prediction) == normalize_answer(ground_truth))
29
+
30
+ def compute_f1(prediction, ground_truth):
31
+ prediction_tokens = normalize_answer(prediction).split()
32
+ ground_truth_tokens = normalize_answer(ground_truth).split()
33
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
34
+ num_same = sum(common.values())
35
+ if num_same == 0:
36
+ return 0
37
+ precision = 1.0 * num_same / len(prediction_tokens)
38
+ recall = 1.0 * num_same / len(ground_truth_tokens)
39
+ f1 = (2 * precision * recall) / (precision + recall)
40
+ return f1
41
+
42
+ def main():
43
+ st.title("Question Answering with RoBERTa")
44
+
45
+ # Load the QA pipeline
46
+ qa_pipeline = load_qa_pipeline()
47
+
48
+ # User input for context
49
+ context = st.text_area("Enter the context:", height=200)
50
+
51
+ # User input for question
52
+ question = st.text_input("Enter your question:")
53
+
54
+ if st.button("Get Answer"):
55
+ if context and question:
56
+ # Get the answer
57
+ result = qa_pipeline(question=question, context=context)
58
+
59
+ # Display the result
60
+ st.subheader("Answer:")
61
+ st.write(result['answer'])
62
+ st.write(f"Confidence: {result['score']:.2f}")
63
+
64
+ # Store the result for later use
65
+ st.session_state.last_answer = result['answer']
66
+
67
+ # Show option to calculate scores
68
+ st.subheader("Calculate Scores")
69
+ if st.checkbox("Show score calculation"):
70
+ actual_answer = st.text_input("Enter the actual answer:")
71
+ if st.button("Calculate Scores"):
72
+ if actual_answer:
73
+ em_score = compute_exact_match(result['answer'], actual_answer)
74
+ f1_score = compute_f1(result['answer'], actual_answer)
75
+ st.subheader("Scores:")
76
+ st.write(f"Exact Match: {em_score}")
77
+ st.write(f"F1 Score: {f1_score:.4f}")
78
+ else:
79
+ st.warning("Please enter the actual answer.")
80
+ else:
81
+ st.warning("Please provide both context and question.")
82
+
83
+ if __name__ == "__main__":
84
+ main()