|
import os |
|
from typing import List |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
from torchvision.datasets.utils import download_url |
|
|
|
from .longclip import longclip |
|
from .viclip import get_viclip |
|
from .video_utils import extract_frames |
|
|
|
|
|
__all__ = ["VideoCLIPXLScore"] |
|
|
|
_MODELS = { |
|
"ViClip-InternVid-10M-FLT": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViClip-InternVid-10M-FLT.pth", |
|
"LongCLIP-L": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/longclip-L.pt", |
|
"VideoCLIP-XL-v2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/VideoCLIP-XL-v2.bin", |
|
} |
|
_MD5 = { |
|
"ViClip-InternVid-10M-FLT": "b1ebf538225438b3b75e477da7735cd0", |
|
"LongCLIP-L": "5478b662f6f85ca0ebd4bb05f9b592f3", |
|
"VideoCLIP-XL-v2": "cebda0bab14b677ec061a57e80791f35", |
|
} |
|
|
|
def normalize( |
|
data: np.array, |
|
mean: list[float] = [0.485, 0.456, 0.406], |
|
std: list[float] = [0.229, 0.224, 0.225] |
|
): |
|
v_mean = np.array(mean).reshape(1, 1, 3) |
|
v_std = np.array(std).reshape(1, 1, 3) |
|
|
|
return (data / 255.0 - v_mean) / v_std |
|
|
|
|
|
class VideoCLIPXL(nn.Module): |
|
def __init__(self, root: str = "~/.cache/clip"): |
|
super(VideoCLIPXL, self).__init__() |
|
|
|
self.root = os.path.expanduser(root) |
|
if not os.path.exists(self.root): |
|
os.makedirs(self.root) |
|
|
|
k = "LongCLIP-L" |
|
filename = os.path.basename(_MODELS[k]) |
|
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) |
|
self.model = longclip.load(os.path.join(self.root, filename), device="cpu")[0].float() |
|
|
|
k = "ViClip-InternVid-10M-FLT" |
|
filename = os.path.basename(_MODELS[k]) |
|
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) |
|
self.viclip_model = get_viclip("l", os.path.join(self.root, filename))["viclip"].float() |
|
|
|
|
|
del self.model.visual |
|
del self.viclip_model.text_encoder |
|
|
|
|
|
class VideoCLIPXLScore(): |
|
def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"): |
|
self.root = os.path.expanduser(root) |
|
if not os.path.exists(self.root): |
|
os.makedirs(self.root) |
|
|
|
k = "VideoCLIP-XL-v2" |
|
filename = os.path.basename(_MODELS[k]) |
|
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) |
|
self.model = VideoCLIPXL() |
|
state_dict = torch.load(os.path.join(self.root, filename), map_location="cpu") |
|
self.model.load_state_dict(state_dict) |
|
self.model.to(device) |
|
|
|
self.device = device |
|
|
|
def __call__(self, videos: List[List[Image.Image]], texts: List[str]): |
|
assert len(videos) == len(texts) |
|
|
|
|
|
videos = [[cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR) for f in v] for v in videos] |
|
resize_videos = [[cv2.resize(f, (224, 224)) for f in v] for v in videos] |
|
resize_normalizied_videos = [normalize(np.stack(v)) for v in resize_videos] |
|
|
|
video_inputs = torch.stack([torch.from_numpy(v) for v in resize_normalizied_videos]) |
|
video_inputs = video_inputs.float().permute(0, 1, 4, 2, 3).to(self.device, non_blocking=True) |
|
|
|
with torch.no_grad(): |
|
vid_features = torch.stack( |
|
[self.model.viclip_model.get_vid_features(x.unsqueeze(0)).float() for x in video_inputs] |
|
) |
|
vid_features.squeeze_() |
|
|
|
text_inputs = longclip.tokenize(texts, truncate=True).to(self.device) |
|
text_features = self.model.model.encode_text(text_inputs) |
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
scores = text_features @ vid_features.T |
|
|
|
return scores.tolist() if len(videos) == 1 else scores.diagonal().tolist() |
|
|
|
def __repr__(self): |
|
return "videoclipxl_score" |
|
|
|
|
|
if __name__ == "__main__": |
|
videos = ["your_video_path"] * 3 |
|
texts = [ |
|
"a joker", |
|
"glasses and flower", |
|
"The video opens with a view of a white building with multiple windows, partially obscured by leafless tree branches. The scene transitions to a closer view of the same building, with the tree branches more prominent in the foreground. The focus then shifts to a street sign that reads 'Abesses' in bold, yellow letters against a green background. The sign is attached to a metal structure, possibly a tram or bus stop. The sign is illuminated by a light source above it, and the background reveals a glimpse of the building and tree branches from earlier shots. The colors are muted, with the yellow sign standing out against the grey and green hues." |
|
] |
|
|
|
video_clip_xl_score = VideoCLIPXLScore(device="cuda") |
|
batch_frames = [] |
|
for v in videos: |
|
sampled_frames = extract_frames(v, sample_method="uniform", num_sampled_frames=8)[1] |
|
batch_frames.append(sampled_frames) |
|
print(video_clip_xl_score(batch_frames, texts)) |