File size: 5,684 Bytes
c1772f8
 
 
 
 
 
f0340cd
 
c1772f8
f0340cd
 
c1772f8
 
 
 
 
 
 
 
 
 
 
ff20866
 
 
f0340cd
 
 
 
 
 
 
 
ff20866
f0340cd
 
ff20866
c1772f8
ff20866
c1772f8
 
 
41b022e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1772f8
 
 
 
 
f0340cd
 
ff20866
f0340cd
 
c1772f8
f0340cd
 
 
 
ff20866
f0340cd
ff20866
f0340cd
 
ff20866
c1772f8
 
 
 
 
 
 
 
 
 
 
f0340cd
c1772f8
 
 
 
 
 
ff20866
c1772f8
ff20866
c1772f8
 
 
 
 
 
 
 
 
 
 
 
 
 
ff20866
c1772f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41b022e
c1772f8
 
 
 
 
ff20866
c1772f8
 
 
 
 
 
f0340cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from huggingface_hub import InferenceClient
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
import os
import re
from unidecode import unidecode

# CSS para estilização
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
footer {visibility: hidden}
'''

# Inicializar o cliente de inferência
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

# Variável global para armazenar o retriever
global_retriever = None

# Função de pré-processamento de texto
def preprocess_text(text):
    """Pré-processa o texto removendo ruídos e normalizando."""
    text = re.sub(r'(Página|Page)\s+\d+(?:\s+of\s+\d+)?', '', text, flags=re.IGNORECASE)
    text = re.sub(r'\s+', ' ', text).strip()
    text = unidecode(text.lower())
    return text

# Configurar o retriever
def initialize_retriever(file_objs, persist_directory="chroma_db"):
    """Carrega documentos PDFs, pré-processa e cria um retriever híbrido."""
    global global_retriever
    if not file_objs:
        return "Nenhum documento carregado."
    
    documents = []
    for file_obj in file_objs:
        # Validar se é um PDF
        if not file_obj.name.lower().endswith('.pdf'):
            return f"Erro: O arquivo '{file_obj.name}' não é um PDF válido. Apenas arquivos .pdf são aceitos."
        try:
            loader = PyPDFLoader(file_obj.name)
            raw_docs = loader.load()
            for doc in raw_docs:
                doc.page_content = preprocess_text(doc.page_content)
                doc.metadata.update({"source": os.path.basename(file_obj.name)})
            documents.extend(raw_docs)
        except Exception as e:
            return f"Erro ao processar '{file_obj.name}': {str(e)}"
    
    if not documents:
        return "Nenhum conteúdo válido foi extraído dos PDFs."
    
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128)
    splits = text_splitter.split_documents(documents)
    
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    try:
        vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
        vectorstore.add_documents(splits)
    except:
        vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=persist_directory)
    
    semantic_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
    bm25_retriever = BM25Retriever.from_documents(splits)
    bm25_retriever.k = 2
    
    global_retriever = EnsembleRetriever(
        retrievers=[semantic_retriever, bm25_retriever],
        weights=[0.6, 0.4]
    )
    
    return "Documentos processados com sucesso!"

# Formatar o prompt para RAG
def format_prompt(message, history, retriever=None, system_prompt=None):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    if system_prompt:
        prompt += f"[SYS] {system_prompt} [/SYS]"
    if retriever:
        docs = retriever.get_relevant_documents(message)
        context = "\n".join([f"[{doc.metadata.get('source', 'Unknown')}, Page {doc.metadata.get('page', 'N/A')}] {doc.page_content}" for doc in docs])
        prompt += f"[CONTEXT] {context} [/CONTEXT]"
    prompt += f"[INST] {message} [/INST]"
    return prompt

# Função de geração com RAG
def generate(
    prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0
):
    global global_retriever
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history, global_retriever, system_prompt)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output

# Interface Gradio com RAG
def create_demo():
    with gr.Blocks(css=css) as demo:
        status = gr.State(value="Nenhum documento carregado")

        gr.Markdown("<h1>RAG Chatbot</h1>")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Markdown("### Carregar Documentos")
                file_input = gr.Files(label="Upload PDFs", file_types=["pdf"], file_count="multiple")
                process_btn = gr.Button("Processar Documentos")
                status_output = gr.Textbox(label="Status", value="Nenhum documento carregado")

        chat_interface = gr.ChatInterface(
            fn=generate,
            additional_inputs=[
                gr.Textbox(label="System Prompt", placeholder="Digite um prompt de sistema (opcional)", value=None)
            ],
            title="",
            chatbot=gr.Chatbot(type="messages")  # Atualizar para o formato 'messages'
        )

        process_btn.click(
            fn=initialize_retriever,
            inputs=[file_input],
            outputs=[status_output]
        )

    return demo

# Lançar a aplicação
demo = create_demo()
demo.queue().launch(share=False)