# handler.py import io from typing import Any, Dict, List import numpy as np import requests import torch from PIL import Image from transformers import ( CLIPModel, CLIPProcessor, CLIPTokenizerFast, pipeline, AutoProcessor, AutoModelForCausalLM, ) from huggingface_hub import logging from concurrent.futures import ThreadPoolExecutor, as_completed import timeit import easyocr device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # multi-model list multi_model_list = [ {"model_id": "openai/clip-vit-base-patch32", "task": "Custom"}, {"model_id": "microsoft/git-large-coco", "task": "Custom"}, ] class EndpointHandler: def __init__(self, path=""): clip_model_id = "openai/clip-vit-base-patch32" # self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.processor_clip = CLIPProcessor.from_pretrained(clip_model_id) self.model_clip = CLIPModel.from_pretrained(clip_model_id).to(self.device) self.tokenizer_clip = CLIPTokenizerFast.from_pretrained(clip_model_id) self.processor_git = AutoProcessor.from_pretrained("microsoft/git-large-coco") self.model_git = AutoModelForCausalLM.from_pretrained( "microsoft/git-large-coco" ) self.model_git.to(device) self.model_clip.to(device) logging.set_verbosity_debug() self.logger = logging.get_logger(__name__) self.reader = easyocr.Reader(["de", "en"]) def download_image(self, url: str) -> bytes: """ Download an image from a given URL. Parameters: - url: str The URL from where the image needs to be downloaded. Returns: - bytes The downloaded image data in bytes. Raises: - Exception: If the image download request fails. """ response = requests.get(url) if response.status_code == 200: return response.content else: self.logger.error(f"Error downloading image from :{str(url)}") raise Exception( f"Failed to download image from {url}. Status code: {response.status_code}" ) def download_images_in_parallel( self, urls: List[str], images_metadata_list: List[dict] ) -> List[bytes]: """ Download multiple images in parallel and collect their metadata. Parameters: - urls: List[str] A list of URLs from where the images need to be downloaded. - images_metadata_list: List[dict] A list of metadata corresponding to each image URL. Returns: - Tuple[List[bytes], List[dict]] A tuple containing a list of downloaded image data in bytes and a list of metadata for the successfully downloaded images. """ with ThreadPoolExecutor() as executor: # Start the load operations and mark each future with its URL and metadata future_to_metadata = { executor.submit(self.download_image, url): (url, metadata) for url, metadata in zip(urls, images_metadata_list) } results = [] successful_metadata = [] for future in as_completed(future_to_metadata): url, metadata = future_to_metadata[future] try: data = future.result() results.append(data) metadata["url"] = url successful_metadata.append(metadata) except Exception as exc: self.logger.error("%r generated an exception: %s" % (url, exc)) return results, successful_metadata def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the input data based on its type and return the embeddings. This method accepts a dictionary with a 'process_type' key that can be either 'images' or 'text'. If 'process_type' is 'images', the method expects a list of image URLs under the 'images_urls' key. It downloads and processes these images, and returns their embeddings. If 'process_type' is 'text', the method expects a string query under the 'query' key. It processes this text and returns its embedding. Parameters: - data: Dict[str, Any] A dictionary containing the data to be processed. It must include a 'process_type' key with value either 'images' or 'text'. If 'process_type' is 'images', data should also include 'images_urls' key with a list of image URLs. If 'process_type' is 'text', data should also include 'query' key with a string query. Returns: - List[Dict[str, Any]] A list of dictionaries, each containing the embeddings of the processed data. If an error occurs during processing, the dictionary will include an 'error' key with the error message. Raises: - ValueError: If the 'process_type' key is not present in data, or if the required keys for 'images' or 'text' are not present or are of the wrong type. """ if data["process_type"] == "images": try: # Check if 'inputs' key is in data and it has the right type if "images_urls" not in data or not isinstance( data["images_urls"], list ): raise ValueError( "Data must contain 'images_urls' key with a list of images urls." ) batch_size = 50 if "batch_size" in data: batch_size = int(data["batch_size"]) # Download and process the images (just downloading in this example) images_batches = [] processed_metadata = [] for i in range(0, len(data["images_urls"]), batch_size): # select batch of images batches = data["images_urls"][i : i + batch_size] batches_metadata = data["images_metadata"][i : i + batch_size] download_start_time = timeit.default_timer() # Download images in parallel along with their metadata ( downloaded_images, images_metadata, ) = self.download_images_in_parallel(batches, batches_metadata) download_end_time = timeit.default_timer() self.logger.info( f"Image downloading took {download_end_time - download_start_time} seconds" ) processing_start_time = timeit.default_timer() for image_content, image_metadata in zip( downloaded_images, images_metadata ): try: image = Image.open(io.BytesIO(image_content)).convert("RGB") image_array = np.array(image) images_batches.append(image_array) complete_image_metadata = { # "text": image_metadata["caption"], # "source": image_metadata["url"], "source_type": "images", **image_metadata, } # Extract text from image using easyocr extracted_text = self.reader.readtext( np.array(image), detail=0 ) complete_image_metadata["extracted_text"] = extracted_text processed_metadata.append(complete_image_metadata) except Exception as e: self.logger.error(f"Error image processing: {str(e)}") print(e) # This should be a list of images as np.arrays processing_end_time = timeit.default_timer() self.logger.info( f"Image processing took {processing_end_time - processing_start_time} seconds" ) embedding_start_time = timeit.default_timer() with torch.no_grad(): # This line ensures that the code inside the block doesn't track gradients batch = self.processor_clip( text=None, images=images_batches, return_tensors="pt", padding=True, )["pixel_values"].to(self.model_clip.device) batch_git = self.processor_git( images=images_batches, return_tensors="pt", ) git_pixel_values = batch_git.pixel_values.to(self.model_git.device) # get image captions generated_ids = self.model_git.generate( pixel_values=git_pixel_values, max_length=35 ) generated_captions = self.processor_git.batch_decode( generated_ids, skip_special_tokens=True ) # get image embeddings batch_emb = self.model_clip.get_image_features(pixel_values=batch) # detach text emb from graph, move to CPU, and convert to numpy array self.logger.info( f"Shape of batch_emb after get_image_features: {batch_emb.shape}" ) # Check the shape of the tensor before squeezing if batch_emb.shape[0] > 1: batch_emb = batch_emb.squeeze(0) self.logger.info( f"Shape of batch_emb after squeeze: {batch_emb.shape}" ) batch_emb = batch_emb.cpu().detach().numpy() # NORMALIZE if batch_emb.ndim > 1: batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1) self.logger.info( f"Shape of batch_emb after normalization (2D case): {batch_emb.shape}" ) # transpose back to (21, 512) batch_emb = batch_emb.T.tolist() embedding_end_time = timeit.default_timer() self.logger.info( f"Embedding calculation took {embedding_end_time - embedding_start_time} seconds" ) # Return the embeddings return { "embeddings": batch_emb, "metadata": processed_metadata, "captions": generated_captions, } except Exception as e: print(f"Error during Images processing: {str(e)}") self.logger.error(f"Error during Images processing: {str(e)}") return {"embeddings": [], "error": str(e)} elif data["process_type"] == "text": if "query" not in data or not isinstance(data["query"], str): raise ValueError("Data must contain 'query' key which is a str.") query = data["query"] inputs = self.tokenizer_clip(query, return_tensors="pt").to(self.device) text_emb = self.model_clip.get_text_features(**inputs) # detach text emb from graph, move to CPU, and convert to numpy array text_emb = text_emb.detach().cpu().numpy() # calculate value to normalize each vector by and normalize them norm_factor = np.linalg.norm(text_emb, axis=1) text_emb = text_emb.T / norm_factor # transpose back to (21, 512) text_emb = text_emb.T # Converting tensor to list for JSON response text_emb_list = text_emb.tolist() return {"embeddings": text_emb_list} else: print( f"Error during CLIP endpoint processing: data['process_type']: {data['process_type']} neither 'images' or 'text'" ) return {"embeddings": [], "error": str(e)}