fashion-embedder / handler.py
McClain's picture
Upload folder using huggingface_hub
f8fef9e verified
#import torch
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
from typing import Dict, Any, List
import requests
import numpy as np
from fashion_clip.fashion_clip import FashionCLIP
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path="."):
# Preload all the elements you are going to need at inference.
self.fclip = FashionCLIP('fashion-clip')
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data: A dictionary with keys:
- "text": List[str] (required) - The list of text descriptions.
- "image": List[Union[str, PIL.Image.Image, np.ndarray]] (required) - The list of images. Images can be URLs, PIL Images, or numpy arrays.
Returns:
List[Dict[str, Any]]: The embeddings for the text and/or images.
"""
# Extract text and images from the input data
texts = data['inputs'].get("text", [])
images = data['inputs'].get("image", [])
# Convert image URLs to PIL Images if needed
images = [self._load_image(img) for img in images]
results = {}
if images:
image_embeddings = self.fclip.encode_images(images, batch_size=32)
image_embeddings = image_embeddings/np.linalg.norm(image_embeddings, ord=2, axis=-1, keepdims=True)
results["image_embeddings"] = image_embeddings.tolist()
if texts:
text_embeddings = self.fclip.encode_text(texts, batch_size=32)
text_embeddings = text_embeddings/np.linalg.norm(text_embeddings, ord=2, axis=-1, keepdims=True)
results["text_embeddings"] = text_embeddings.tolist()
return results
def _load_image(self, img):
"""Helper function to load an image from a URL, PIL Image, numpy array, bytes, or base64 string."""
if isinstance(img, str):
if img.startswith('http'):
# If the image is a URL
img = Image.open(requests.get(img, stream=True).raw)
else:
# If the image is a base64-encoded string
img = Image.open(BytesIO(base64.b64decode(img)))
elif isinstance(img, bytes):
# If the image is in bytes
img = Image.open(BytesIO(img))
elif isinstance(img, Image.Image):
# If the image is already a PIL Image
pass
elif isinstance(img, np.ndarray):
# If the image is a numpy array
img = Image.fromarray(img)
else:
raise ValueError("Unsupported image type.")
return img