tareeb23's picture
Clean Output
7525fd7 verified
import streamlit as st
from transformers import pipeline
import re
import docx2txt
from io import BytesIO
@st.cache_resource
def load_qa_pipeline():
return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2", tokenizer="tareeb23/Roberta_SQUAD_V2")
def chunk_text(text, chunk_size=1000):
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= chunk_size:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def get_top_answers(qa_pipeline, question, context, top_k=3, score_limit=0.1):
chunks = chunk_text(context)
results = []
for i, chunk in enumerate(chunks):
result = qa_pipeline(question=question, context=chunk)
result['chunk_index'] = i
result['chunk_start'] = i * 1000 # Approximate start position in original context
results.append(result)
# Sort results by score, filter by score limit, and get top k
filtered_results = [r for r in results if r['score'] >= score_limit]
top_results = sorted(filtered_results, key=lambda x: x['score'], reverse=True)[:top_k]
return top_results
def highlight_answer(text, answer, start):
return text[:start] + "**" + answer + "**" + text[start+len(answer):]
def main():
st.title("Document Search Engine")
# Load the QA pipeline
qa_pipeline = load_qa_pipeline()
# File upload for Word documents
uploaded_file = st.file_uploader("Upload a Word document", type=['docx'])
if uploaded_file is not None:
doc_text = docx2txt.process(BytesIO(uploaded_file.read()))
st.session_state['context'] = doc_text
# Context input
if 'context' not in st.session_state:
st.session_state['context'] = ""
context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300)
st.session_state['context'] = context
# Search input and button
col1, col2, col3 = st.columns([3, 1, 1])
with col1:
question = st.text_input("Enter your search query:")
with col2:
top_k = st.number_input("Top K results", min_value=1, max_value=10, value=3)
with col3:
score_limit = st.number_input("Score limit", min_value=0.0, max_value=1.0, value=0.1, step=0.05)
if st.button("Search"):
if context and question:
top_results = get_top_answers(qa_pipeline, question, context, top_k=top_k, score_limit=score_limit)
if top_results:
st.subheader(f"Top {len(top_results)} Results:")
for i, result in enumerate(top_results, 1):
st.markdown(f"{i}. Answer: **{result['answer']}** (Confidence: {result['score']:.4f})")
st.subheader("Context with Highlighted Answers:")
highlighted_context = context
for result in reversed(top_results): # Reverse to avoid messing up indices
start = result['chunk_start'] + result['start']
highlighted_context = highlight_answer(highlighted_context, result['answer'], start)
st.markdown(highlighted_context)
else:
st.warning("No results found above the score limit.")
else:
st.warning("Please provide both context and search query.")
if __name__ == "__main__":
main()