juanpablomesa
commited on
Commit
·
e8faf30
1
Parent(s):
e30ecf9
Initial commit of XCLIP video processing endpoint only video frames
Browse files- handler.py +245 -0
- requirements.txt +25 -0
handler.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from typing import Any, Dict, List
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import tempfile
|
6 |
+
import numpy as np
|
7 |
+
import requests
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from transformers import AutoTokenizer, XCLIPModel, XCLIPProcessor
|
11 |
+
from huggingface_hub import logging
|
12 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
13 |
+
|
14 |
+
import timeit
|
15 |
+
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
|
19 |
+
class EndpointHandler:
|
20 |
+
def __init__(self, path=""):
|
21 |
+
# Preload all the elements you are going to need at inference.
|
22 |
+
# pseudo
|
23 |
+
# self.model = load_model(path)
|
24 |
+
model_id = "microsoft/xclip-base-patch16-zero-shot"
|
25 |
+
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
self.processor = XCLIPProcessor.from_pretrained(path)
|
28 |
+
self.model = XCLIPModel.from_pretrained(path).to(self.device)
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
30 |
+
|
31 |
+
logging.set_verbosity_debug()
|
32 |
+
self.logger = logging.get_logger(__name__)
|
33 |
+
# Check if CUDA (GPU support) is available
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
self.logger.info("GPU is available for inference.")
|
36 |
+
self.logger.info(f"Using {torch.cuda.get_device_name(0)}")
|
37 |
+
else:
|
38 |
+
self.logger.info("GPU is not available, using CPU for inference.")
|
39 |
+
|
40 |
+
def download_video_as_bytes(self, url: str) -> (bytes, dict):
|
41 |
+
"""
|
42 |
+
Download a video from a given URL, load it in RAM, and return it as bytes.
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
- url (str): The URL of the video to download.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
- bytes or None: The video content as bytes if successful, None otherwise.
|
49 |
+
- dict or None: The video download headers if succesful, None otherwise.
|
50 |
+
"""
|
51 |
+
try:
|
52 |
+
response = requests.get(url)
|
53 |
+
response.raise_for_status() # Raise an exception for HTTP errors
|
54 |
+
return response.content, response.headers
|
55 |
+
except requests.RequestException as e:
|
56 |
+
print(f"Error downloading the video: {e}")
|
57 |
+
return None, None
|
58 |
+
|
59 |
+
def extract_evenly_spaced_frames_from_bytes(
|
60 |
+
self, video_bytes: bytes, num_frames: int = 32
|
61 |
+
) -> list:
|
62 |
+
# Write bytes to a temporary file
|
63 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
|
64 |
+
temp_video.write(video_bytes)
|
65 |
+
temp_video.flush()
|
66 |
+
|
67 |
+
# Create a VideoCapture object using the temporary file's name
|
68 |
+
vidcap = cv2.VideoCapture(temp_video.name)
|
69 |
+
|
70 |
+
# Get the total number of frames in the video
|
71 |
+
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
72 |
+
|
73 |
+
# Calculate the interval at which frames should be extracted
|
74 |
+
interval = total_frames // num_frames
|
75 |
+
|
76 |
+
frames = []
|
77 |
+
|
78 |
+
for i in range(num_frames):
|
79 |
+
# Set the video position to the next frame to be captured
|
80 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i * interval)
|
81 |
+
|
82 |
+
# Read the frame
|
83 |
+
success, image = vidcap.read()
|
84 |
+
|
85 |
+
# If successfully read, add the frame to the list
|
86 |
+
if success:
|
87 |
+
frames.append(image)
|
88 |
+
|
89 |
+
return frames
|
90 |
+
|
91 |
+
def preprocess_frames(self, video_frames):
|
92 |
+
"""
|
93 |
+
Define a preprocessing function to convert video frames into a format suitable for the model
|
94 |
+
"""
|
95 |
+
frames = np.array(video_frames)
|
96 |
+
# Use the XCLIP Processor to preprocess the frames
|
97 |
+
inputs = self.processor(
|
98 |
+
text=None, videos=list(frames), return_tensors="pt", padding=True
|
99 |
+
).to(self.device)
|
100 |
+
|
101 |
+
return inputs
|
102 |
+
|
103 |
+
def embed_frames_with_xclip_processing(self, frames):
|
104 |
+
# Initialize an empty list to store the frame embeddings
|
105 |
+
frame_embeddings = []
|
106 |
+
|
107 |
+
frame_preprocessed = self.preprocess_frames(frames)
|
108 |
+
|
109 |
+
# Pass the preprocessed frame through the model to get the frame embeddings
|
110 |
+
frame_embedding = self.model.get_video_features(**frame_preprocessed)
|
111 |
+
|
112 |
+
# Add the frame embeddings to the list
|
113 |
+
frame_embeddings.append(frame_embedding)
|
114 |
+
|
115 |
+
# Stack the list of frame embeddings into a single tensor
|
116 |
+
tensor = torch.stack(frame_embeddings)
|
117 |
+
|
118 |
+
# detach text emb from graph, move to CPU, and convert to numpy array
|
119 |
+
batch_emb = tensor.squeeze(0)
|
120 |
+
|
121 |
+
batch_emb = batch_emb.cpu().detach().numpy()
|
122 |
+
|
123 |
+
# NORMALIZE
|
124 |
+
batch_emb = batch_emb.T / np.linalg.norm(batch_emb, axis=1)
|
125 |
+
# transpose back to (21, 512)
|
126 |
+
|
127 |
+
batch_emb = batch_emb.tolist()
|
128 |
+
|
129 |
+
return batch_emb
|
130 |
+
|
131 |
+
def process_video(self, video_url, video_metadata):
|
132 |
+
try:
|
133 |
+
video_bytes, video_headers = self.download_video_as_bytes(video_url)
|
134 |
+
frames = self.extract_evenly_spaced_frames_from_bytes(
|
135 |
+
video_bytes, num_frames=32
|
136 |
+
)
|
137 |
+
frame_embeddings = self.embed_frames_with_xclip_processing(frames)
|
138 |
+
video_metadata["url"] = video_url
|
139 |
+
return frame_embeddings, video_metadata
|
140 |
+
except Exception as e:
|
141 |
+
print(e)
|
142 |
+
return None, None, None
|
143 |
+
|
144 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
145 |
+
"""
|
146 |
+
Process the input data based on its type and return the embeddings.
|
147 |
+
|
148 |
+
This method accepts a dictionary with a 'process_type' key that can be either 'images' or 'text'.
|
149 |
+
If 'process_type' is 'images', the method expects a list of image URLs under the 'images_urls' key.
|
150 |
+
It downloads and processes these images, and returns their embeddings.
|
151 |
+
If 'process_type' is 'text', the method expects a string query under the 'query' key.
|
152 |
+
It processes this text and returns its embedding.
|
153 |
+
|
154 |
+
Parameters:
|
155 |
+
- data: Dict[str, Any]
|
156 |
+
A dictionary containing the data to be processed.
|
157 |
+
It must include a 'process_type' key with value either 'images' or 'text'.
|
158 |
+
If 'process_type' is 'images', data should also include 'images_urls' key with a list of image URLs.
|
159 |
+
If 'process_type' is 'text', data should also include 'query' key with a string query.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
- List[Dict[str, Any]]
|
163 |
+
A list of dictionaries, each containing the embeddings of the processed data.
|
164 |
+
If an error occurs during processing, the dictionary will include an 'error' key with the error message.
|
165 |
+
|
166 |
+
Raises:
|
167 |
+
- ValueError: If the 'process_type' key is not present in data, or if the required keys for 'images' or 'text' are not present or are of the wrong type.
|
168 |
+
"""
|
169 |
+
|
170 |
+
if data["process_type"] == "videos":
|
171 |
+
try:
|
172 |
+
if "videos_urls" not in data or not isinstance(
|
173 |
+
data["videos_urls"], list
|
174 |
+
):
|
175 |
+
raise ValueError(
|
176 |
+
"Data must contain 'videos_urls' key with a list of videos urls."
|
177 |
+
)
|
178 |
+
|
179 |
+
batch_size = 10
|
180 |
+
if "batch_size" in data:
|
181 |
+
batch_size = int(data["batch_size"])
|
182 |
+
# Download and process the videos
|
183 |
+
processed_video_embeddings = []
|
184 |
+
processed_videos_metadata = []
|
185 |
+
|
186 |
+
for i in range(0, len(data["videos_urls"]), batch_size):
|
187 |
+
videos_urls = data["videos_urls"][i : i + batch_size]
|
188 |
+
videos_metadata = data["videos_metadata"][i : i + batch_size]
|
189 |
+
|
190 |
+
with ThreadPoolExecutor() as executor:
|
191 |
+
futures = [
|
192 |
+
executor.submit(self.process_video, url, metadata)
|
193 |
+
for url, metadata in zip(videos_urls, videos_metadata)
|
194 |
+
]
|
195 |
+
for future in as_completed(futures):
|
196 |
+
frame_embeddings, video_metadata = future.result()
|
197 |
+
if frame_embeddings is not None:
|
198 |
+
processed_video_embeddings.append(frame_embeddings)
|
199 |
+
|
200 |
+
processed_metadata = {
|
201 |
+
"text": video_metadata["caption"],
|
202 |
+
"source": video_metadata["url"],
|
203 |
+
"source_type": "video_frames",
|
204 |
+
**video_metadata,
|
205 |
+
}
|
206 |
+
processed_videos_metadata.append(processed_metadata)
|
207 |
+
# Return the embeddings
|
208 |
+
return {
|
209 |
+
"embeddings": processed_video_embeddings,
|
210 |
+
"metadata": processed_videos_metadata,
|
211 |
+
}
|
212 |
+
|
213 |
+
except Exception as e:
|
214 |
+
print(f"Error during videos processing: {str(e)}")
|
215 |
+
return {"embeddings": [], "error": str(e)}
|
216 |
+
|
217 |
+
elif data["process_type"] == "text":
|
218 |
+
if "query" not in data or not isinstance(data["query"], str):
|
219 |
+
raise ValueError("Data must contain 'query' key which is a str.")
|
220 |
+
query = data["query"]
|
221 |
+
inputs = self.tokenizer(query, return_tensors="pt").to(self.device)
|
222 |
+
text_emb = self.model.get_text_features(**inputs)
|
223 |
+
# detach text emb from graph, move to CPU, and convert to numpy array
|
224 |
+
text_emb = text_emb.detach().cpu().numpy()
|
225 |
+
|
226 |
+
# calculate value to normalize each vector by and normalize them
|
227 |
+
norm_factor = np.linalg.norm(text_emb, axis=1)
|
228 |
+
|
229 |
+
text_emb = text_emb.T / norm_factor
|
230 |
+
# transpose back to (21, 512)
|
231 |
+
text_emb = text_emb.T
|
232 |
+
|
233 |
+
# Converting tensor to list for JSON response
|
234 |
+
text_emb_list = text_emb.tolist()
|
235 |
+
|
236 |
+
return {"embeddings": text_emb_list}
|
237 |
+
|
238 |
+
else:
|
239 |
+
print(
|
240 |
+
f"Error during CLIP endpoint processing: data['process_type']: {data['process_type']} neither 'images' or 'text'"
|
241 |
+
)
|
242 |
+
return {"embeddings": [], "error": str(e)}
|
243 |
+
|
244 |
+
# pseudo
|
245 |
+
# self.model(input)
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
certifi==2023.7.22
|
2 |
+
charset-normalizer==3.3.2
|
3 |
+
colorama==0.4.6
|
4 |
+
filelock==3.13.1
|
5 |
+
fsspec==2023.10.0
|
6 |
+
huggingface-hub==0.17.3
|
7 |
+
idna==3.4
|
8 |
+
Jinja2==3.1.2
|
9 |
+
MarkupSafe==2.1.3
|
10 |
+
mpmath==1.3.0
|
11 |
+
networkx==3.2.1
|
12 |
+
numpy==1.24.4
|
13 |
+
opencv-python==4.8.1.78
|
14 |
+
packaging==23.2
|
15 |
+
Pillow==10.1.0
|
16 |
+
PyYAML==6.0.1
|
17 |
+
regex==2023.10.3
|
18 |
+
requests==2.31.0
|
19 |
+
safetensors==0.4.0
|
20 |
+
sympy==1.12
|
21 |
+
tokenizers==0.13.3
|
22 |
+
tqdm==4.66.1
|
23 |
+
transformers==4.27.2
|
24 |
+
typing_extensions==4.8.0
|
25 |
+
urllib3==2.0.7
|