import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Force light theme globally
st.markdown("""
    <style>
        /* Hide Streamlit's menu and footer */
        #MainMenu {visibility: hidden;}
        footer {visibility: hidden;}
        header {visibility: hidden;}
        
        /* Center and size the logo */
        .block-container {
            padding-top: 1rem;
        }
    </style>
""", unsafe_allow_html=True)

# Load model and tokenizer from Hugging Face Hub
@st.cache_resource
def load_model_and_tokenizer():
    model_name = "dejanseo/bulgarian-search-query-intent"
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

# Load resources
model, tokenizer = load_model_and_tokenizer()

# Page layout with clickable logo
st.markdown("""
    <div style="display: flex; justify-content: space-between; align-items: center;">
        <h1>Класификация на намерения за търсене</h1>
        <a href="https://dejan.ai" target="_blank">
            <img src="https://huggingface.co/spaces/dejanseo/bulgarian-search-query-intent-classifier/resolve/main/dejan-300x103.png" width="300">
        </a>
    </div>
""", unsafe_allow_html=True)

st.write(
    "Въведете една или повече заявки (всеки на нов ред) или качете `.txt` файл, в който "
    "всяка заявка е на отделен ред без допълнителни параметри. "
    "Моделът е създаден от [DEJAN AI](https://dejan.ai)."
)

# Текстово поле за въвеждане на заявки
queries_input = st.text_area("Въведете вашите заявки (по една на ред):")

# Качване на `.txt` файл
uploaded_file = st.file_uploader(
    "Качете `.txt` файл с заявки (всеки ред съдържа една заявка)", type=["txt"]
)

# Събиране на заявките от текстовото поле и/или файла
queries = []
if queries_input.strip():
    queries.extend([line.strip() for line in queries_input.splitlines() if line.strip()])
if uploaded_file is not None:
    file_content = uploaded_file.read().decode("utf-8")
    queries.extend([line.strip() for line in file_content.splitlines() if line.strip()])

# UI for button with spinner
button_disabled = False
if queries:
    button_disabled = False
else:
    button_disabled = True

if st.button("Класифицирай", disabled=button_disabled):
    if queries:
        with st.spinner("Обработване..."):
            # Tokenize in batch
            inputs = tokenizer(
                queries,
                return_tensors="pt",
                truncation=True,
                padding=True,
                max_length=256
            )

            # Run inference
            with torch.no_grad():
                outputs = model(**inputs)

            logits = outputs.logits
            predictions = logits.argmax(dim=-1).tolist()
            probabilities = F.softmax(logits, dim=-1)
            confidence_scores = probabilities.max(dim=-1).values.tolist()

            # Използване на наличната label mapping от модела
            id2label = model.config.id2label

            results = []
            for query, pred, conf in zip(queries, predictions, confidence_scores):
                predicted_label = id2label.get(str(pred), id2label.get(pred, "Неизвестно"))
                results.append({
                    "Заявка": query,
                    "Предсказано намерение": predicted_label,
                    "Доверие": f"{conf:.2f}"
                })

            st.write("### Резултати:")
            st.dataframe(results, use_container_width=True)
    else:
        st.warning("Моля, въведете поне една заявка, преди да класифицирате.")