|
from huggingface_hub import InferenceClient |
|
import gradio as gr |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.retrievers import BM25Retriever |
|
from langchain.retrievers import EnsembleRetriever |
|
import os |
|
import re |
|
from unidecode import unidecode |
|
|
|
|
|
css = ''' |
|
.gradio-container{max-width: 1000px !important} |
|
h1{text-align:center} |
|
footer {visibility: hidden} |
|
''' |
|
|
|
|
|
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3") |
|
|
|
|
|
global_retriever = None |
|
|
|
|
|
def preprocess_text(text): |
|
"""Pré-processa o texto removendo ruídos e normalizando.""" |
|
text = re.sub(r'(Página|Page)\s+\d+(?:\s+of\s+\d+)?', '', text, flags=re.IGNORECASE) |
|
text = re.sub(r'\s+', ' ', text).strip() |
|
text = unidecode(text.lower()) |
|
return text |
|
|
|
|
|
def initialize_retriever(file_objs, persist_directory="chroma_db"): |
|
"""Carrega documentos PDFs, pré-processa e cria um retriever híbrido.""" |
|
global global_retriever |
|
if not file_objs: |
|
return "Nenhum documento carregado." |
|
|
|
documents = [] |
|
for file_obj in file_objs: |
|
|
|
if not file_obj.name.lower().endswith('.pdf'): |
|
return f"Erro: O arquivo '{file_obj.name}' não é um PDF válido. Apenas arquivos .pdf são aceitos." |
|
try: |
|
loader = PyPDFLoader(file_obj.name) |
|
raw_docs = loader.load() |
|
for doc in raw_docs: |
|
doc.page_content = preprocess_text(doc.page_content) |
|
doc.metadata.update({"source": os.path.basename(file_obj.name)}) |
|
documents.extend(raw_docs) |
|
except Exception as e: |
|
return f"Erro ao processar '{file_obj.name}': {str(e)}" |
|
|
|
if not documents: |
|
return "Nenhum conteúdo válido foi extraído dos PDFs." |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128) |
|
splits = text_splitter.split_documents(documents) |
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
try: |
|
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings) |
|
vectorstore.add_documents(splits) |
|
except: |
|
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=persist_directory) |
|
|
|
semantic_retriever = vectorstore.as_retriever(search_kwargs={"k": 2}) |
|
bm25_retriever = BM25Retriever.from_documents(splits) |
|
bm25_retriever.k = 2 |
|
|
|
global_retriever = EnsembleRetriever( |
|
retrievers=[semantic_retriever, bm25_retriever], |
|
weights=[0.6, 0.4] |
|
) |
|
|
|
return "Documentos processados com sucesso!" |
|
|
|
|
|
def format_prompt(message, history, retriever=None, system_prompt=None): |
|
prompt = "<s>" |
|
for user_prompt, bot_response in history: |
|
prompt += f"[INST] {user_prompt} [/INST]" |
|
prompt += f" {bot_response}</s> " |
|
if system_prompt: |
|
prompt += f"[SYS] {system_prompt} [/SYS]" |
|
if retriever: |
|
docs = retriever.get_relevant_documents(message) |
|
context = "\n".join([f"[{doc.metadata.get('source', 'Unknown')}, Page {doc.metadata.get('page', 'N/A')}] {doc.page_content}" for doc in docs]) |
|
prompt += f"[CONTEXT] {context} [/CONTEXT]" |
|
prompt += f"[INST] {message} [/INST]" |
|
return prompt |
|
|
|
|
|
def generate( |
|
prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0 |
|
): |
|
global global_retriever |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
|
|
formatted_prompt = format_prompt(prompt, history, global_retriever, system_prompt) |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
yield output |
|
|
|
|
|
def create_demo(): |
|
with gr.Blocks(css=css) as demo: |
|
status = gr.State(value="Nenhum documento carregado") |
|
|
|
gr.Markdown("<h1>RAG Chatbot</h1>") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("### Carregar Documentos") |
|
file_input = gr.Files(label="Upload PDFs", file_types=["pdf"], file_count="multiple") |
|
process_btn = gr.Button("Processar Documentos") |
|
status_output = gr.Textbox(label="Status", value="Nenhum documento carregado") |
|
|
|
chat_interface = gr.ChatInterface( |
|
fn=generate, |
|
additional_inputs=[ |
|
gr.Textbox(label="System Prompt", placeholder="Digite um prompt de sistema (opcional)", value=None) |
|
], |
|
title="", |
|
chatbot=gr.Chatbot(type="messages") |
|
) |
|
|
|
process_btn.click( |
|
fn=initialize_retriever, |
|
inputs=[file_input], |
|
outputs=[status_output] |
|
) |
|
|
|
return demo |
|
|
|
|
|
demo = create_demo() |
|
demo.queue().launch(share=False) |