sguna's picture
Upload handler.py (#5)
3cf4b90 verified
raw
history blame
2.1 kB
import base64
from typing import Any, Dict
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from io import BytesIO
import torch
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
logger.debug("Initializing model and processor.")
self.model = BlipForConditionalGeneration.from_pretrained(
"quadranttechnologies/qhub-blip-image-captioning-finetuned").to(device)
self.processor = BlipProcessor.from_pretrained("quadranttechnologies/qhub-blip-image-captioning-finetuned")
self.model.eval()
self.model = self.model.to(device).to(device)
def __call__(self, data: Any) -> Dict[str, Any]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`dict`:. The object returned should be a dict of one list like {"descriptions": ["Description of the image"]} containing :
- "description": A string corresponding to the generated description.
"""
logger.debug(f"Received data keys: {data.keys()}")
image_base64 = data["inputs"].get("image")
image_data = base64.b64decode(image_base64)
# Convert image data to PIL Image
images = Image.open(BytesIO(image_data))
# Optional text input
text = data["inputs"].get("text", "")
parameters = data.pop("parameters", {})
processed_image = self.processor(images=images, text=text, return_tensors="pt")
processed_image["pixel_values"] = processed_image["pixel_values"].to(device)
processed_image = {**processed_image, **parameters}
with torch.no_grad():
out = self.model.generate(
**processed_image
)
description = self.processor.batch_decode(out, skip_special_tokens=True)
return {"description": description}