import streamlit as st
from transformers import pipeline
import re
import docx2txt
from io import BytesIO

@st.cache_resource
def load_qa_pipeline():
    return pipeline("question-answering", model="tareeb23/Roberta_SQUAD_V2", tokenizer="tareeb23/Roberta_SQUAD_V2")

def chunk_text(text, chunk_size=1000):
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current_chunk = ""
    
    for sentence in sentences:
        if len(current_chunk) + len(sentence) <= chunk_size:
            current_chunk += sentence + " "
        else:
            chunks.append(current_chunk.strip())
            current_chunk = sentence + " "
    
    if current_chunk:
        chunks.append(current_chunk.strip())
    
    return chunks

def get_top_answers(qa_pipeline, question, context, top_k=3, score_limit=0.1):
    chunks = chunk_text(context)
    results = []
    
    for i, chunk in enumerate(chunks):
        result = qa_pipeline(question=question, context=chunk)
        result['chunk_index'] = i
        result['chunk_start'] = i * 1000  # Approximate start position in original context
        results.append(result)
    
    # Sort results by score, filter by score limit, and get top k
    filtered_results = [r for r in results if r['score'] >= score_limit]
    top_results = sorted(filtered_results, key=lambda x: x['score'], reverse=True)[:top_k]
    
    return top_results

def highlight_answer(text, answer, start):
    return text[:start] + "**" + answer + "**" + text[start+len(answer):]

def main():
    st.title("Document Search Engine")

    # Load the QA pipeline
    qa_pipeline = load_qa_pipeline()

    # File upload for Word documents
    uploaded_file = st.file_uploader("Upload a Word document", type=['docx'])
    if uploaded_file is not None:
        doc_text = docx2txt.process(BytesIO(uploaded_file.read()))
        st.session_state['context'] = doc_text

    # Context input
    if 'context' not in st.session_state:
        st.session_state['context'] = ""
    context = st.text_area("Enter or edit the context:", value=st.session_state['context'], height=300)
    st.session_state['context'] = context

    # Search input and button
    col1, col2, col3 = st.columns([3, 1, 1])
    with col1:
        question = st.text_input("Enter your search query:")
    with col2:
        top_k = st.number_input("Top K results", min_value=1, max_value=10, value=3)
    with col3:
        score_limit = st.number_input("Score limit", min_value=0.0, max_value=1.0, value=0.1, step=0.05)

    if st.button("Search"):
        if context and question:
            top_results = get_top_answers(qa_pipeline, question, context, top_k=top_k, score_limit=score_limit)
            
            if top_results:
                st.subheader(f"Top {len(top_results)} Results:")
                for i, result in enumerate(top_results, 1):
                    st.markdown(f"{i}. Answer: **{result['answer']}** (Confidence: {result['score']:.4f})")

                
                st.subheader("Context with Highlighted Answers:")
                highlighted_context = context
                for result in reversed(top_results):  # Reverse to avoid messing up indices
                    start = result['chunk_start'] + result['start']
                    highlighted_context = highlight_answer(highlighted_context, result['answer'], start)
                st.markdown(highlighted_context)
            else:
                st.warning("No results found above the score limit.")
        else:
            st.warning("Please provide both context and search query.")

if __name__ == "__main__":
    main()