weibo1903 commited on
Commit
52ef253
·
verified ·
1 Parent(s): 0b30541

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -11
main.py CHANGED
@@ -1,13 +1,20 @@
1
- from fastapi import FastAPI, Query
 
 
 
 
 
2
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
  from qwen_vl_utils import process_vision_info
4
- import torch
5
 
 
6
  app = FastAPI()
7
 
 
8
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
9
- min_pixels = 256*28*28
10
- max_pixels = 1280*28*28
 
11
  processor = AutoProcessor.from_pretrained(
12
  checkpoint,
13
  min_pixels=min_pixels,
@@ -17,22 +24,33 @@ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
  checkpoint,
18
  torch_dtype=torch.bfloat16,
19
  device_map="auto",
20
- # attn_implementation="flash_attention_2",
21
  )
22
 
 
 
 
 
 
23
  @app.get("/")
24
  def read_root():
25
  return {"message": "API is live. Use the /predict endpoint."}
26
 
27
- @app.get("/predict")
28
- def predict(
29
- image_url: str = Query(..., description="URL of the image"),
30
- prompt: str = Query(..., description="Prompt for the image")
31
- ):
 
 
 
 
 
32
  messages = [
33
  {"role": "system", "content": "You are a helpful assistant with vision abilities."},
34
- {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
35
  ]
 
 
36
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
37
  image_inputs, video_inputs = process_vision_info(messages)
38
  inputs = processor(
@@ -42,10 +60,15 @@ def predict(
42
  padding=True,
43
  return_tensors="pt",
44
  ).to(model.device)
 
 
45
  with torch.no_grad():
46
  generated_ids = model.generate(**inputs, max_new_tokens=128)
 
 
47
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
48
  output_texts = processor.batch_decode(
49
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
50
  )
 
51
  return {"response": output_texts[0]}
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8
  from qwen_vl_utils import process_vision_info
 
9
 
10
+ # Initialize FastAPI
11
  app = FastAPI()
12
 
13
+ # Load the model and processor
14
  checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
15
+ min_pixels = 256 * 28 * 28
16
+ max_pixels = 1280 * 28 * 28
17
+
18
  processor = AutoProcessor.from_pretrained(
19
  checkpoint,
20
  min_pixels=min_pixels,
 
24
  checkpoint,
25
  torch_dtype=torch.bfloat16,
26
  device_map="auto",
 
27
  )
28
 
29
+ # Define the request schema
30
+ class ImageRequest(BaseModel):
31
+ image_base64: str # Base64 encoded image
32
+ prompt: str # Text prompt
33
+
34
  @app.get("/")
35
  def read_root():
36
  return {"message": "API is live. Use the /predict endpoint."}
37
 
38
+ @app.post("/predict") # Changed from GET to POST
39
+ async def predict(request: ImageRequest):
40
+ # Decode the base64 image
41
+ try:
42
+ image_data = base64.b64decode(request.image_base64)
43
+ image = Image.open(BytesIO(image_data)).convert("RGB")
44
+ except Exception as e:
45
+ return {"error": f"Invalid base64 image data: {str(e)}"}
46
+
47
+ # Create message structure
48
  messages = [
49
  {"role": "system", "content": "You are a helpful assistant with vision abilities."},
50
+ {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": request.prompt}]},
51
  ]
52
+
53
+ # Process inputs
54
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
55
  image_inputs, video_inputs = process_vision_info(messages)
56
  inputs = processor(
 
60
  padding=True,
61
  return_tensors="pt",
62
  ).to(model.device)
63
+
64
+ # Run inference
65
  with torch.no_grad():
66
  generated_ids = model.generate(**inputs, max_new_tokens=128)
67
+
68
+ # Process output
69
  generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
70
  output_texts = processor.batch_decode(
71
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
72
  )
73
+
74
  return {"response": output_texts[0]}