Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import re | |
import docx2txt | |
from io import BytesIO | |
def load_qa_pipeline(): | |
return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2", tokenizer="tareeb23/Roberta_SQUAD_V2") | |
def chunk_text(text, chunk_size=1000): | |
sentences = re.split(r'(?<=[.!?])\s+', text) | |
chunks = [] | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) <= chunk_size: | |
current_chunk += sentence + " " | |
else: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence + " " | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return chunks | |
def get_top_answers(qa_pipeline, question, context, top_k=3, score_limit=0.1): | |
chunks = chunk_text(context) | |
results = [] | |
for i, chunk in enumerate(chunks): | |
result = qa_pipeline(question=question, context=chunk) | |
result['chunk_index'] = i | |
result['chunk_start'] = i * 1000 # Approximate start position in original context | |
results.append(result) | |
# Sort results by score, filter by score limit, and get top k | |
filtered_results = [r for r in results if r['score'] >= score_limit] | |
top_results = sorted(filtered_results, key=lambda x: x['score'], reverse=True)[:top_k] | |
return top_results | |
def highlight_answer(text, answer, start): | |
return text[:start] + "**" + answer + "**" + text[start+len(answer):] | |
def main(): | |
st.title("Document Search Engine") | |
# Load the QA pipeline | |
qa_pipeline = load_qa_pipeline() | |
# File upload for Word documents | |
uploaded_file = st.file_uploader("Upload a Word document", type=['docx']) | |
if uploaded_file is not None: | |
doc_text = docx2txt.process(BytesIO(uploaded_file.read())) | |
st.session_state['context'] = doc_text | |
# Context input | |
if 'context' not in st.session_state: | |
st.session_state['context'] = "" | |
context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300) | |
st.session_state['context'] = context | |
# Search input and button | |
col1, col2, col3 = st.columns([3, 1, 1]) | |
with col1: | |
question = st.text_input("Enter your search query:") | |
with col2: | |
top_k = st.number_input("Top K results", min_value=1, max_value=10, value=3) | |
with col3: | |
score_limit = st.number_input("Score limit", min_value=0.0, max_value=1.0, value=0.1, step=0.05) | |
if st.button("Search"): | |
if context and question: | |
top_results = get_top_answers(qa_pipeline, question, context, top_k=top_k, score_limit=score_limit) | |
if top_results: | |
st.subheader(f"Top {len(top_results)} Results:") | |
for i, result in enumerate(top_results, 1): | |
st.markdown(f"{i}. Answer: **{result['answer']}** (Confidence: {result['score']:.4f})") | |
st.subheader("Context with Highlighted Answers:") | |
highlighted_context = context | |
for result in reversed(top_results): # Reverse to avoid messing up indices | |
start = result['chunk_start'] + result['start'] | |
highlighted_context = highlight_answer(highlighted_context, result['answer'], start) | |
st.markdown(highlighted_context) | |
else: | |
st.warning("No results found above the score limit.") | |
else: | |
st.warning("Please provide both context and search query.") | |
if __name__ == "__main__": | |
main() |