|
import io |
|
from typing import Any, Dict, List |
|
|
|
import cv2 |
|
import tempfile |
|
import numpy as np |
|
import requests |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoTokenizer, XCLIPModel, XCLIPProcessor |
|
from huggingface_hub import logging |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from decord import VideoReader |
|
from decord import cpu |
|
|
|
import timeit |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
model_id = "microsoft/xclip-base-patch16-zero-shot" |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.processor = XCLIPProcessor.from_pretrained(path) |
|
self.model = XCLIPModel.from_pretrained(path).to(self.device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
logging.set_verbosity_debug() |
|
self.logger = logging.get_logger(__name__) |
|
|
|
if torch.cuda.is_available(): |
|
self.logger.info("GPU is available for inference.") |
|
self.logger.info(f"Using {torch.cuda.get_device_name(0)}") |
|
else: |
|
self.logger.info("GPU is not available, using CPU for inference.") |
|
|
|
def download_video_as_bytes(self, url: str) -> (bytes, dict): |
|
""" |
|
Download a video from a given URL, load it in RAM, and return it as bytes. |
|
|
|
Parameters: |
|
- url (str): The URL of the video to download. |
|
|
|
Returns: |
|
- bytes or None: The video content as bytes if successful, None otherwise. |
|
- dict or None: The video download headers if succesful, None otherwise. |
|
""" |
|
try: |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
return response.content, response.headers |
|
except requests.RequestException as e: |
|
print(f"Error downloading the video: {e}") |
|
return None, None |
|
|
|
def extract_evenly_spaced_frames_from_bytes_cv2( |
|
self, video_bytes: bytes, num_frames: int = 32 |
|
) -> list: |
|
|
|
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: |
|
temp_video.write(video_bytes) |
|
temp_video.flush() |
|
|
|
|
|
vidcap = cv2.VideoCapture(temp_video.name) |
|
|
|
|
|
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
interval = total_frames // num_frames |
|
|
|
frames = [] |
|
|
|
for i in range(num_frames): |
|
|
|
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i * interval) |
|
|
|
|
|
success, image = vidcap.read() |
|
|
|
|
|
if success: |
|
frames.append(image) |
|
|
|
return frames |
|
|
|
def extract_evenly_spaced_frames_from_bytes( |
|
self, video_bytes: bytes, num_frames: int = 32 |
|
) -> list: |
|
|
|
file_obj = io.BytesIO(video_bytes) |
|
vr = VideoReader(file_obj, ctx=cpu(0)) |
|
|
|
|
|
total_frames = len(vr) |
|
|
|
|
|
interval = total_frames // num_frames |
|
|
|
frames = [] |
|
|
|
for i in range(num_frames): |
|
|
|
frame_index = min(i * interval, total_frames - 1) |
|
|
|
|
|
frame = vr[frame_index].asnumpy() |
|
|
|
|
|
frames.append(frame) |
|
|
|
return frames |
|
|
|
def preprocess_frames(self, video_frames): |
|
""" |
|
Define a preprocessing function to convert video frames into a format suitable for the model |
|
""" |
|
frames = np.array(video_frames) |
|
|
|
inputs = self.processor( |
|
text=None, videos=list(frames), return_tensors="pt", padding=True |
|
).to(self.device) |
|
|
|
return inputs |
|
|
|
def embed_frames_with_xclip_processing(self, frames): |
|
|
|
|
|
|
|
frame_preprocessed = self.preprocess_frames(frames) |
|
|
|
|
|
|
|
frame_embedding = self.model.get_video_features(**frame_preprocessed) |
|
|
|
|
|
|
|
|
|
|
|
if frame_embedding.dim() == 2: |
|
|
|
batch_emb = torch.nn.functional.normalize(frame_embedding, p=2, dim=1) |
|
else: |
|
|
|
batch_emb = frame_embedding.squeeze(0) |
|
|
|
|
|
batch_emb = batch_emb.cpu().detach().numpy() |
|
|
|
|
|
batch_emb = batch_emb.tolist() |
|
|
|
|
|
return batch_emb |
|
|
|
def process_video(self, video_url, video_metadata): |
|
try: |
|
self.logger.info("Downloading video as bytes.") |
|
download_start_time = timeit.default_timer() |
|
video_bytes, video_headers = self.download_video_as_bytes(video_url) |
|
download_end_time = timeit.default_timer() |
|
self.logger.info( |
|
f"Video downloading took {download_end_time - download_start_time} seconds" |
|
) |
|
self.logger.info("Extracting frames.") |
|
processing_start_time = timeit.default_timer() |
|
frames = self.extract_evenly_spaced_frames_from_bytes( |
|
video_bytes, num_frames=32 |
|
) |
|
processing_end_time = timeit.default_timer() |
|
self.logger.info( |
|
f"Extracting video frames took {processing_end_time - processing_start_time} seconds" |
|
) |
|
self.logger.info("Embedding frames with Xclip.") |
|
embedding_start_time = timeit.default_timer() |
|
frame_embeddings = self.embed_frames_with_xclip_processing(frames) |
|
embedding_end_time = timeit.default_timer() |
|
self.logger.info( |
|
f"Embedding calculation took {embedding_end_time - embedding_start_time} seconds" |
|
) |
|
video_metadata["url"] = video_url |
|
self.logger.info("Returning embeddings and metadata.") |
|
return frame_embeddings, video_metadata |
|
except Exception as e: |
|
print(e) |
|
return None, None, None |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
Process the input data based on its type and return the embeddings. |
|
|
|
This method accepts a dictionary with a 'process_type' key that can be either 'images' or 'text'. |
|
If 'process_type' is 'images', the method expects a list of image URLs under the 'images_urls' key. |
|
It downloads and processes these images, and returns their embeddings. |
|
If 'process_type' is 'text', the method expects a string query under the 'query' key. |
|
It processes this text and returns its embedding. |
|
|
|
Parameters: |
|
- data: Dict[str, Any] |
|
A dictionary containing the data to be processed. |
|
It must include a 'process_type' key with value either 'images' or 'text'. |
|
If 'process_type' is 'images', data should also include 'images_urls' key with a list of image URLs. |
|
If 'process_type' is 'text', data should also include 'query' key with a string query. |
|
|
|
Returns: |
|
- List[Dict[str, Any]] |
|
A list of dictionaries, each containing the embeddings of the processed data. |
|
If an error occurs during processing, the dictionary will include an 'error' key with the error message. |
|
|
|
Raises: |
|
- ValueError: If the 'process_type' key is not present in data, or if the required keys for 'images' or 'text' are not present or are of the wrong type. |
|
""" |
|
|
|
if data["process_type"] == "videos": |
|
try: |
|
if "videos_urls" not in data or not isinstance( |
|
data["videos_urls"], list |
|
): |
|
raise ValueError( |
|
"Data must contain 'videos_urls' key with a list of videos urls." |
|
) |
|
|
|
batch_size = 4 |
|
if "batch_size" in data: |
|
batch_size = int(data["batch_size"]) |
|
|
|
processed_video_embeddings = [] |
|
processed_videos_metadata = [] |
|
|
|
for i in range(0, len(data["videos_urls"]), batch_size): |
|
videos_urls = data["videos_urls"][i : i + batch_size] |
|
videos_metadata = data["videos_metadata"][i : i + batch_size] |
|
|
|
with ThreadPoolExecutor() as executor: |
|
futures = [ |
|
executor.submit(self.process_video, url, metadata) |
|
for url, metadata in zip(videos_urls, videos_metadata) |
|
] |
|
for future in as_completed(futures): |
|
frame_embeddings, video_metadata = future.result() |
|
if frame_embeddings is not None: |
|
processed_video_embeddings.append(frame_embeddings) |
|
self.logger.info("Finished appending video embedding") |
|
processed_metadata = { |
|
"text": video_metadata["caption"], |
|
"source": video_metadata["url"], |
|
"source_type": "video_frames", |
|
**video_metadata, |
|
} |
|
processed_videos_metadata.append(processed_metadata) |
|
self.logger.info("Finished appending video metadata") |
|
self.logger.info(f"Finished processing batch {i}") |
|
|
|
self.logger.info("Returning embeddings and metadata of all batches") |
|
return { |
|
"embeddings": processed_video_embeddings, |
|
"metadata": processed_videos_metadata, |
|
} |
|
|
|
except Exception as e: |
|
print(f"Error during videos processing: {str(e)}") |
|
return {"embeddings": [], "error": str(e)} |
|
|
|
elif data["process_type"] == "text": |
|
if "query" not in data or not isinstance(data["query"], str): |
|
raise ValueError("Data must contain 'query' key which is a str.") |
|
query = data["query"] |
|
inputs = self.tokenizer(query, return_tensors="pt").to(self.device) |
|
text_emb = self.model.get_text_features(**inputs) |
|
|
|
text_emb = text_emb.detach().cpu().numpy() |
|
|
|
|
|
norm_factor = np.linalg.norm(text_emb, axis=1) |
|
|
|
text_emb = text_emb.T / norm_factor |
|
|
|
text_emb = text_emb.T |
|
|
|
|
|
text_emb_list = text_emb.tolist() |
|
|
|
return {"embeddings": text_emb_list} |
|
|
|
else: |
|
print( |
|
f"Error during CLIP endpoint processing: data['process_type']: {data['process_type']} neither 'images' or 'text'" |
|
) |
|
return {"embeddings": [], "error": str(e)} |
|
|
|
|
|
|
|
|