ecologist / helpers.py
mjwong's picture
Upload helpers.py
1f094d9 verified
from typing import List, Callable
import numpy as np
from PIL import Image
import torch
def l2_normalize(embedding: np.ndarray) -> np.ndarray:
"""Normalize vector using L2 norm.
Args:
embedding (np.ndarray): Input vector to normalize.
Returns:
np.ndarray: Normalized vector.
"""
# Compute the L2 norm of the input vector
norm = np.linalg.norm(embedding)
# Return the normalized vector if norm is greater than 0;
# otherwise, return the original vector
return embedding / norm if norm > 0 else embedding
def encode_image(
image: Image.Image,
preprocess: Callable[[Image.Image], torch.Tensor],
model: torch.nn.Module,
device: torch.device,
) -> List[float]:
"""Preprocess and encode an image using input model.
This function performs the following steps:
1. Preprocess the image to create a tensor.
2. Move the tensor to the specified device (CPU or GPU).
3. Generate image features using the model.
4. Normalize the resulting embedding.
Args:
image (Image.Image): Input image to encode.
preprocess (Callable[[Image.Image], torch.Tensor]):
A callable function to preprocess the image.
model (torch.nn.Module): The model used for encoding.
device (torch.device): The device to which the image tensor is sent.
Returns:
List[float]: A list representing the normalized embedding.
"""
# Preprocess the input image and add a batch dimension
image_input = preprocess(image).unsqueeze(0).to(device)
# Use the model to encode the image without computing gradients
with torch.no_grad():
image_features = model.encode_image(image_input)
# Extract the first (and only) embedding from the batch and move it to CPU
embedding = image_features[0].cpu().numpy()
# Normalize the embedding using L2 normalization
embedding_norm = l2_normalize(embedding)
# Convert the normalized NumPy array to a list and return it
return embedding_norm.tolist()