Thomasboosinger's picture
Update handler.py
03dec79 verified
raw
history blame
1.58 kB
from transformers import pipeline
from PIL import Image
from io import BytesIO
import base64
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, model_path=""):
# Initialize the zero-shot object detection pipeline with the specified model
# and set the device to GPU for faster computation.
self.pipeline = pipeline(task="zero-shot-object-detection", model=model_path, device=0)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handles incoming requests for zero-shot object detection, decoding the image
and predicting labels based on provided candidates.
Args:
data (Dict[str, Any]): The input data containing an encoded image and candidate labels.
Returns:
List[Dict[str, Any]]: Predictions with labels and scores for the detected objects.
"""
# Decode the base64-encoded image to a PIL Image object for processing.
image_data = data.get("inputs", {}).get('image', '')
image = Image.open(BytesIO(base64.b64decode(image_data)))
# Extract candidate labels from the input data.
candidate_labels = data.get("inputs", {}).get("candidates", [])
# Perform zero-shot object detection using the provided image and candidate labels.
detection_results = self.pipeline(image=image, candidate_labels=candidate_labels)
# Return the detection results directly, which should match the expected output structure.
return detection_results