Document Search Engine
Browse files
@@ -3,6 +3,8 @@ from transformers import pipeline
3 |
import re
4 |
from collections import Counter
5 |
import string
6 |
7 |
8 |
def load_qa_pipeline():
@@ -21,68 +23,91 @@ def normalize_answer(s):
21 |
return text.lower()
22 |
return white_space_fix(remove_articles(remove_punc(lower(s))))
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
def main():
40 |
41 |
42 |
# Load the QA pipeline
43 |
qa_pipeline = load_qa_pipeline()
44 |
45 |
46 |
47 |
48 |
49 |
50 |
question = st.text_input("Enter your question:")
51 |
question = question.strip() # Remove leading/trailing whitespace
52 |
53 |
54 |
55 |
56 |
57 |
58 |
st.warning("Question should not exceed 150 characters.")
59 |
60 |
61 |
62 |
63 |
64 |
65 |
actual_answer = st.text_input("Enter the actual answer:")
66 |
67 |
if st.button("Get Answer"):
68 |
if context and question:
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
st.warning("Please provide both context and
86 |
87 |
if __name__ == "__main__":
88 |
3 |
import re
4 |
from collections import Counter
5 |
import string
6 |
import docx2txt
7 |
from io import BytesIO
8 |
9 |
10 |
def load_qa_pipeline():
23 |
return text.lower()
24 |
return white_space_fix(remove_articles(remove_punc(lower(s))))
25 |
26 |
def chunk_text(text, chunk_size=1000):
27 |
sentences = re.split(r'(?<=[.!?])\s+', text)
28 |
chunks = []
29 |
current_chunk = ""
30 |
31 |
for sentence in sentences:
32 |
if len(current_chunk) + len(sentence) <= chunk_size:
33 |
current_chunk += sentence + " "
34 |
35 |
36 |
current_chunk = sentence + " "
37 |
38 |
if current_chunk:
39 |
40 |
41 |
return chunks
42 |
43 |
def highlight_text(text, start_indices, chunk_size):
44 |
highlighted_text = text
45 |
offset = 0
46 |
for i, start in enumerate(start_indices):
47 |
actual_start = start + (i * 7) # 7 is the length of the highlight tag
48 |
chunk_index = start // chunk_size
49 |
actual_start += chunk_index * chunk_size
50 |
highlighted_text = (
51 |
highlighted_text[:actual_start + offset] +
52 |
"<mark>" +
53 |
highlighted_text[actual_start + offset:actual_start + offset + 10] +
54 |
"</mark>" +
55 |
highlighted_text[actual_start + offset + 10:]
56 |
57 |
offset += 13 # Length of "<mark></mark>"
58 |
return highlighted_text
59 |
60 |
def main():
61 |
st.title("Document Search Engine")
62 |
63 |
# Load the QA pipeline
64 |
qa_pipeline = load_qa_pipeline()
65 |
66 |
# File upload for Word documents
67 |
uploaded_file = st.file_uploader("Upload a Word document", type=['docx'])
68 |
if uploaded_file is not None:
69 |
doc_text = docx2txt.process(BytesIO(
70 |
st.session_state['context'] = doc_text
71 |
72 |
# Context input
73 |
if 'context' not in st.session_state:
74 |
st.session_state['context'] = ""
75 |
context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300)
76 |
st.session_state['context'] = context
77 |
78 |
# Search input and button
79 |
col1, col2 = st.columns([3, 1])
80 |
with col1:
81 |
question = st.text_input("Enter your search query:")
82 |
with col2:
83 |
search_button = st.button("Search")
84 |
85 |
if search_button:
86 |
if context and question:
87 |
chunks = chunk_text(context)
88 |
results = []
89 |
for i, chunk in enumerate(chunks):
90 |
result = qa_pipeline(question=question, context=chunk)
91 |
result['chunk_index'] = i
92 |
93 |
94 |
# Sort results by score and get top 3
95 |
top_results = sorted(results, key=lambda x: x['score'], reverse=True)[:3]
96 |
97 |
st.subheader("Top 3 Results:")
98 |
for i, result in enumerate(top_results, 1):
99 |
st.write(f"{i}. Answer: {result['answer']}")
100 |
st.write(f" Confidence: {result['score']:.2f}")
101 |
102 |
# Highlight answers in the context
103 |
chunk_size = 1000 # Make sure this matches the chunk_size in chunk_text function
104 |
start_indices = [result['start'] + (result['chunk_index'] * chunk_size) for result in top_results]
105 |
highlighted_context = highlight_text(context, start_indices, chunk_size)
106 |
107 |
st.subheader("Context with Highlighted Answers:")
108 |
st.markdown(highlighted_context, unsafe_allow_html=True)
109 |
110 |
st.warning("Please provide both context and search query.")
111 |
112 |
if __name__ == "__main__":
113 |