tareeb23 commited on
Commit
1b5c287
·
verified ·
1 Parent(s): 806829b

Document Search Engine

Browse files
Files changed (1) hide show
  1. app.py +74 -49
app.py CHANGED
@@ -3,6 +3,8 @@ from transformers import pipeline
3
  import re
4
  from collections import Counter
5
  import string
 
 
6
 
7
  @st.cache_resource
8
  def load_qa_pipeline():
@@ -21,68 +23,91 @@ def normalize_answer(s):
21
  return text.lower()
22
  return white_space_fix(remove_articles(remove_punc(lower(s))))
23
 
24
- def compute_exact_match(prediction, ground_truth):
25
- return int(normalize_answer(prediction) == normalize_answer(ground_truth))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- def compute_f1(prediction, ground_truth):
28
- prediction_tokens = normalize_answer(prediction).split()
29
- ground_truth_tokens = normalize_answer(ground_truth).split()
30
- common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
31
- num_same = sum(common.values())
32
- if num_same == 0:
33
- return 0
34
- precision = 1.0 * num_same / len(prediction_tokens)
35
- recall = 1.0 * num_same / len(ground_truth_tokens)
36
- f1 = (2 * precision * recall) / (precision + recall)
37
- return f1
 
 
 
 
 
38
 
39
  def main():
40
- st.title("Question Answering with RoBERTa")
41
 
42
  # Load the QA pipeline
43
  qa_pipeline = load_qa_pipeline()
44
 
45
- # User input for context
46
- context = st.text_area("Enter the context:", height=200)
47
- context = context.strip() # Remove leading/trailing whitespace
48
-
49
- # User input for question
50
- question = st.text_input("Enter your question:")
51
- question = question.strip() # Remove leading/trailing whitespace
52
 
53
- # Check for context and question length
54
- if len(context) > 1500:
55
- st.warning("Context should not exceed 1500 characters.")
56
- return
57
- if len(question) > 150:
58
- st.warning("Question should not exceed 150 characters.")
59
- return
60
 
61
- # Option to calculate scores
62
- calculate_scores = st.checkbox("Calculate scores")
 
 
 
 
63
 
64
- if calculate_scores:
65
- actual_answer = st.text_input("Enter the actual answer:")
66
-
67
- if st.button("Get Answer"):
68
  if context and question:
69
- # Get the answer
70
- result = qa_pipeline(question=question, context=context)
 
 
 
 
71
 
72
- # Display the result
73
- st.subheader("Answer:")
74
- st.write(result['answer'])
75
- st.write(f"Confidence: {result['score']:.2f}")
76
-
77
- # Calculate and display scores if option is selected
78
- if calculate_scores and actual_answer:
79
- em_score = compute_exact_match(result['answer'], actual_answer)
80
- f1_score = compute_f1(result['answer'], actual_answer)
81
- st.subheader("Scores:")
82
- st.write(f"Exact Match: {em_score}")
83
- st.write(f"F1 Score: {f1_score:.4f}")
 
 
 
84
  else:
85
- st.warning("Please provide both context and question.")
86
 
87
  if __name__ == "__main__":
88
  main()
 
3
  import re
4
  from collections import Counter
5
  import string
6
+ import docx2txt
7
+ from io import BytesIO
8
 
9
  @st.cache_resource
10
  def load_qa_pipeline():
 
23
  return text.lower()
24
  return white_space_fix(remove_articles(remove_punc(lower(s))))
25
 
26
+ def chunk_text(text, chunk_size=1000):
27
+ sentences = re.split(r'(?<=[.!?])\s+', text)
28
+ chunks = []
29
+ current_chunk = ""
30
+
31
+ for sentence in sentences:
32
+ if len(current_chunk) + len(sentence) <= chunk_size:
33
+ current_chunk += sentence + " "
34
+ else:
35
+ chunks.append(current_chunk.strip())
36
+ current_chunk = sentence + " "
37
+
38
+ if current_chunk:
39
+ chunks.append(current_chunk.strip())
40
+
41
+ return chunks
42
 
43
+ def highlight_text(text, start_indices, chunk_size):
44
+ highlighted_text = text
45
+ offset = 0
46
+ for i, start in enumerate(start_indices):
47
+ actual_start = start + (i * 7) # 7 is the length of the highlight tag
48
+ chunk_index = start // chunk_size
49
+ actual_start += chunk_index * chunk_size
50
+ highlighted_text = (
51
+ highlighted_text[:actual_start + offset] +
52
+ "<mark>" +
53
+ highlighted_text[actual_start + offset:actual_start + offset + 10] +
54
+ "</mark>" +
55
+ highlighted_text[actual_start + offset + 10:]
56
+ )
57
+ offset += 13 # Length of "<mark></mark>"
58
+ return highlighted_text
59
 
60
  def main():
61
+ st.title("Document Search Engine")
62
 
63
  # Load the QA pipeline
64
  qa_pipeline = load_qa_pipeline()
65
 
66
+ # File upload for Word documents
67
+ uploaded_file = st.file_uploader("Upload a Word document", type=['docx'])
68
+ if uploaded_file is not None:
69
+ doc_text = docx2txt.process(BytesIO(uploaded_file.read()))
70
+ st.session_state['context'] = doc_text
 
 
71
 
72
+ # Context input
73
+ if 'context' not in st.session_state:
74
+ st.session_state['context'] = ""
75
+ context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300)
76
+ st.session_state['context'] = context
 
 
77
 
78
+ # Search input and button
79
+ col1, col2 = st.columns([3, 1])
80
+ with col1:
81
+ question = st.text_input("Enter your search query:")
82
+ with col2:
83
+ search_button = st.button("Search")
84
 
85
+ if search_button:
 
 
 
86
  if context and question:
87
+ chunks = chunk_text(context)
88
+ results = []
89
+ for i, chunk in enumerate(chunks):
90
+ result = qa_pipeline(question=question, context=chunk)
91
+ result['chunk_index'] = i
92
+ results.append(result)
93
 
94
+ # Sort results by score and get top 3
95
+ top_results = sorted(results, key=lambda x: x['score'], reverse=True)[:3]
96
+
97
+ st.subheader("Top 3 Results:")
98
+ for i, result in enumerate(top_results, 1):
99
+ st.write(f"{i}. Answer: {result['answer']}")
100
+ st.write(f" Confidence: {result['score']:.2f}")
101
+
102
+ # Highlight answers in the context
103
+ chunk_size = 1000 # Make sure this matches the chunk_size in chunk_text function
104
+ start_indices = [result['start'] + (result['chunk_index'] * chunk_size) for result in top_results]
105
+ highlighted_context = highlight_text(context, start_indices, chunk_size)
106
+
107
+ st.subheader("Context with Highlighted Answers:")
108
+ st.markdown(highlighted_context, unsafe_allow_html=True)
109
  else:
110
+ st.warning("Please provide both context and search query.")
111
 
112
  if __name__ == "__main__":
113
  main()