tareeb23's picture
Document Search Engine
1b5c287 verified
raw
history blame
4.08 kB
import streamlit as st
from transformers import pipeline
import re
from collections import Counter
import string
import docx2txt
from io import BytesIO
@st.cache_resource
def load_qa_pipeline():
return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2")
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def chunk_text(text, chunk_size=1000):
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= chunk_size:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def highlight_text(text, start_indices, chunk_size):
highlighted_text = text
offset = 0
for i, start in enumerate(start_indices):
actual_start = start + (i * 7) # 7 is the length of the highlight tag
chunk_index = start // chunk_size
actual_start += chunk_index * chunk_size
highlighted_text = (
highlighted_text[:actual_start + offset] +
"<mark>" +
highlighted_text[actual_start + offset:actual_start + offset + 10] +
"</mark>" +
highlighted_text[actual_start + offset + 10:]
)
offset += 13 # Length of "<mark></mark>"
return highlighted_text
def main():
st.title("Document Search Engine")
# Load the QA pipeline
qa_pipeline = load_qa_pipeline()
# File upload for Word documents
uploaded_file = st.file_uploader("Upload a Word document", type=['docx'])
if uploaded_file is not None:
doc_text = docx2txt.process(BytesIO(uploaded_file.read()))
st.session_state['context'] = doc_text
# Context input
if 'context' not in st.session_state:
st.session_state['context'] = ""
context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300)
st.session_state['context'] = context
# Search input and button
col1, col2 = st.columns([3, 1])
with col1:
question = st.text_input("Enter your search query:")
with col2:
search_button = st.button("Search")
if search_button:
if context and question:
chunks = chunk_text(context)
results = []
for i, chunk in enumerate(chunks):
result = qa_pipeline(question=question, context=chunk)
result['chunk_index'] = i
results.append(result)
# Sort results by score and get top 3
top_results = sorted(results, key=lambda x: x['score'], reverse=True)[:3]
st.subheader("Top 3 Results:")
for i, result in enumerate(top_results, 1):
st.write(f"{i}. Answer: {result['answer']}")
st.write(f" Confidence: {result['score']:.2f}")
# Highlight answers in the context
chunk_size = 1000 # Make sure this matches the chunk_size in chunk_text function
start_indices = [result['start'] + (result['chunk_index'] * chunk_size) for result in top_results]
highlighted_context = highlight_text(context, start_indices, chunk_size)
st.subheader("Context with Highlighted Answers:")
st.markdown(highlighted_context, unsafe_allow_html=True)
else:
st.warning("Please provide both context and search query.")
if __name__ == "__main__":
main()