meepmoo's picture
Upload folder using huggingface_hub
208b0eb verified
raw
history blame
5.21 kB
import os
from typing import List
import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision.datasets.utils import download_url
from .longclip import longclip
from .viclip import get_viclip
from .video_utils import extract_frames
# All metrics.
__all__ = ["VideoCLIPXLScore"]
_MODELS = {
"ViClip-InternVid-10M-FLT": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViClip-InternVid-10M-FLT.pth",
"LongCLIP-L": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/longclip-L.pt",
"VideoCLIP-XL-v2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/VideoCLIP-XL-v2.bin",
}
_MD5 = {
"ViClip-InternVid-10M-FLT": "b1ebf538225438b3b75e477da7735cd0",
"LongCLIP-L": "5478b662f6f85ca0ebd4bb05f9b592f3",
"VideoCLIP-XL-v2": "cebda0bab14b677ec061a57e80791f35",
}
def normalize(
data: np.array,
mean: list[float] = [0.485, 0.456, 0.406],
std: list[float] = [0.229, 0.224, 0.225]
):
v_mean = np.array(mean).reshape(1, 1, 3)
v_std = np.array(std).reshape(1, 1, 3)
return (data / 255.0 - v_mean) / v_std
class VideoCLIPXL(nn.Module):
def __init__(self, root: str = "~/.cache/clip"):
super(VideoCLIPXL, self).__init__()
self.root = os.path.expanduser(root)
if not os.path.exists(self.root):
os.makedirs(self.root)
k = "LongCLIP-L"
filename = os.path.basename(_MODELS[k])
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
self.model = longclip.load(os.path.join(self.root, filename), device="cpu")[0].float()
k = "ViClip-InternVid-10M-FLT"
filename = os.path.basename(_MODELS[k])
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
self.viclip_model = get_viclip("l", os.path.join(self.root, filename))["viclip"].float()
# delete unused encoder
del self.model.visual
del self.viclip_model.text_encoder
class VideoCLIPXLScore():
def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"):
self.root = os.path.expanduser(root)
if not os.path.exists(self.root):
os.makedirs(self.root)
k = "VideoCLIP-XL-v2"
filename = os.path.basename(_MODELS[k])
download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k])
self.model = VideoCLIPXL()
state_dict = torch.load(os.path.join(self.root, filename), map_location="cpu")
self.model.load_state_dict(state_dict)
self.model.to(device)
self.device = device
def __call__(self, videos: List[List[Image.Image]], texts: List[str]):
assert len(videos) == len(texts)
# Use cv2.resize in accordance with the official demo. Resize and Normalize => B * [T, 224, 224, 3].
videos = [[cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR) for f in v] for v in videos]
resize_videos = [[cv2.resize(f, (224, 224)) for f in v] for v in videos]
resize_normalizied_videos = [normalize(np.stack(v)) for v in resize_videos]
video_inputs = torch.stack([torch.from_numpy(v) for v in resize_normalizied_videos])
video_inputs = video_inputs.float().permute(0, 1, 4, 2, 3).to(self.device, non_blocking=True) # BTCHW
with torch.no_grad():
vid_features = torch.stack(
[self.model.viclip_model.get_vid_features(x.unsqueeze(0)).float() for x in video_inputs]
)
vid_features.squeeze_()
# vid_features = self.model.viclip_model.get_vid_features(video_inputs).float()
text_inputs = longclip.tokenize(texts, truncate=True).to(self.device)
text_features = self.model.model.encode_text(text_inputs)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
scores = text_features @ vid_features.T
return scores.tolist() if len(videos) == 1 else scores.diagonal().tolist()
def __repr__(self):
return "videoclipxl_score"
if __name__ == "__main__":
videos = ["your_video_path"] * 3
texts = [
"a joker",
"glasses and flower",
"The video opens with a view of a white building with multiple windows, partially obscured by leafless tree branches. The scene transitions to a closer view of the same building, with the tree branches more prominent in the foreground. The focus then shifts to a street sign that reads 'Abesses' in bold, yellow letters against a green background. The sign is attached to a metal structure, possibly a tram or bus stop. The sign is illuminated by a light source above it, and the background reveals a glimpse of the building and tree branches from earlier shots. The colors are muted, with the yellow sign standing out against the grey and green hues."
]
video_clip_xl_score = VideoCLIPXLScore(device="cuda")
batch_frames = []
for v in videos:
sampled_frames = extract_frames(v, sample_method="uniform", num_sampled_frames=8)[1]
batch_frames.append(sampled_frames)
print(video_clip_xl_score(batch_frames, texts))