clip-and-git-images / handler.py
juanpablomesa's picture
Added easyocr for text extraction
42ed865
# handler.py
import io
from typing import Any, Dict, List
import numpy as np
import requests
import torch
from PIL import Image
from transformers import (
CLIPModel,
CLIPProcessor,
CLIPTokenizerFast,
pipeline,
AutoProcessor,
AutoModelForCausalLM,
)
from huggingface_hub import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import timeit
import easyocr
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# multi-model list
multi_model_list = [
{"model_id": "openai/clip-vit-base-patch32", "task": "Custom"},
{"model_id": "microsoft/git-large-coco", "task": "Custom"},
]
class EndpointHandler:
def __init__(self, path=""):
clip_model_id = "openai/clip-vit-base-patch32"
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.processor_clip = CLIPProcessor.from_pretrained(clip_model_id)
self.model_clip = CLIPModel.from_pretrained(clip_model_id).to(self.device)
self.tokenizer_clip = CLIPTokenizerFast.from_pretrained(clip_model_id)
self.processor_git = AutoProcessor.from_pretrained("microsoft/git-large-coco")
self.model_git = AutoModelForCausalLM.from_pretrained(
"microsoft/git-large-coco"
)
self.model_git.to(device)
self.model_clip.to(device)
logging.set_verbosity_debug()
self.logger = logging.get_logger(__name__)
self.reader = easyocr.Reader(["de", "en"])
def download_image(self, url: str) -> bytes:
"""
Download an image from a given URL.
Parameters:
- url: str
The URL from where the image needs to be downloaded.
Returns:
- bytes
The downloaded image data in bytes.
Raises:
- Exception: If the image download request fails.
"""
response = requests.get(url)
if response.status_code == 200:
return response.content
else:
self.logger.error(f"Error downloading image from :{str(url)}")
raise Exception(
f"Failed to download image from {url}. Status code: {response.status_code}"
)
def download_images_in_parallel(
self, urls: List[str], images_metadata_list: List[dict]
) -> List[bytes]:
"""
Download multiple images in parallel and collect their metadata.
Parameters:
- urls: List[str]
A list of URLs from where the images need to be downloaded.
- images_metadata_list: List[dict]
A list of metadata corresponding to each image URL.
Returns:
- Tuple[List[bytes], List[dict]]
A tuple containing a list of downloaded image data in bytes and
a list of metadata for the successfully downloaded images.
"""
with ThreadPoolExecutor() as executor:
# Start the load operations and mark each future with its URL and metadata
future_to_metadata = {
executor.submit(self.download_image, url): (url, metadata)
for url, metadata in zip(urls, images_metadata_list)
}
results = []
successful_metadata = []
for future in as_completed(future_to_metadata):
url, metadata = future_to_metadata[future]
try:
data = future.result()
results.append(data)
metadata["url"] = url
successful_metadata.append(metadata)
except Exception as exc:
self.logger.error("%r generated an exception: %s" % (url, exc))
return results, successful_metadata
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Process the input data based on its type and return the embeddings.
This method accepts a dictionary with a 'process_type' key that can be either 'images' or 'text'.
If 'process_type' is 'images', the method expects a list of image URLs under the 'images_urls' key.
It downloads and processes these images, and returns their embeddings.
If 'process_type' is 'text', the method expects a string query under the 'query' key.
It processes this text and returns its embedding.
Parameters:
- data: Dict[str, Any]
A dictionary containing the data to be processed.
It must include a 'process_type' key with value either 'images' or 'text'.
If 'process_type' is 'images', data should also include 'images_urls' key with a list of image URLs.
If 'process_type' is 'text', data should also include 'query' key with a string query.
Returns:
- List[Dict[str, Any]]
A list of dictionaries, each containing the embeddings of the processed data.
If an error occurs during processing, the dictionary will include an 'error' key with the error message.
Raises:
- ValueError: If the 'process_type' key is not present in data, or if the required keys for 'images' or 'text' are not present or are of the wrong type.
"""
if data["process_type"] == "images":
try:
# Check if 'inputs' key is in data and it has the right type
if "images_urls" not in data or not isinstance(
data["images_urls"], list
):
raise ValueError(
"Data must contain 'images_urls' key with a list of images urls."
)
batch_size = 50
if "batch_size" in data:
batch_size = int(data["batch_size"])
# Download and process the images (just downloading in this example)
images_batches = []
processed_metadata = []
for i in range(0, len(data["images_urls"]), batch_size):
# select batch of images
batches = data["images_urls"][i : i + batch_size]
batches_metadata = data["images_metadata"][i : i + batch_size]
download_start_time = timeit.default_timer()
# Download images in parallel along with their metadata
(
downloaded_images,
images_metadata,
) = self.download_images_in_parallel(batches, batches_metadata)
download_end_time = timeit.default_timer()
self.logger.info(
f"Image downloading took {download_end_time - download_start_time} seconds"
)
processing_start_time = timeit.default_timer()
for image_content, image_metadata in zip(
downloaded_images, images_metadata
):
try:
image = Image.open(io.BytesIO(image_content)).convert("RGB")
image_array = np.array(image)
images_batches.append(image_array)
complete_image_metadata = {
# "text": image_metadata["caption"],
# "source": image_metadata["url"],
"source_type": "images",
**image_metadata,
}
# Extract text from image using easyocr
extracted_text = self.reader.readtext(
np.array(image), detail=0
)
complete_image_metadata["extracted_text"] = extracted_text
processed_metadata.append(complete_image_metadata)
except Exception as e:
self.logger.error(f"Error image processing: {str(e)}")
print(e)
# This should be a list of images as np.arrays
processing_end_time = timeit.default_timer()
self.logger.info(
f"Image processing took {processing_end_time - processing_start_time} seconds"
)
embedding_start_time = timeit.default_timer()
with torch.no_grad(): # This line ensures that the code inside the block doesn't track gradients
batch = self.processor_clip(
text=None,
images=images_batches,
return_tensors="pt",
padding=True,
)["pixel_values"].to(self.model_clip.device)
batch_git = self.processor_git(
images=images_batches,
return_tensors="pt",
)
git_pixel_values = batch_git.pixel_values.to(self.model_git.device)
# get image captions
generated_ids = self.model_git.generate(
pixel_values=git_pixel_values, max_length=35
)
generated_captions = self.processor_git.batch_decode(
generated_ids, skip_special_tokens=True
)
# get image embeddings
batch_emb = self.model_clip.get_image_features(pixel_values=batch)
# detach text emb from graph, move to CPU, and convert to numpy array
self.logger.info(
f"Shape of batch_emb after get_image_features: {batch_emb.shape}"
)
# Check the shape of the tensor before squeezing
if batch_emb.shape[0] > 1:
batch_emb = batch_emb.squeeze(0)
self.logger.info(
f"Shape of batch_emb after squeeze: {batch_emb.shape}"
)
batch_emb = batch_emb.cpu().detach().numpy()
# NORMALIZE
if batch_emb.ndim > 1:
batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
self.logger.info(
f"Shape of batch_emb after normalization (2D case): {batch_emb.shape}"
)
# transpose back to (21, 512)
batch_emb = batch_emb.T.tolist()
embedding_end_time = timeit.default_timer()
self.logger.info(
f"Embedding calculation took {embedding_end_time - embedding_start_time} seconds"
)
# Return the embeddings
return {
"embeddings": batch_emb,
"metadata": processed_metadata,
"captions": generated_captions,
}
except Exception as e:
print(f"Error during Images processing: {str(e)}")
self.logger.error(f"Error during Images processing: {str(e)}")
return {"embeddings": [], "error": str(e)}
elif data["process_type"] == "text":
if "query" not in data or not isinstance(data["query"], str):
raise ValueError("Data must contain 'query' key which is a str.")
query = data["query"]
inputs = self.tokenizer_clip(query, return_tensors="pt").to(self.device)
text_emb = self.model_clip.get_text_features(**inputs)
# detach text emb from graph, move to CPU, and convert to numpy array
text_emb = text_emb.detach().cpu().numpy()
# calculate value to normalize each vector by and normalize them
norm_factor = np.linalg.norm(text_emb, axis=1)
text_emb = text_emb.T / norm_factor
# transpose back to (21, 512)
text_emb = text_emb.T
# Converting tensor to list for JSON response
text_emb_list = text_emb.tolist()
return {"embeddings": text_emb_list}
else:
print(
f"Error during CLIP endpoint processing: data['process_type']: {data['process_type']} neither 'images' or 'text'"
)
return {"embeddings": [], "error": str(e)}