lucianotonet commited on
Commit
09984c6
·
1 Parent(s): 29950c3

Implementa suporte a imagens no endpoint de previsão

Browse files

Atualiza o processamento de mensagens para lidar com diferentes tipos de entrada, como imagens via URL, base64 e caminhos locais. Essa mudança aprimora a flexibilidade do modelo, permitindo que ele receba entradas visuais diretamente, o que pode melhorar a qualidade das previsões em cenários que exigem contexto visual. Além disso, ajusta o processamento de texto para suportar múltiplos conteúdos, garantindo uma integração mais eficaz com as funcionalidades do modelo.

Files changed (1) hide show
  1. app.py +37 -7
app.py CHANGED
@@ -3,19 +3,45 @@ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
  from typing import List, Dict
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
  model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
10
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @app.post("/predict")
13
  async def predict(messages: List[Dict] = Body(...)):
14
  # Processamento e inferência
15
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
16
- image_inputs, video_inputs = process_vision_info(messages)
 
 
 
 
 
 
 
 
 
17
  inputs = processor(
18
- text=[text],
19
  images=image_inputs,
20
  videos=video_inputs,
21
  padding=True,
@@ -24,6 +50,10 @@ async def predict(messages: List[Dict] = Body(...)):
24
  inputs = inputs.to("cpu") # Altere para "cuda" se tiver GPU disponível
25
 
26
  generated_ids = model.generate(**inputs, max_new_tokens=128)
27
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
28
- output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
29
- return {"response": output_text}
 
 
 
 
 
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
  from typing import List, Dict
6
+ import base64
7
+ import requests
8
+ from PIL import Image
9
+ from io import BytesIO
10
 
11
  app = FastAPI()
12
 
13
  model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto")
14
+ min_pixels = 256 * 28 * 28
15
+ max_pixels = 1280 * 28 * 28
16
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
17
+
18
+ def process_image(image_data):
19
+ if image_data.startswith("http://") or image_data.startswith("https://"):
20
+ response = requests.get(image_data)
21
+ img = Image.open(BytesIO(response.content))
22
+ elif image_data.startswith("data:image"):
23
+ img_data = base64.b64decode(image_data.split(",")[1])
24
+ img = Image.open(BytesIO(img_data))
25
+ else: # Assume it's a local file path
26
+ img = Image.open(image_data)
27
+ return img
28
 
29
  @app.post("/predict")
30
  async def predict(messages: List[Dict] = Body(...)):
31
  # Processamento e inferência
32
+ texts = []
33
+ image_inputs = []
34
+ video_inputs = []
35
+
36
+ for message in messages:
37
+ for content in message["content"]:
38
+ if content["type"] == "text":
39
+ texts.append(processor.apply_chat_template(content["text"], tokenize=False, add_generation_prompt=True))
40
+ elif content["type"] == "image":
41
+ image_inputs.append(process_image(content["image"]))
42
+
43
  inputs = processor(
44
+ text=texts,
45
  images=image_inputs,
46
  videos=video_inputs,
47
  padding=True,
 
50
  inputs = inputs.to("cpu") # Altere para "cuda" se tiver GPU disponível
51
 
52
  generated_ids = model.generate(**inputs, max_new_tokens=128)
53
+ generated_ids_trimmed = [
54
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
55
+ ]
56
+ output_texts = processor.batch_decode(
57
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
58
+ )
59
+ return {"response": output_texts}