|
from typing import List, Callable |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
|
|
def l2_normalize(embedding: np.ndarray) -> np.ndarray: |
|
"""Normalize vector using L2 norm. |
|
|
|
Args: |
|
embedding (np.ndarray): Input vector to normalize. |
|
|
|
Returns: |
|
np.ndarray: Normalized vector. |
|
""" |
|
|
|
norm = np.linalg.norm(embedding) |
|
|
|
|
|
|
|
return embedding / norm if norm > 0 else embedding |
|
|
|
def encode_image( |
|
image: Image.Image, |
|
preprocess: Callable[[Image.Image], torch.Tensor], |
|
model: torch.nn.Module, |
|
device: torch.device, |
|
) -> List[float]: |
|
"""Preprocess and encode an image using input model. |
|
|
|
This function performs the following steps: |
|
1. Preprocess the image to create a tensor. |
|
2. Move the tensor to the specified device (CPU or GPU). |
|
3. Generate image features using the model. |
|
4. Normalize the resulting embedding. |
|
|
|
Args: |
|
image (Image.Image): Input image to encode. |
|
preprocess (Callable[[Image.Image], torch.Tensor]): |
|
A callable function to preprocess the image. |
|
model (torch.nn.Module): The model used for encoding. |
|
device (torch.device): The device to which the image tensor is sent. |
|
|
|
Returns: |
|
List[float]: A list representing the normalized embedding. |
|
""" |
|
|
|
image_input = preprocess(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
image_features = model.encode_image(image_input) |
|
|
|
|
|
embedding = image_features[0].cpu().numpy() |
|
|
|
|
|
embedding_norm = l2_normalize(embedding) |
|
|
|
|
|
return embedding_norm.tolist() |