juanpablomesa commited on
Commit
e8faf30
·
1 Parent(s): e30ecf9

Initial commit of XCLIP video processing endpoint only video frames

Browse files
Files changed (2) hide show
  1. handler.py +245 -0
  2. requirements.txt +25 -0
handler.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from typing import Any, Dict, List
3
+
4
+ import cv2
5
+ import tempfile
6
+ import numpy as np
7
+ import requests
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import AutoTokenizer, XCLIPModel, XCLIPProcessor
11
+ from huggingface_hub import logging
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+
14
+ import timeit
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+
19
+ class EndpointHandler:
20
+ def __init__(self, path=""):
21
+ # Preload all the elements you are going to need at inference.
22
+ # pseudo
23
+ # self.model = load_model(path)
24
+ model_id = "microsoft/xclip-base-patch16-zero-shot"
25
+ # self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.processor = XCLIPProcessor.from_pretrained(path)
28
+ self.model = XCLIPModel.from_pretrained(path).to(self.device)
29
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
30
+
31
+ logging.set_verbosity_debug()
32
+ self.logger = logging.get_logger(__name__)
33
+ # Check if CUDA (GPU support) is available
34
+ if torch.cuda.is_available():
35
+ self.logger.info("GPU is available for inference.")
36
+ self.logger.info(f"Using {torch.cuda.get_device_name(0)}")
37
+ else:
38
+ self.logger.info("GPU is not available, using CPU for inference.")
39
+
40
+ def download_video_as_bytes(self, url: str) -> (bytes, dict):
41
+ """
42
+ Download a video from a given URL, load it in RAM, and return it as bytes.
43
+
44
+ Parameters:
45
+ - url (str): The URL of the video to download.
46
+
47
+ Returns:
48
+ - bytes or None: The video content as bytes if successful, None otherwise.
49
+ - dict or None: The video download headers if succesful, None otherwise.
50
+ """
51
+ try:
52
+ response = requests.get(url)
53
+ response.raise_for_status() # Raise an exception for HTTP errors
54
+ return response.content, response.headers
55
+ except requests.RequestException as e:
56
+ print(f"Error downloading the video: {e}")
57
+ return None, None
58
+
59
+ def extract_evenly_spaced_frames_from_bytes(
60
+ self, video_bytes: bytes, num_frames: int = 32
61
+ ) -> list:
62
+ # Write bytes to a temporary file
63
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
64
+ temp_video.write(video_bytes)
65
+ temp_video.flush()
66
+
67
+ # Create a VideoCapture object using the temporary file's name
68
+ vidcap = cv2.VideoCapture(temp_video.name)
69
+
70
+ # Get the total number of frames in the video
71
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
72
+
73
+ # Calculate the interval at which frames should be extracted
74
+ interval = total_frames // num_frames
75
+
76
+ frames = []
77
+
78
+ for i in range(num_frames):
79
+ # Set the video position to the next frame to be captured
80
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i * interval)
81
+
82
+ # Read the frame
83
+ success, image = vidcap.read()
84
+
85
+ # If successfully read, add the frame to the list
86
+ if success:
87
+ frames.append(image)
88
+
89
+ return frames
90
+
91
+ def preprocess_frames(self, video_frames):
92
+ """
93
+ Define a preprocessing function to convert video frames into a format suitable for the model
94
+ """
95
+ frames = np.array(video_frames)
96
+ # Use the XCLIP Processor to preprocess the frames
97
+ inputs = self.processor(
98
+ text=None, videos=list(frames), return_tensors="pt", padding=True
99
+ ).to(self.device)
100
+
101
+ return inputs
102
+
103
+ def embed_frames_with_xclip_processing(self, frames):
104
+ # Initialize an empty list to store the frame embeddings
105
+ frame_embeddings = []
106
+
107
+ frame_preprocessed = self.preprocess_frames(frames)
108
+
109
+ # Pass the preprocessed frame through the model to get the frame embeddings
110
+ frame_embedding = self.model.get_video_features(**frame_preprocessed)
111
+
112
+ # Add the frame embeddings to the list
113
+ frame_embeddings.append(frame_embedding)
114
+
115
+ # Stack the list of frame embeddings into a single tensor
116
+ tensor = torch.stack(frame_embeddings)
117
+
118
+ # detach text emb from graph, move to CPU, and convert to numpy array
119
+ batch_emb = tensor.squeeze(0)
120
+
121
+ batch_emb = batch_emb.cpu().detach().numpy()
122
+
123
+ # NORMALIZE
124
+ batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
125
+ # transpose back to (21, 512)
126
+
127
+ batch_emb = batch_emb.tolist()
128
+
129
+ return batch_emb
130
+
131
+ def process_video(self, video_url, video_metadata):
132
+ try:
133
+ video_bytes, video_headers = self.download_video_as_bytes(video_url)
134
+ frames = self.extract_evenly_spaced_frames_from_bytes(
135
+ video_bytes, num_frames=32
136
+ )
137
+ frame_embeddings = self.embed_frames_with_xclip_processing(frames)
138
+ video_metadata["url"] = video_url
139
+ return frame_embeddings, video_metadata
140
+ except Exception as e:
141
+ print(e)
142
+ return None, None, None
143
+
144
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
145
+ """
146
+ Process the input data based on its type and return the embeddings.
147
+
148
+ This method accepts a dictionary with a 'process_type' key that can be either 'images' or 'text'.
149
+ If 'process_type' is 'images', the method expects a list of image URLs under the 'images_urls' key.
150
+ It downloads and processes these images, and returns their embeddings.
151
+ If 'process_type' is 'text', the method expects a string query under the 'query' key.
152
+ It processes this text and returns its embedding.
153
+
154
+ Parameters:
155
+ - data: Dict[str, Any]
156
+ A dictionary containing the data to be processed.
157
+ It must include a 'process_type' key with value either 'images' or 'text'.
158
+ If 'process_type' is 'images', data should also include 'images_urls' key with a list of image URLs.
159
+ If 'process_type' is 'text', data should also include 'query' key with a string query.
160
+
161
+ Returns:
162
+ - List[Dict[str, Any]]
163
+ A list of dictionaries, each containing the embeddings of the processed data.
164
+ If an error occurs during processing, the dictionary will include an 'error' key with the error message.
165
+
166
+ Raises:
167
+ - ValueError: If the 'process_type' key is not present in data, or if the required keys for 'images' or 'text' are not present or are of the wrong type.
168
+ """
169
+
170
+ if data["process_type"] == "videos":
171
+ try:
172
+ if "videos_urls" not in data or not isinstance(
173
+ data["videos_urls"], list
174
+ ):
175
+ raise ValueError(
176
+ "Data must contain 'videos_urls' key with a list of videos urls."
177
+ )
178
+
179
+ batch_size = 10
180
+ if "batch_size" in data:
181
+ batch_size = int(data["batch_size"])
182
+ # Download and process the videos
183
+ processed_video_embeddings = []
184
+ processed_videos_metadata = []
185
+
186
+ for i in range(0, len(data["videos_urls"]), batch_size):
187
+ videos_urls = data["videos_urls"][i : i + batch_size]
188
+ videos_metadata = data["videos_metadata"][i : i + batch_size]
189
+
190
+ with ThreadPoolExecutor() as executor:
191
+ futures = [
192
+ executor.submit(self.process_video, url, metadata)
193
+ for url, metadata in zip(videos_urls, videos_metadata)
194
+ ]
195
+ for future in as_completed(futures):
196
+ frame_embeddings, video_metadata = future.result()
197
+ if frame_embeddings is not None:
198
+ processed_video_embeddings.append(frame_embeddings)
199
+
200
+ processed_metadata = {
201
+ "text": video_metadata["caption"],
202
+ "source": video_metadata["url"],
203
+ "source_type": "video_frames",
204
+ **video_metadata,
205
+ }
206
+ processed_videos_metadata.append(processed_metadata)
207
+ # Return the embeddings
208
+ return {
209
+ "embeddings": processed_video_embeddings,
210
+ "metadata": processed_videos_metadata,
211
+ }
212
+
213
+ except Exception as e:
214
+ print(f"Error during videos processing: {str(e)}")
215
+ return {"embeddings": [], "error": str(e)}
216
+
217
+ elif data["process_type"] == "text":
218
+ if "query" not in data or not isinstance(data["query"], str):
219
+ raise ValueError("Data must contain 'query' key which is a str.")
220
+ query = data["query"]
221
+ inputs = self.tokenizer(query, return_tensors="pt").to(self.device)
222
+ text_emb = self.model.get_text_features(**inputs)
223
+ # detach text emb from graph, move to CPU, and convert to numpy array
224
+ text_emb = text_emb.detach().cpu().numpy()
225
+
226
+ # calculate value to normalize each vector by and normalize them
227
+ norm_factor = np.linalg.norm(text_emb, axis=1)
228
+
229
+ text_emb = text_emb.T / norm_factor
230
+ # transpose back to (21, 512)
231
+ text_emb = text_emb.T
232
+
233
+ # Converting tensor to list for JSON response
234
+ text_emb_list = text_emb.tolist()
235
+
236
+ return {"embeddings": text_emb_list}
237
+
238
+ else:
239
+ print(
240
+ f"Error during CLIP endpoint processing: data['process_type']: {data['process_type']} neither 'images' or 'text'"
241
+ )
242
+ return {"embeddings": [], "error": str(e)}
243
+
244
+ # pseudo
245
+ # self.model(input)
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2023.7.22
2
+ charset-normalizer==3.3.2
3
+ colorama==0.4.6
4
+ filelock==3.13.1
5
+ fsspec==2023.10.0
6
+ huggingface-hub==0.17.3
7
+ idna==3.4
8
+ Jinja2==3.1.2
9
+ MarkupSafe==2.1.3
10
+ mpmath==1.3.0
11
+ networkx==3.2.1
12
+ numpy==1.24.4
13
+ opencv-python==4.8.1.78
14
+ packaging==23.2
15
+ Pillow==10.1.0
16
+ PyYAML==6.0.1
17
+ regex==2023.10.3
18
+ requests==2.31.0
19
+ safetensors==0.4.0
20
+ sympy==1.12
21
+ tokenizers==0.13.3
22
+ tqdm==4.66.1
23
+ transformers==4.27.2
24
+ typing_extensions==4.8.0
25
+ urllib3==2.0.7