juanpablomesa commited on
Commit
16f4f18
·
1 Parent(s): 5cd682d

Inicial commit of CLIP and GIT simultaneous models

Browse files
Files changed (2) hide show
  1. handler.py +267 -0
  2. requirements.txt +24 -0
handler.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ import io
3
+ from typing import Any, Dict, List
4
+
5
+ import numpy as np
6
+ import requests
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import (
10
+ CLIPModel,
11
+ CLIPProcessor,
12
+ CLIPTokenizerFast,
13
+ pipeline,
14
+ AutoProcessor,
15
+ AutoModelForCausalLM,
16
+ )
17
+ from huggingface_hub import logging
18
+ from concurrent.futures import ThreadPoolExecutor, as_completed
19
+
20
+ import timeit
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+
25
+ # multi-model list
26
+ multi_model_list = [
27
+ {"model_id": "openai/clip-vit-base-patch32", "task": "Custom"},
28
+ {"model_id": "microsoft/git-large-coco", "task": "Custom"},
29
+ ]
30
+
31
+
32
+ class EndpointHandler:
33
+ def __init__(self, path=""):
34
+ clip_model_id = "openai/clip-vit-base-patch32"
35
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ self.processor_clip = CLIPProcessor.from_pretrained(clip_model_id)
38
+ self.model_clip = CLIPModel.from_pretrained(clip_model_id).to(self.device)
39
+ self.tokenizer_clip = CLIPTokenizerFast.from_pretrained(clip_model_id)
40
+ self.processor_git = AutoProcessor.from_pretrained("microsoft/git-large-coco")
41
+ self.model_git = AutoModelForCausalLM.from_pretrained(
42
+ "microsoft/git-large-coco"
43
+ )
44
+
45
+ self.model_git.to(device)
46
+ self.model_clip.to(device)
47
+
48
+ logging.set_verbosity_debug()
49
+ self.logger = logging.get_logger(__name__)
50
+
51
+ def download_image(self, url: str) -> bytes:
52
+ """
53
+ Download an image from a given URL.
54
+
55
+ Parameters:
56
+ - url: str
57
+ The URL from where the image needs to be downloaded.
58
+
59
+ Returns:
60
+ - bytes
61
+ The downloaded image data in bytes.
62
+
63
+ Raises:
64
+ - Exception: If the image download request fails.
65
+ """
66
+ response = requests.get(url)
67
+ if response.status_code == 200:
68
+ return response.content
69
+ else:
70
+ raise Exception(
71
+ f"Failed to download image from {url}. Status code: {response.status_code}"
72
+ )
73
+
74
+ def download_images_in_parallel(
75
+ self, urls: List[str], images_metadata_list: List[dict]
76
+ ) -> List[bytes]:
77
+ """
78
+ Download multiple images in parallel and collect their metadata.
79
+
80
+ Parameters:
81
+ - urls: List[str]
82
+ A list of URLs from where the images need to be downloaded.
83
+ - images_metadata_list: List[dict]
84
+ A list of metadata corresponding to each image URL.
85
+
86
+ Returns:
87
+ - Tuple[List[bytes], List[dict]]
88
+ A tuple containing a list of downloaded image data in bytes and
89
+ a list of metadata for the successfully downloaded images.
90
+ """
91
+ with ThreadPoolExecutor() as executor:
92
+ # Start the load operations and mark each future with its URL and metadata
93
+ future_to_metadata = {
94
+ executor.submit(self.download_image, url): (url, metadata)
95
+ for url, metadata in zip(urls, images_metadata_list)
96
+ }
97
+
98
+ results = []
99
+ successful_metadata = []
100
+ for future in as_completed(future_to_metadata):
101
+ url, metadata = future_to_metadata[future]
102
+ try:
103
+ data = future.result()
104
+ results.append(data)
105
+ metadata["url"] = url
106
+ successful_metadata.append(metadata)
107
+ except Exception as exc:
108
+ self.logger.error("%r generated an exception: %s" % (url, exc))
109
+ return results, successful_metadata
110
+
111
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
112
+ """
113
+ Process the input data based on its type and return the embeddings.
114
+
115
+ This method accepts a dictionary with a 'process_type' key that can be either 'images' or 'text'.
116
+ If 'process_type' is 'images', the method expects a list of image URLs under the 'images_urls' key.
117
+ It downloads and processes these images, and returns their embeddings.
118
+ If 'process_type' is 'text', the method expects a string query under the 'query' key.
119
+ It processes this text and returns its embedding.
120
+
121
+ Parameters:
122
+ - data: Dict[str, Any]
123
+ A dictionary containing the data to be processed.
124
+ It must include a 'process_type' key with value either 'images' or 'text'.
125
+ If 'process_type' is 'images', data should also include 'images_urls' key with a list of image URLs.
126
+ If 'process_type' is 'text', data should also include 'query' key with a string query.
127
+
128
+ Returns:
129
+ - List[Dict[str, Any]]
130
+ A list of dictionaries, each containing the embeddings of the processed data.
131
+ If an error occurs during processing, the dictionary will include an 'error' key with the error message.
132
+
133
+ Raises:
134
+ - ValueError: If the 'process_type' key is not present in data, or if the required keys for 'images' or 'text' are not present or are of the wrong type.
135
+ """
136
+
137
+ if data["process_type"] == "images":
138
+ try:
139
+ # Check if 'inputs' key is in data and it has the right type
140
+ if "images_urls" not in data or not isinstance(
141
+ data["images_urls"], list
142
+ ):
143
+ raise ValueError(
144
+ "Data must contain 'images_urls' key with a list of images urls."
145
+ )
146
+
147
+ batch_size = 100
148
+ if "batch_size" in data:
149
+ batch_size = int(data["batch_size"])
150
+ # Download and process the images (just downloading in this example)
151
+ images_batches = []
152
+ processed_metadata = []
153
+ for i in range(0, len(data["images_urls"]), batch_size):
154
+ # select batch of images
155
+ batches = data["images_urls"][i : i + batch_size]
156
+ batches_metadata = data["images_metadata"][i : i + batch_size]
157
+
158
+ download_start_time = timeit.default_timer()
159
+
160
+ # Download images in parallel along with their metadata
161
+ (
162
+ downloaded_images,
163
+ images_metadata,
164
+ ) = self.download_images_in_parallel(batches, batches_metadata)
165
+
166
+ download_end_time = timeit.default_timer()
167
+ self.logger.info(
168
+ f"Image downloading took {download_end_time - download_start_time} seconds"
169
+ )
170
+ processing_start_time = timeit.default_timer()
171
+
172
+ for image_content, image_metadata in zip(
173
+ downloaded_images, images_metadata
174
+ ):
175
+ try:
176
+ image = Image.open(io.BytesIO(image_content)).convert("RGB")
177
+ image_array = np.array(image)
178
+ images_batches.append(image_array)
179
+ complete_image_metadata = {
180
+ "text": image_metadata["caption"],
181
+ "source": image_metadata["url"],
182
+ "source_type": "images",
183
+ **image_metadata,
184
+ }
185
+ processed_metadata.append(complete_image_metadata)
186
+
187
+ except Exception as e:
188
+ print(e)
189
+ # This should be a list of images as np.arrays
190
+ processing_end_time = timeit.default_timer()
191
+ self.logger.info(
192
+ f"Image processing took {processing_end_time - processing_start_time} seconds"
193
+ )
194
+
195
+ embedding_start_time = timeit.default_timer()
196
+ with torch.no_grad(): # This line ensures that the code inside the block doesn't track gradients
197
+ batch = self.processor_clip(
198
+ text=None,
199
+ images=images_batches,
200
+ return_tensors="pt",
201
+ padding=True,
202
+ )["pixel_values"].to(self.model_clip.device)
203
+ batch_git = self.processor_git(
204
+ images=images_batches,
205
+ return_tensors="pt",
206
+ )
207
+ git_pixel_values = batch_git.pixel_values.to(self.model_git.device)
208
+ # get image captions
209
+ generated_ids = self.model_git.generate(
210
+ pixel_values=git_pixel_values, max_length=50
211
+ )
212
+
213
+ generated_captions = self.processor_git.batch_decode(
214
+ generated_ids, skip_special_tokens=True
215
+ )
216
+
217
+ # get image embeddings
218
+ batch_emb = self.model_clip.get_image_features(pixel_values=batch)
219
+ # detach text emb from graph, move to CPU, and convert to numpy array
220
+ batch_emb = batch_emb.squeeze(0)
221
+ batch_emb = batch_emb.cpu().detach().numpy()
222
+ # NORMALIZE
223
+ batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
224
+ # transpose back to (21, 512)
225
+ batch_emb = batch_emb.T.tolist()
226
+ embedding_end_time = timeit.default_timer()
227
+ self.logger.info(
228
+ f"Embedding calculation took {embedding_end_time - embedding_start_time} seconds"
229
+ )
230
+
231
+ # Return the embeddings
232
+ return {
233
+ "embeddings": batch_emb,
234
+ "metadata": processed_metadata,
235
+ "captions": generated_captions,
236
+ }
237
+
238
+ except Exception as e:
239
+ print(f"Error during Images processing: {str(e)}")
240
+ return {"embeddings": [], "error": str(e)}
241
+
242
+ elif data["process_type"] == "text":
243
+ if "query" not in data or not isinstance(data["query"], str):
244
+ raise ValueError("Data must contain 'query' key which is a str.")
245
+ query = data["query"]
246
+ inputs = self.tokenizer_clip(query, return_tensors="pt").to(self.device)
247
+ text_emb = self.model_clip.get_text_features(**inputs)
248
+ # detach text emb from graph, move to CPU, and convert to numpy array
249
+ text_emb = text_emb.detach().cpu().numpy()
250
+
251
+ # calculate value to normalize each vector by and normalize them
252
+ norm_factor = np.linalg.norm(text_emb, axis=1)
253
+
254
+ text_emb = text_emb.T / norm_factor
255
+ # transpose back to (21, 512)
256
+ text_emb = text_emb.T
257
+
258
+ # Converting tensor to list for JSON response
259
+ text_emb_list = text_emb.tolist()
260
+
261
+ return {"embeddings": text_emb_list}
262
+
263
+ else:
264
+ print(
265
+ f"Error during CLIP endpoint processing: data['process_type']: {data['process_type']} neither 'images' or 'text'"
266
+ )
267
+ return {"embeddings": [], "error": str(e)}
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2023.7.22
2
+ charset-normalizer==3.3.2
3
+ colorama==0.4.6
4
+ filelock==3.13.1
5
+ fsspec==2023.10.0
6
+ huggingface-hub==0.17.3
7
+ idna==3.4
8
+ Jinja2==3.1.2
9
+ MarkupSafe==2.1.3
10
+ mpmath==1.3.0
11
+ networkx==3.2.1
12
+ numpy==1.24.4
13
+ packaging==23.2
14
+ Pillow==10.1.0
15
+ PyYAML==6.0.1
16
+ regex==2023.10.3
17
+ requests==2.31.0
18
+ safetensors==0.4.0
19
+ sympy==1.12
20
+ tokenizers==0.13.3
21
+ tqdm==4.66.1
22
+ transformers==4.27.2
23
+ typing_extensions==4.8.0
24
+ urllib3==2.0.7