qwen2-vl-inference / handler.py
morthens's picture
Update handler.py
8770a76 verified
raw
history blame
2.83 kB
from typing import Dict, List, Any
import json
import torch
from PIL import Image
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
class EndpointHandler:
def __init__(self, model_name: str ="morthens/qwen2-vl-inference"):
# Load the model and processor
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(model_name)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Extract image path and messages from the request data
image_path = data.get("image_path", "")
messages = data.get("messages", [])
# Load the image
try:
image = Image.open(image_path)
except FileNotFoundError:
return [{"error": "Image file not found."}]
except Exception as e:
return [{"error": str(e)}]
# Prepare the text prompt from messages
text_prompt = self.create_text_prompt(messages)
# Process inputs for the model
inputs = self.processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt"
)
# Move inputs to GPU if available
inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
# Generate output using the model
output_ids = self.model.generate(**inputs, max_new_tokens=128)
# Decode the generated output
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
# Clean and parse JSON from output text
cleaned_data = self.clean_output(output_text[0])
try:
json_data = json.loads(cleaned_data)
except json.JSONDecodeError:
return [{"error": "Failed to parse JSON output."}]
return [json_data]
def create_text_prompt(self, messages: List[Dict[str, Any]]) -> str:
"""Extracts and formats text content from messages."""
text_content = ""
for message in messages:
for content in message.get('content', []):
if content['type'] == 'text':
text_content += content['text']
return self.processor.apply_chat_template(messages, add_generation_prompt=True)
def clean_output(self, output: str) -> str:
"""Cleans up the model's output for JSON parsing."""
return output.replace("```json\n", "").replace("```", "").strip()