File size: 2,928 Bytes
29950c3
cda11d3
0b1df81
f3c27dd
09984c6
 
 
 
cf98129
0b1df81
 
 
29a9d45
cda11d3
29a9d45
09984c6
f3c27dd
29a9d45
 
09984c6
690d40a
09984c6
 
 
 
690d40a
09984c6
 
0b1df81
 
791df4b
 
 
 
cf98129
09984c6
 
cf98129
09984c6
cf98129
09984c6
f3c27dd
 
 
 
 
29a9d45
 
 
 
 
f3c27dd
 
 
 
 
29a9d45
cf98129
cf514d7
cf98129
 
 
cf514d7
 
29a9d45
690d40a
29a9d45
0b1df81
cf98129
 
 
09984c6
cf98129
09984c6
29a9d45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from fastapi import FastAPI, Body
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch
from typing import List, Dict, Union
import base64
import requests
from PIL import Image
from io import BytesIO
from qwen_vl_utils import process_vision_info

app = FastAPI()

# Carrega o modelo e o processor
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

def process_image(image_data: str) -> Image.Image:
    """Processa uma imagem a partir de URL, base64 ou caminho local."""
    if image_data.startswith(("http://", "https://")):
        response = requests.get(image_data)
        response.raise_for_status()
        img = Image.open(BytesIO(response.content))
    elif image_data.startswith("data:image"):
        img_data = base64.b64decode(image_data.split(",")[1])
        img = Image.open(BytesIO(img_data))
    else:
        img = Image.open(image_data)
    return img

@app.post("/predict")
async def predict(messages: List[Dict[str, Union[str, List[Dict[str, Union[str, None]]]]]] = Body(...)):
    """
    Endpoint para prever respostas com base nas mensagens fornecidas.
    """
    # Processa as mensagens para texto e imagens
    texts = []
    image_inputs = []
    video_inputs = []

    # Utiliza o qwen_vl_utils para processar as informações visuais
    for message in messages:
        content = message.get("content")
        if isinstance(content, str):
            texts.append(processor.apply_chat_template(content, tokenize=False, add_generation_prompt=True))
        elif isinstance(content, list):
            for item in content:
                if item.get("type") == "text":
                    texts.append(processor.apply_chat_template(item["text"], tokenize=False, add_generation_prompt=True))
                elif item.get("type") == "image":
                    image = process_image(item["image"])
                    image_inputs.append(image)
                else:
                    raise ValueError(f"Formato inválido para o item: {item}")
        else:
            raise ValueError(f"Formato inválido para o conteúdo: {content}")

    # Prepara inputs para o modelo
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text for text in texts],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt"
    ).to("cpu")

    # Gera as respostas
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_texts = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return {"response": output_texts}