from huggingface_hub import InferenceClient
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
import os
import re
from unidecode import unidecode

# CSS para estilização
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
footer {visibility: hidden}
'''

# Inicializar o cliente de inferência
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

# Variável global para armazenar o retriever
global_retriever = None

# Função de pré-processamento de texto
def preprocess_text(text):
    """Pré-processa o texto removendo ruídos e normalizando."""
    text = re.sub(r'(Página|Page)\s+\d+(?:\s+of\s+\d+)?', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\s+', ' ', text).strip()
    text = unidecode(text.lower())
    return text

# Configurar o retriever
def initialize_retriever(file_objs, persist_directory="chroma_db"):
    """Carrega documentos PDFs, pré-processa e cria um retriever híbrido."""
    global global_retriever
    if not file_objs:
        return "Nenhum documento carregado."
    
    documents = []
    for file_obj in file_objs:
        # Validar se é um PDF
        if not file_obj.name.lower().endswith('.pdf'):
            return f"Erro: O arquivo '{file_obj.name}' não é um PDF válido. Apenas arquivos .pdf são aceitos."
        try:
            loader = PyPDFLoader(file_obj.name)
            raw_docs = loader.load()
            for doc in raw_docs:
                doc.page_content = preprocess_text(doc.page_content)
                doc.metadata.update({"source": os.path.basename(file_obj.name)})
            documents.extend(raw_docs)
        except Exception as e:
            return f"Erro ao processar '{file_obj.name}': {str(e)}"
    
    if not documents:
        return "Nenhum conteúdo válido foi extraído dos PDFs."
    
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128)
    splits = text_splitter.split_documents(documents)
    
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    try:
        vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
        vectorstore.add_documents(splits)
    except:
        vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=persist_directory)
    
    semantic_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
    bm25_retriever = BM25Retriever.from_documents(splits)
    bm25_retriever.k = 2
    
    global_retriever = EnsembleRetriever(
        retrievers=[semantic_retriever, bm25_retriever],
        weights=[0.6, 0.4]
    )
    
    return "Documentos processados com sucesso!"

# Formatar o prompt para RAG
def format_prompt(message, history, retriever=None, system_prompt=None):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    if system_prompt:
        prompt += f"[SYS] {system_prompt} [/SYS]"
    if retriever:
        docs = retriever.get_relevant_documents(message)
        context = "\n".join([f"[{doc.metadata.get('source', 'Unknown')}, Page {doc.metadata.get('page', 'N/A')}] {doc.page_content}" for doc in docs])
        prompt += f"[CONTEXT] {context} [/CONTEXT]"
    prompt += f"[INST] {message} [/INST]"
    return prompt

# Função de geração com RAG
def generate(
    prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0
):
    global global_retriever
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history, global_retriever, system_prompt)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output

# Interface Gradio com RAG
def create_demo():
    with gr.Blocks(css=css) as demo:
        status = gr.State(value="Nenhum documento carregado")

        gr.Markdown("<h1>RAG Chatbot</h1>")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Carregar Documentos")
                file_input = gr.Files(label="Upload PDFs", file_types=["pdf"], file_count="multiple")
                process_btn = gr.Button("Processar Documentos")
                status_output = gr.Textbox(label="Status", value="Nenhum documento carregado")

        chat_interface = gr.ChatInterface(
            fn=generate,
            additional_inputs=[
                gr.Textbox(label="System Prompt", placeholder="Digite um prompt de sistema (opcional)", value=None)
            ],
            title="",
            chatbot=gr.Chatbot(type="messages")  # Atualizar para o formato 'messages'
        )

        process_btn.click(
            fn=initialize_retriever,
            inputs=[file_input],
            outputs=[status_output]
        )

    return demo

# Lançar a aplicação
demo = create_demo()
demo.queue().launch(share=False)