sguna's picture
Upload handler.py
ae68696 verified
raw
history blame
2.34 kB
from typing import Any, Dict
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from io import BytesIO
import torch
import base64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.model = BlipForConditionalGeneration.from_pretrained(
"quadranttechnologies/qhub-blip-image-captioning-finetuned").to(device)
self.processor = BlipProcessor.from_pretrained("quadranttechnologies/qhub-blip-image-captioning-finetuned")
self.model.eval()
self.model = self.model.to(device).to(device)
def __call__(self, data: Any) -> Dict[str, Any]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`dict`:. The object returned should be a dict of one list like {"descriptions": ["Description of the image"]} containing :
- "description": A string corresponding to the generated description.
"""
images = data.pop("inputs", data)
text = data.get("text", "")
parameters = data.pop("parameters", {})
try:
# Ensure inputs is a list of image bytes, even if only a single image is provided
if isinstance(images, bytes): # Single image as bytes
raw_images = [Image.open(BytesIO(images))]
elif isinstance(images, list): # Multiple images as list of bytes
raw_images = [Image.open(BytesIO(_img)) for _img in images if isinstance(_img, bytes)]
else:
raise ValueError("Invalid image input format. Expected bytes or list of bytes.")
except Exception as e:
return {"error": f"Error fetching or processing image: {str(e)}"}
processed_image = self.processor(images=images, text=text, return_tensors="pt")
processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
processed_image = {**processed_image, **parameters}
with torch.no_grad():
out = self.model.generate(
**processed_image
)
description = self.processor.batch_decode(out, skip_special_tokens=True)
return {"description": description}