|
|
|
from transformers import CLIPModel, CLIPProcessor |
|
from PIL import Image |
|
from typing import Dict, Any, List |
|
import requests |
|
import numpy as np |
|
from fashion_clip.fashion_clip import FashionCLIP |
|
from io import BytesIO |
|
import base64 |
|
|
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path="."): |
|
|
|
self.fclip = FashionCLIP('fashion-clip') |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Args: |
|
data: A dictionary with keys: |
|
- "text": List[str] (required) - The list of text descriptions. |
|
- "image": List[Union[str, PIL.Image.Image, np.ndarray]] (required) - The list of images. Images can be URLs, PIL Images, or numpy arrays. |
|
|
|
Returns: |
|
List[Dict[str, Any]]: The embeddings for the text and/or images. |
|
""" |
|
|
|
texts = data['inputs'].get("text", []) |
|
images = data['inputs'].get("image", []) |
|
|
|
|
|
|
|
images = [self._load_image(img) for img in images] |
|
|
|
results = {} |
|
|
|
if images: |
|
image_embeddings = self.fclip.encode_images(images, batch_size=32) |
|
image_embeddings = image_embeddings/np.linalg.norm(image_embeddings, ord=2, axis=-1, keepdims=True) |
|
results["image_embeddings"] = image_embeddings.tolist() |
|
|
|
if texts: |
|
text_embeddings = self.fclip.encode_text(texts, batch_size=32) |
|
text_embeddings = text_embeddings/np.linalg.norm(text_embeddings, ord=2, axis=-1, keepdims=True) |
|
results["text_embeddings"] = text_embeddings.tolist() |
|
|
|
return results |
|
|
|
def _load_image(self, img): |
|
"""Helper function to load an image from a URL, PIL Image, numpy array, bytes, or base64 string.""" |
|
if isinstance(img, str): |
|
if img.startswith('http'): |
|
|
|
img = Image.open(requests.get(img, stream=True).raw) |
|
else: |
|
|
|
img = Image.open(BytesIO(base64.b64decode(img))) |
|
elif isinstance(img, bytes): |
|
|
|
img = Image.open(BytesIO(img)) |
|
elif isinstance(img, Image.Image): |
|
|
|
pass |
|
elif isinstance(img, np.ndarray): |
|
|
|
img = Image.fromarray(img) |
|
else: |
|
raise ValueError("Unsupported image type.") |
|
return img |
|
|