|
|
|
import io |
|
from typing import Any, Dict, List |
|
|
|
import numpy as np |
|
import requests |
|
import torch |
|
from PIL import Image |
|
from transformers import ( |
|
CLIPModel, |
|
CLIPProcessor, |
|
CLIPTokenizerFast, |
|
pipeline, |
|
AutoProcessor, |
|
AutoModelForCausalLM, |
|
) |
|
from huggingface_hub import logging |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
import timeit |
|
import easyocr |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
multi_model_list = [ |
|
{"model_id": "openai/clip-vit-base-patch32", "task": "Custom"}, |
|
{"model_id": "microsoft/git-large-coco", "task": "Custom"}, |
|
] |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
clip_model_id = "openai/clip-vit-base-patch32" |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.processor_clip = CLIPProcessor.from_pretrained(clip_model_id) |
|
self.model_clip = CLIPModel.from_pretrained(clip_model_id).to(self.device) |
|
self.tokenizer_clip = CLIPTokenizerFast.from_pretrained(clip_model_id) |
|
self.processor_git = AutoProcessor.from_pretrained("microsoft/git-large-coco") |
|
self.model_git = AutoModelForCausalLM.from_pretrained( |
|
"microsoft/git-large-coco" |
|
) |
|
|
|
self.model_git.to(device) |
|
self.model_clip.to(device) |
|
|
|
logging.set_verbosity_debug() |
|
self.logger = logging.get_logger(__name__) |
|
self.reader = easyocr.Reader(["de", "en"]) |
|
|
|
def download_image(self, url: str) -> bytes: |
|
""" |
|
Download an image from a given URL. |
|
|
|
Parameters: |
|
- url: str |
|
The URL from where the image needs to be downloaded. |
|
|
|
Returns: |
|
- bytes |
|
The downloaded image data in bytes. |
|
|
|
Raises: |
|
- Exception: If the image download request fails. |
|
""" |
|
response = requests.get(url) |
|
if response.status_code == 200: |
|
return response.content |
|
else: |
|
self.logger.error(f"Error downloading image from :{str(url)}") |
|
raise Exception( |
|
f"Failed to download image from {url}. Status code: {response.status_code}" |
|
) |
|
|
|
def download_images_in_parallel( |
|
self, urls: List[str], images_metadata_list: List[dict] |
|
) -> List[bytes]: |
|
""" |
|
Download multiple images in parallel and collect their metadata. |
|
|
|
Parameters: |
|
- urls: List[str] |
|
A list of URLs from where the images need to be downloaded. |
|
- images_metadata_list: List[dict] |
|
A list of metadata corresponding to each image URL. |
|
|
|
Returns: |
|
- Tuple[List[bytes], List[dict]] |
|
A tuple containing a list of downloaded image data in bytes and |
|
a list of metadata for the successfully downloaded images. |
|
""" |
|
with ThreadPoolExecutor() as executor: |
|
|
|
future_to_metadata = { |
|
executor.submit(self.download_image, url): (url, metadata) |
|
for url, metadata in zip(urls, images_metadata_list) |
|
} |
|
|
|
results = [] |
|
successful_metadata = [] |
|
for future in as_completed(future_to_metadata): |
|
url, metadata = future_to_metadata[future] |
|
try: |
|
data = future.result() |
|
results.append(data) |
|
metadata["url"] = url |
|
successful_metadata.append(metadata) |
|
except Exception as exc: |
|
self.logger.error("%r generated an exception: %s" % (url, exc)) |
|
return results, successful_metadata |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Process the input data based on its type and return the embeddings. |
|
|
|
This method accepts a dictionary with a 'process_type' key that can be either 'images' or 'text'. |
|
If 'process_type' is 'images', the method expects a list of image URLs under the 'images_urls' key. |
|
It downloads and processes these images, and returns their embeddings. |
|
If 'process_type' is 'text', the method expects a string query under the 'query' key. |
|
It processes this text and returns its embedding. |
|
|
|
Parameters: |
|
- data: Dict[str, Any] |
|
A dictionary containing the data to be processed. |
|
It must include a 'process_type' key with value either 'images' or 'text'. |
|
If 'process_type' is 'images', data should also include 'images_urls' key with a list of image URLs. |
|
If 'process_type' is 'text', data should also include 'query' key with a string query. |
|
|
|
Returns: |
|
- List[Dict[str, Any]] |
|
A list of dictionaries, each containing the embeddings of the processed data. |
|
If an error occurs during processing, the dictionary will include an 'error' key with the error message. |
|
|
|
Raises: |
|
- ValueError: If the 'process_type' key is not present in data, or if the required keys for 'images' or 'text' are not present or are of the wrong type. |
|
""" |
|
|
|
if data["process_type"] == "images": |
|
try: |
|
|
|
if "images_urls" not in data or not isinstance( |
|
data["images_urls"], list |
|
): |
|
raise ValueError( |
|
"Data must contain 'images_urls' key with a list of images urls." |
|
) |
|
|
|
batch_size = 50 |
|
if "batch_size" in data: |
|
batch_size = int(data["batch_size"]) |
|
|
|
images_batches = [] |
|
processed_metadata = [] |
|
for i in range(0, len(data["images_urls"]), batch_size): |
|
|
|
batches = data["images_urls"][i : i + batch_size] |
|
batches_metadata = data["images_metadata"][i : i + batch_size] |
|
|
|
download_start_time = timeit.default_timer() |
|
|
|
|
|
( |
|
downloaded_images, |
|
images_metadata, |
|
) = self.download_images_in_parallel(batches, batches_metadata) |
|
|
|
download_end_time = timeit.default_timer() |
|
self.logger.info( |
|
f"Image downloading took {download_end_time - download_start_time} seconds" |
|
) |
|
processing_start_time = timeit.default_timer() |
|
|
|
for image_content, image_metadata in zip( |
|
downloaded_images, images_metadata |
|
): |
|
try: |
|
image = Image.open(io.BytesIO(image_content)).convert("RGB") |
|
image_array = np.array(image) |
|
images_batches.append(image_array) |
|
complete_image_metadata = { |
|
|
|
|
|
"source_type": "images", |
|
**image_metadata, |
|
} |
|
|
|
extracted_text = self.reader.readtext( |
|
np.array(image), detail=0 |
|
) |
|
complete_image_metadata["extracted_text"] = extracted_text |
|
|
|
processed_metadata.append(complete_image_metadata) |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error image processing: {str(e)}") |
|
print(e) |
|
|
|
processing_end_time = timeit.default_timer() |
|
self.logger.info( |
|
f"Image processing took {processing_end_time - processing_start_time} seconds" |
|
) |
|
|
|
embedding_start_time = timeit.default_timer() |
|
with torch.no_grad(): |
|
batch = self.processor_clip( |
|
text=None, |
|
images=images_batches, |
|
return_tensors="pt", |
|
padding=True, |
|
)["pixel_values"].to(self.model_clip.device) |
|
batch_git = self.processor_git( |
|
images=images_batches, |
|
return_tensors="pt", |
|
) |
|
git_pixel_values = batch_git.pixel_values.to(self.model_git.device) |
|
|
|
generated_ids = self.model_git.generate( |
|
pixel_values=git_pixel_values, max_length=35 |
|
) |
|
|
|
generated_captions = self.processor_git.batch_decode( |
|
generated_ids, skip_special_tokens=True |
|
) |
|
|
|
|
|
batch_emb = self.model_clip.get_image_features(pixel_values=batch) |
|
|
|
|
|
self.logger.info( |
|
f"Shape of batch_emb after get_image_features: {batch_emb.shape}" |
|
) |
|
|
|
|
|
if batch_emb.shape[0] > 1: |
|
batch_emb = batch_emb.squeeze(0) |
|
self.logger.info( |
|
f"Shape of batch_emb after squeeze: {batch_emb.shape}" |
|
) |
|
|
|
batch_emb = batch_emb.cpu().detach().numpy() |
|
|
|
if batch_emb.ndim > 1: |
|
batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1) |
|
self.logger.info( |
|
f"Shape of batch_emb after normalization (2D case): {batch_emb.shape}" |
|
) |
|
|
|
batch_emb = batch_emb.T.tolist() |
|
|
|
embedding_end_time = timeit.default_timer() |
|
self.logger.info( |
|
f"Embedding calculation took {embedding_end_time - embedding_start_time} seconds" |
|
) |
|
|
|
|
|
return { |
|
"embeddings": batch_emb, |
|
"metadata": processed_metadata, |
|
"captions": generated_captions, |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error during Images processing: {str(e)}") |
|
self.logger.error(f"Error during Images processing: {str(e)}") |
|
return {"embeddings": [], "error": str(e)} |
|
|
|
elif data["process_type"] == "text": |
|
if "query" not in data or not isinstance(data["query"], str): |
|
raise ValueError("Data must contain 'query' key which is a str.") |
|
query = data["query"] |
|
inputs = self.tokenizer_clip(query, return_tensors="pt").to(self.device) |
|
text_emb = self.model_clip.get_text_features(**inputs) |
|
|
|
text_emb = text_emb.detach().cpu().numpy() |
|
|
|
|
|
norm_factor = np.linalg.norm(text_emb, axis=1) |
|
|
|
text_emb = text_emb.T / norm_factor |
|
|
|
text_emb = text_emb.T |
|
|
|
|
|
text_emb_list = text_emb.tolist() |
|
|
|
return {"embeddings": text_emb_list} |
|
|
|
else: |
|
print( |
|
f"Error during CLIP endpoint processing: data['process_type']: {data['process_type']} neither 'images' or 'text'" |
|
) |
|
return {"embeddings": [], "error": str(e)} |
|
|