File size: 5,684 Bytes
c1772f8 f0340cd c1772f8 f0340cd c1772f8 ff20866 f0340cd ff20866 f0340cd ff20866 c1772f8 ff20866 c1772f8 41b022e c1772f8 f0340cd ff20866 f0340cd c1772f8 f0340cd ff20866 f0340cd ff20866 f0340cd ff20866 c1772f8 f0340cd c1772f8 ff20866 c1772f8 ff20866 c1772f8 ff20866 c1772f8 41b022e c1772f8 ff20866 c1772f8 f0340cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from huggingface_hub import InferenceClient
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
import os
import re
from unidecode import unidecode
# CSS para estilização
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
footer {visibility: hidden}
'''
# Inicializar o cliente de inferência
client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
# Variável global para armazenar o retriever
global_retriever = None
# Função de pré-processamento de texto
def preprocess_text(text):
"""Pré-processa o texto removendo ruídos e normalizando."""
text = re.sub(r'(Página|Page)\s+\d+(?:\s+of\s+\d+)?', '', text, flags=re.IGNORECASE)
text = re.sub(r'\s+', ' ', text).strip()
text = unidecode(text.lower())
return text
# Configurar o retriever
def initialize_retriever(file_objs, persist_directory="chroma_db"):
"""Carrega documentos PDFs, pré-processa e cria um retriever híbrido."""
global global_retriever
if not file_objs:
return "Nenhum documento carregado."
documents = []
for file_obj in file_objs:
# Validar se é um PDF
if not file_obj.name.lower().endswith('.pdf'):
return f"Erro: O arquivo '{file_obj.name}' não é um PDF válido. Apenas arquivos .pdf são aceitos."
try:
loader = PyPDFLoader(file_obj.name)
raw_docs = loader.load()
for doc in raw_docs:
doc.page_content = preprocess_text(doc.page_content)
doc.metadata.update({"source": os.path.basename(file_obj.name)})
documents.extend(raw_docs)
except Exception as e:
return f"Erro ao processar '{file_obj.name}': {str(e)}"
if not documents:
return "Nenhum conteúdo válido foi extraído dos PDFs."
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128)
splits = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
try:
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
vectorstore.add_documents(splits)
except:
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=persist_directory)
semantic_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
bm25_retriever = BM25Retriever.from_documents(splits)
bm25_retriever.k = 2
global_retriever = EnsembleRetriever(
retrievers=[semantic_retriever, bm25_retriever],
weights=[0.6, 0.4]
)
return "Documentos processados com sucesso!"
# Formatar o prompt para RAG
def format_prompt(message, history, retriever=None, system_prompt=None):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
if system_prompt:
prompt += f"[SYS] {system_prompt} [/SYS]"
if retriever:
docs = retriever.get_relevant_documents(message)
context = "\n".join([f"[{doc.metadata.get('source', 'Unknown')}, Page {doc.metadata.get('page', 'N/A')}] {doc.page_content}" for doc in docs])
prompt += f"[CONTEXT] {context} [/CONTEXT]"
prompt += f"[INST] {message} [/INST]"
return prompt
# Função de geração com RAG
def generate(
prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0
):
global global_retriever
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history, global_retriever, system_prompt)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
# Interface Gradio com RAG
def create_demo():
with gr.Blocks(css=css) as demo:
status = gr.State(value="Nenhum documento carregado")
gr.Markdown("<h1>RAG Chatbot</h1>")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Carregar Documentos")
file_input = gr.Files(label="Upload PDFs", file_types=["pdf"], file_count="multiple")
process_btn = gr.Button("Processar Documentos")
status_output = gr.Textbox(label="Status", value="Nenhum documento carregado")
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System Prompt", placeholder="Digite um prompt de sistema (opcional)", value=None)
],
title="",
chatbot=gr.Chatbot(type="messages") # Atualizar para o formato 'messages'
)
process_btn.click(
fn=initialize_retriever,
inputs=[file_input],
outputs=[status_output]
)
return demo
# Lançar a aplicação
demo = create_demo()
demo.queue().launch(share=False) |