lucianotonet commited on
Commit
f3c27dd
1 Parent(s): 09984c6

Refine image processing and input handling

Browse files

Melhora a função de processamento de imagens para incluir verificação de erros durante as requisições HTTP e clarifica o tratamento de entradas na função de previsão. AGora suporta tanto strings quanto listas para o conteúdo, disponibilizando feedback para formatação inválida. Essas alterações aumentam a robustez do sistema e melhoram a experiência do usuário ao prevenir falhas inesperadas.

Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, Body
2
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
- from typing import List, Dict
6
  import base64
7
  import requests
8
  from PIL import Image
@@ -15,31 +15,41 @@ min_pixels = 256 * 28 * 28
15
  max_pixels = 1280 * 28 * 28
16
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
17
 
18
- def process_image(image_data):
19
  if image_data.startswith("http://") or image_data.startswith("https://"):
20
  response = requests.get(image_data)
 
21
  img = Image.open(BytesIO(response.content))
22
  elif image_data.startswith("data:image"):
23
  img_data = base64.b64decode(image_data.split(",")[1])
24
  img = Image.open(BytesIO(img_data))
25
- else: # Assume it's a local file path
26
  img = Image.open(image_data)
27
  return img
28
 
29
  @app.post("/predict")
30
- async def predict(messages: List[Dict] = Body(...)):
31
  # Processamento e inferência
32
  texts = []
33
  image_inputs = []
34
  video_inputs = []
35
 
36
  for message in messages:
37
- for content in message["content"]:
38
- if content["type"] == "text":
39
- texts.append(processor.apply_chat_template(content["text"], tokenize=False, add_generation_prompt=True))
40
- elif content["type"] == "image":
41
- image_inputs.append(process_image(content["image"]))
42
-
 
 
 
 
 
 
 
 
 
43
  inputs = processor(
44
  text=texts,
45
  images=image_inputs,
 
2
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
+ from typing import List, Dict, Union
6
  import base64
7
  import requests
8
  from PIL import Image
 
15
  max_pixels = 1280 * 28 * 28
16
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
17
 
18
+ def process_image(image_data: str) -> Image.Image:
19
  if image_data.startswith("http://") or image_data.startswith("https://"):
20
  response = requests.get(image_data)
21
+ response.raise_for_status() # Adiciona verificação de erro na requisição
22
  img = Image.open(BytesIO(response.content))
23
  elif image_data.startswith("data:image"):
24
  img_data = base64.b64decode(image_data.split(",")[1])
25
  img = Image.open(BytesIO(img_data))
26
+ else: # Assume que é um caminho de arquivo local
27
  img = Image.open(image_data)
28
  return img
29
 
30
  @app.post("/predict")
31
+ async def predict(messages: List[Dict[str, Union[str, List[Dict[str, str]]]]] = Body(...)):
32
  # Processamento e inferência
33
  texts = []
34
  image_inputs = []
35
  video_inputs = []
36
 
37
  for message in messages:
38
+ content = message.get("content")
39
+ if isinstance(content, str):
40
+ texts.append(processor.apply_chat_template(content, tokenize=False, add_generation_prompt=True))
41
+ elif isinstance(content, list):
42
+ for item in content:
43
+ if isinstance(item, dict) and "type" in item:
44
+ if item["type"] == "text":
45
+ texts.append(processor.apply_chat_template(item["text"], tokenize=False, add_generation_prompt=True))
46
+ elif item["type"] == "image":
47
+ image_inputs.append(process_image(item["image"]))
48
+ else:
49
+ raise ValueError(f"Formato inválido para o item: {item}")
50
+ else:
51
+ raise ValueError(f"Formato inválido para o conteúdo: {content}")
52
+
53
  inputs = processor(
54
  text=texts,
55
  images=image_inputs,