Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
import subprocess | |
import os | |
import shutil | |
import string | |
import random | |
import glob | |
from pypdf import PdfReader | |
from sentence_transformers import SentenceTransformer | |
# Configurações do modelo | |
MODEL_NAME = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m") | |
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", 128)) | |
DEFAULT_MAX_CHARACTERS = int(os.environ.get("DEFAULT_MAX_CHARACTERS", 258)) | |
# Carregue o modelo de linguagem | |
model = SentenceTransformer(MODEL_NAME) | |
# Função para incorporar consultas e documentos | |
def embed(queries, chunks): | |
query_embeddings = model.encode(queries, prompt_name="query") | |
document_embeddings = model.encode(chunks) | |
scores = query_embeddings @ document_embeddings.T | |
results = {} | |
for query, query_scores in zip(queries, scores): | |
chunk_idxs = [i for i in range(len(chunks))] | |
results[query] = list(zip(chunk_idxs, query_scores)) | |
return results | |
# Função para extrair texto de arquivos PDF | |
def extract_text_from_pdf(reader): | |
full_text = "" | |
for idx, page in enumerate(reader.pages): | |
text = page.extract_text() | |
if len(text) > 0: | |
full_text += f"---- Página {idx} ----\n" + page.extract_text() + "\n\n" | |
return full_text.strip() | |
# Função para converter arquivos em texto | |
def convert(filename): | |
plain_text_filetypes = [ | |
".txt", | |
".csv", | |
".tsv", | |
".md", | |
".yaml", | |
".toml", | |
".json", | |
".json5", | |
".jsonc", | |
] | |
if any(filename.endswith(ft) for ft in plain_text_filetypes): | |
with open(filename, "r") as f: | |
return f.read() | |
if filename.endswith(".pdf"): | |
return extract_text_from_pdf(PdfReader(filename)) | |
raise ValueError(f"Tipo de arquivo não suportado: {filename}") | |
# Função para dividir texto em pedaços | |
def chunk_to_length(text, max_length=512): | |
chunks = [] | |
while len(text) > max_length: | |
chunks.append(text[:max_length]) | |
text = text[max_length:] | |
chunks.append(text) | |
return chunks | |
# Função para prever pedaços relevantes | |
def predict(query, max_characters): | |
query_embedding = model.encode(query, prompt_name="query") | |
all_chunks = [] | |
for filename, doc in docs.items(): | |
similarities = doc["embeddings"] @ query_embedding.T | |
all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)]) | |
all_chunks.sort(key=lambda x: x[2], reverse=True) | |
relevant_chunks = {} | |
total_chars = 0 | |
for filename, chunk, _ in all_chunks: | |
if total_chars + len(chunk) <= max_characters: | |
if filename not in relevant_chunks: | |
relevant_chunks[filename] = [] | |
relevant_chunks[filename].append(chunk) | |
total_chars += len(chunk) | |
else: | |
break | |
return {"relevant_chunks": relevant_chunks} | |
# Carregue os documentos | |
docs = {} | |
for filename in glob.glob("src/*"): | |
if filename.endswith("add_your_files_here"): | |
continue | |
converted_doc = convert(filename) | |
chunks = chunk_to_length(converted_doc, CHUNK_SIZE) | |
embeddings = model.encode(chunks) | |
docs[filename] = { | |
"chunks": chunks, | |
"embeddings": embeddings, | |
} | |
# Crie a interface da ferramenta | |
gr.Interface( | |
predict, | |
inputs=[ | |
gr.Textbox(label="Consulta feita sobre os documentos"), | |
gr.Number(label="Máximo de caracteres de saída", value=DEFAULT_MAX_CHARACTERS), | |
], | |
outputs=[gr.Dict(label="Pedaços relevantes")], | |
title="Demonstração do modelo de ferramenta da comunidade ", | |
description='''"Para usar o no HuggingChat com seus próprios documentos | |
, comece clonando este espaço, adicione seus documentos à pasta `src` e então crie uma ferramenta comunitária com este espaço!" | |
,''' | |
).launch() |