VideoCrafterXtend / VBench /vbench /background_consistency.py
ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
2.43 kB
import os
import json
import logging
import numpy as np
import clip
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from vbench.utils import load_video, load_dimension_info, clip_transform
from tqdm import tqdm
def background_consistency(clip_model, preprocess, video_list, device, read_frame):
sim = 0.0
cnt = 0
video_results = []
image_transform = clip_transform(224)
for video_path in tqdm(video_list):
video_sim = 0.0
if read_frame:
video_path = video_path[:-4].replace('videos', 'frames').replace(' ', '_')
tmp_paths = [os.path.join(video_path, f) for f in sorted(os.listdir(video_path))]
images = []
for tmp_path in tmp_paths:
images.append(preprocess(Image.open(tmp_path)))
images = torch.stack(images)
else:
images = load_video(video_path)
images = image_transform(images)
images = images.to(device)
image_features = clip_model.encode_image(images)
image_features = F.normalize(image_features, dim=-1, p=2)
for i in range(len(image_features)):
image_feature = image_features[i].unsqueeze(0)
if i == 0:
first_image_feature = image_feature
else:
sim_pre = max(0.0, F.cosine_similarity(former_image_feature, image_feature).item())
sim_fir = max(0.0, F.cosine_similarity(first_image_feature, image_feature).item())
cur_sim = (sim_pre + sim_fir) / 2
video_sim += cur_sim
cnt += 1
former_image_feature = image_feature
sim_per_image = video_sim / (len(image_features) - 1)
sim += video_sim
video_results.append({'video_path': video_path, 'video_results': sim_per_image})
sim_per_video = sim / (len(video_list) - 1)
sim_per_frame = sim / cnt
return sim_per_frame, video_results
def compute_background_consistency(json_dir, device, submodules_list):
vit_path, read_frame = submodules_list[0], submodules_list[1]
clip_model, preprocess = clip.load(vit_path, device=device)
video_list, _ = load_dimension_info(json_dir, dimension='background_consistency', lang='en')
all_results, video_results = background_consistency(clip_model, preprocess, video_list, device, read_frame)
return all_results, video_results