#import torch from transformers import CLIPModel, CLIPProcessor from PIL import Image from typing import Dict, Any, List import requests import numpy as np from fashion_clip.fashion_clip import FashionCLIP from io import BytesIO import base64 class EndpointHandler: def __init__(self, path="."): # Preload all the elements you are going to need at inference. self.fclip = FashionCLIP('fashion-clip') def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Args: data: A dictionary with keys: - "text": List[str] (required) - The list of text descriptions. - "image": List[Union[str, PIL.Image.Image, np.ndarray]] (required) - The list of images. Images can be URLs, PIL Images, or numpy arrays. Returns: List[Dict[str, Any]]: The embeddings for the text and/or images. """ # Extract text and images from the input data texts = data['inputs'].get("text", []) images = data['inputs'].get("image", []) # Convert image URLs to PIL Images if needed images = [self._load_image(img) for img in images] results = {} if images: image_embeddings = self.fclip.encode_images(images, batch_size=32) image_embeddings = image_embeddings/np.linalg.norm(image_embeddings, ord=2, axis=-1, keepdims=True) results["image_embeddings"] = image_embeddings.tolist() if texts: text_embeddings = self.fclip.encode_text(texts, batch_size=32) text_embeddings = text_embeddings/np.linalg.norm(text_embeddings, ord=2, axis=-1, keepdims=True) results["text_embeddings"] = text_embeddings.tolist() return results def _load_image(self, img): """Helper function to load an image from a URL, PIL Image, numpy array, bytes, or base64 string.""" if isinstance(img, str): if img.startswith('http'): # If the image is a URL img = Image.open(requests.get(img, stream=True).raw) else: # If the image is a base64-encoded string img = Image.open(BytesIO(base64.b64decode(img))) elif isinstance(img, bytes): # If the image is in bytes img = Image.open(BytesIO(img)) elif isinstance(img, Image.Image): # If the image is already a PIL Image pass elif isinstance(img, np.ndarray): # If the image is a numpy array img = Image.fromarray(img) else: raise ValueError("Unsupported image type.") return img