import streamlit as st
from transformers import pipeline
from diff_match_patch import diff_match_patch
from langdetect import detect
import time

# Load models
@st.cache_resource
def load_grammar_model():
    return pipeline("text2text-generation", model="vennify/t5-base-grammar-correction")

@st.cache_resource
def load_explainer_model():
    return pipeline("text2text-generation", model="google/flan-t5-large")

@st.cache_resource
def load_translation_ur_to_en():
    return pipeline("translation", model="Helsinki-NLP/opus-mt-ur-en")

@st.cache_resource
def load_translation_en_to_ur():
    return pipeline("translation", model="Helsinki-NLP/opus-mt-en-ur")

# Initialize models
grammar_model = load_grammar_model()
explainer_model = load_explainer_model()
translate_ur_en = load_translation_ur_to_en()
translate_en_ur = load_translation_en_to_ur()
dmp = diff_match_patch()

st.title("📝 AI Grammar & Writing Assistant (Multilingual)")
st.markdown("Supports English & Urdu inputs. Fix grammar, punctuation, spelling, tenses — with explanations and writing tips.")

# Initialize session state
if "corrected_text" not in st.session_state:
    st.session_state.corrected_text = ""
if "detected_lang" not in st.session_state:
    st.session_state.detected_lang = ""
if "history" not in st.session_state:
    st.session_state.history = []

user_input = st.text_area("✍️ Enter your sentence, paragraph, or essay:", height=200)

# Detect & Translate Urdu if needed
def detect_and_translate_input(text):
    lang = detect(text)
    if lang == "ur":
        st.info("🔄 Detected Urdu input. Translating to English for grammar correction...")
        translated = translate_ur_en(text)[0]['translation_text']
        return translated, lang
    return text, lang

# Button: Grammar Correction
if st.button("✅ Correct Grammar"):
    if user_input.strip():
        translated_input, lang = detect_and_translate_input(user_input)
        st.session_state.detected_lang = lang

        corrected = grammar_model(f"grammar: {translated_input}", max_length=512, do_sample=False)[0]["generated_text"]
        st.session_state.corrected_text = corrected

        # Show corrected text
        st.subheader("✅ Corrected Text (in English)")
        st.success(corrected)

        # Highlight changes
        st.subheader("🔍 Changes Highlighted")
        diffs = dmp.diff_main(translated_input, corrected)
        dmp.diff_cleanupSemantic(diffs)
        html_diff = ""
        for (op, data) in diffs:
            if op == -1:
                html_diff += f'<span style="background-color:#fbb;">{data}</span>'
            elif op == 1:
                html_diff += f'<span style="background-color:#bfb;">{data}</span>'
            else:
                html_diff += data
        st.markdown(f"<div style='font-family:monospace;'>{html_diff}</div>", unsafe_allow_html=True)

        # Optional Urdu output
        if lang == "ur":
            urdu_back = translate_en_ur(corrected)[0]['translation_text']
            st.subheader("🔄 Corrected Text (Back in Urdu)")
            st.success(urdu_back)

        # Save to history
        st.session_state.history.append({
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "original": user_input,
            "corrected": corrected,
            "lang": lang
        })

# Button: Explanation
if st.button("🧠 Explain Corrections"):
    if st.session_state.corrected_text:
        st.subheader("Line-by-Line Explanation")
        original_lines = user_input.split(".")
        for line in original_lines:
            if line.strip():
                prompt = f"Explain and fix issues in this sentence:\n'{line.strip()}.'"
                explanation = explainer_model(prompt, max_length=100)[0]["generated_text"]
                st.markdown(f"**🔸 {line.strip()}**")
                st.info(explanation)
    else:
        st.warning("Please correct the grammar first.")

# Button: Suggest Improvements
if st.button("💡 Suggest Writing Improvements"):
    if st.session_state.corrected_text:
        prompt = f"Suggest improvements to make this text clearer and more professional:\n\n{st.session_state.corrected_text}"
        suggestion = explainer_model(prompt, max_length=150)[0]["generated_text"]
        st.subheader("Improvement Suggestions")
        st.warning(suggestion)
    else:
        st.warning("Please correct the grammar first.")

# Download corrected text
if st.session_state.corrected_text:
    st.download_button("⬇️ Download Corrected Text", st.session_state.corrected_text, file_name="corrected_text.txt")

# History viewer
if st.checkbox("📜 Show My Correction History"):
    st.subheader("Correction History")
    for record in st.session_state.history:
        st.markdown(f"🕒 **{record['timestamp']}** | Language: `{record['lang']}`")
        st.markdown(f"**Original:** {record['original']}")
        st.markdown(f"**Corrected:** {record['corrected']}")
        st.markdown("---")