Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from huggingface_hub import InferenceClient | |
import io | |
from PIL import Image | |
import torch | |
import numpy as np | |
import subprocess | |
import os | |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_id = 'J-LAB/Florence-vl3' | |
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(device).eval() | |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
def run_example(task_prompt, image): | |
inputs = processor(text=task_prompt, images=image, return_tensors="pt", padding=True).to(device) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
return parsed_answer | |
def process_image(image, task_prompt): | |
if isinstance(image, str): # Check if the image path is provided | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image) # Convert NumPy array to PIL Image | |
if task_prompt == 'Product Caption': | |
task_prompt = '<MORE_DETAILED_CAPTION>' | |
elif task_prompt == 'OCR': | |
task_prompt = '<OCR>' | |
results = run_example(task_prompt, image) | |
# Remove the key and get the text value | |
if results and task_prompt in results: | |
output_text = results[task_prompt] | |
else: | |
output_text = "" | |
return output_text | |
# Inicializando o cliente | |
client = InferenceClient(api_key=os.getenv('YOUR_HF_TOKEN')) | |
# Função de resposta para o chatbot | |
def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, image): | |
image_result = "" | |
if image is not None: | |
try: | |
image_result_caption = process_image(image, 'Product Caption') | |
image_result_ocr = process_image(image, 'OCR') | |
image_result = image_result_caption + " " + image_result_ocr # Concatenar os resultados | |
except Exception as e: | |
image_result = f"An error occurred with image processing: {str(e)}" | |
# Construindo a mensagem completa com o resultado da imagem | |
full_message = message | |
if image_result: | |
full_message = f"\n<image>{image_result}</image>\n\n{message}" | |
# Adicionando mensagens ao histórico | |
messages = [{"role": "system", "content": f'{system_message} a descrição das imagens enviadas pelo usuário ficam dentro da tag <image> </image>'}] | |
for user, assistant in history: | |
if user: | |
messages.append({"role": "user", "content": user}) | |
if assistant: | |
messages.append({"role": "assistant", "content": assistant}) | |
messages.append({"role": "user", "content": full_message}) | |
# Gerando a resposta | |
response = "" | |
try: | |
stream = client.chat.completions.create( | |
model="meta-llama/Llama-3.1-8B-Instruct", | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
stream=True | |
) | |
response = "" | |
for chunk in stream: | |
if chunk.choices[0].delta.content is not None: | |
token = chunk.choices[0].delta.content | |
response += token | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Atualizando o histórico, mas sem mostrar image_result no chat | |
history.append((message, response)) | |
return history, gr.update(value=None), gr.update(value="") | |
# Configurando a interface do Gradio | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot() | |
chat_input = gr.Textbox(placeholder="Enter message...", show_label=False) | |
image_input = gr.Image(type="filepath", label="Upload an image") | |
submit_btn = gr.Button("Send Message") | |
system_message = gr.Textbox(value="Você é um chatbot útil que sempre responde em português", label="System message") | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
submit_btn.click(respond, inputs=[chat_input, chatbot, system_message, max_tokens, temperature, top_p, image_input], outputs=[chatbot, image_input, chat_input]) | |
if __name__ == "__main__": | |
demo.launch(debug=True, quiet=True) | |