VideoCrafterXtend / VBench /vbench /overall_consistency.py
ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
2.78 kB
import os
import json
import numpy as np
import torch
import clip
from tqdm import tqdm
from vbench.utils import load_video, load_dimension_info, clip_transform, read_frames_decord_by_fps, CACHE_DIR
from vbench.third_party.ViCLIP.viclip import ViCLIP
from vbench.third_party.ViCLIP.simple_tokenizer import SimpleTokenizer
def get_text_features(model, input_text, tokenizer, text_feature_dict={}):
if input_text in text_feature_dict:
return text_feature_dict[input_text]
text_template= f"{input_text}"
with torch.no_grad():
text_features = model.encode_text(text_template).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_feature_dict[input_text] = text_features
return text_features
def get_vid_features(model, input_frames):
with torch.no_grad():
clip_feat = model.encode_vision(input_frames,test=True).float()
clip_feat /= clip_feat.norm(dim=-1, keepdim=True)
return clip_feat
def get_predict_label(clip_feature, text_feats_tensor, top=5):
label_probs = (100.0 * clip_feature @ text_feats_tensor.T).softmax(dim=-1)
top_probs, top_labels = label_probs.cpu().topk(top, dim=-1)
return top_probs, top_labels
def overall_consistency(clip_model, video_dict, tokenizer, device, sample="middle"):
sim = []
video_results = []
image_transform = clip_transform(224)
for info in tqdm(video_dict):
query = info['prompt']
text = clip.tokenize([query]).to(device)
video_list = info['video_list']
for video_path in video_list:
cur_video = []
with torch.no_grad():
images = read_frames_decord_by_fps(video_path, num_frames=8, sample=sample)
images = image_transform(images)
images = images.to(device)
clip_feat = get_vid_features(clip_model,images.unsqueeze(0))
text_feat = get_text_features(clip_model, query, tokenizer)
logit_per_text = clip_feat @ text_feat.T
score_per_video = float(logit_per_text[0][0].cpu())
sim.append(score_per_video)
video_results.append({'video_path': video_path, 'video_results': score_per_video})
avg_score = np.mean(sim)
return avg_score, video_results
def compute_overall_consistency(json_dir, device, submodules_list):
tokenizer = SimpleTokenizer(os.path.join(CACHE_DIR, "ViCLIP/bpe_simple_vocab_16e6.txt.gz"))
viclip = ViCLIP(tokenizer= tokenizer, **submodules_list).to(device)
_, video_dict = load_dimension_info(json_dir, dimension='overall_consistency', lang='en')
all_results, video_results = overall_consistency(viclip, video_dict, tokenizer, device)
return all_results, video_results