Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import re | |
from collections import Counter | |
import string | |
import docx2txt | |
from io import BytesIO | |
def load_qa_pipeline(): | |
return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2") | |
def normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def remove_articles(text): | |
return re.sub(r'\b(a|an|the)\b', ' ', text) | |
def white_space_fix(text): | |
return ' '.join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return ''.join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def chunk_text(text, chunk_size=1000): | |
sentences = re.split(r'(?<=[.!?])\s+', text) | |
chunks = [] | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) <= chunk_size: | |
current_chunk += sentence + " " | |
else: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence + " " | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
def highlight_text(text, start_indices, chunk_size): | |
highlighted_text = text | |
offset = 0 | |
for i, start in enumerate(start_indices): | |
actual_start = start + (i * 7) # 7 is the length of the highlight tag | |
chunk_index = start // chunk_size | |
actual_start += chunk_index * chunk_size | |
highlighted_text = ( | |
highlighted_text[:actual_start + offset] + | |
"<mark>" + | |
highlighted_text[actual_start + offset:actual_start + offset + 10] + | |
"</mark>" + | |
highlighted_text[actual_start + offset + 10:] | |
) | |
offset += 13 # Length of "<mark></mark>" | |
return highlighted_text | |
def main(): | |
st.title("Document Search Engine") | |
# Load the QA pipeline | |
qa_pipeline = load_qa_pipeline() | |
# File upload for Word documents | |
uploaded_file = st.file_uploader("Upload a Word document", type=['docx']) | |
if uploaded_file is not None: | |
doc_text = docx2txt.process(BytesIO(uploaded_file.read())) | |
st.session_state['context'] = doc_text | |
# Context input | |
if 'context' not in st.session_state: | |
st.session_state['context'] = "" | |
context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300) | |
st.session_state['context'] = context | |
# Search input and button | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
question = st.text_input("Enter your search query:") | |
with col2: | |
search_button = st.button("Search") | |
if search_button: | |
if context and question: | |
chunks = chunk_text(context) | |
results = [] | |
for i, chunk in enumerate(chunks): | |
result = qa_pipeline(question=question, context=chunk) | |
result['chunk_index'] = i | |
results.append(result) | |
# Sort results by score and get top 3 | |
top_results = sorted(results, key=lambda x: x['score'], reverse=True)[:3] | |
st.subheader("Top 3 Results:") | |
for i, result in enumerate(top_results, 1): | |
st.write(f"{i}. Answer: {result['answer']}") | |
st.write(f" Confidence: {result['score']:.2f}") | |
# Highlight answers in the context | |
chunk_size = 1000 # Make sure this matches the chunk_size in chunk_text function | |
start_indices = [result['start'] + (result['chunk_index'] * chunk_size) for result in top_results] | |
highlighted_context = highlight_text(context, start_indices, chunk_size) | |
st.subheader("Context with Highlighted Answers:") | |
st.markdown(highlighted_context, unsafe_allow_html=True) | |
else: | |
st.warning("Please provide both context and search query.") | |
if __name__ == "__main__": | |
main() |