morthens commited on
Commit
4c169bc
·
verified ·
1 Parent(s): 4ea75f0

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +81 -0
handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import json
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, model_name: str):
9
+ # Load the model and processor
10
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
11
+ model_name,
12
+ torch_dtype="auto",
13
+ device_map="auto"
14
+ )
15
+ self.processor = AutoProcessor.from_pretrained(model_name)
16
+
17
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
+ # Extract image path and messages from the request data
19
+ image_path = data.get("image_path", "")
20
+ messages = data.get("messages", [])
21
+
22
+ # Load the image
23
+ try:
24
+ image = Image.open(image_path)
25
+ except FileNotFoundError:
26
+ return [{"error": "Image file not found."}]
27
+ except Exception as e:
28
+ return [{"error": str(e)}]
29
+
30
+ # Prepare the text prompt from messages
31
+ text_prompt = self.create_text_prompt(messages)
32
+
33
+ # Process inputs for the model
34
+ inputs = self.processor(
35
+ text=[text_prompt],
36
+ images=[image],
37
+ padding=True,
38
+ return_tensors="pt"
39
+ )
40
+
41
+ # Move inputs to GPU if available
42
+ inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+ # Generate output using the model
45
+ output_ids = self.model.generate(**inputs, max_new_tokens=128)
46
+
47
+ # Decode the generated output
48
+ generated_ids = [
49
+ output_ids[len(input_ids):]
50
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
51
+ ]
52
+
53
+ output_text = self.processor.batch_decode(
54
+ generated_ids,
55
+ skip_special_tokens=True,
56
+ clean_up_tokenization_spaces=True
57
+ )
58
+
59
+ # Clean and parse JSON from output text
60
+ cleaned_data = self.clean_output(output_text[0])
61
+
62
+ try:
63
+ json_data = json.loads(cleaned_data)
64
+ except json.JSONDecodeError:
65
+ return [{"error": "Failed to parse JSON output."}]
66
+
67
+ return [json_data]
68
+
69
+ def create_text_prompt(self, messages: List[Dict[str, Any]]) -> str:
70
+ """Extracts and formats text content from messages."""
71
+ text_content = ""
72
+ for message in messages:
73
+ for content in message.get('content', []):
74
+ if content['type'] == 'text':
75
+ text_content += content['text']
76
+
77
+ return self.processor.apply_chat_template(messages, add_generation_prompt=True)
78
+
79
+ def clean_output(self, output: str) -> str:
80
+ """Cleans up the model's output for JSON parsing."""
81
+ return output.replace("```json\n", "").replace("```", "").strip()