VideoCrafterXtend / VBench /vbench /subject_consistency.py
ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
2.67 kB
import io
import os
import cv2
import json
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from vbench.utils import load_video, load_dimension_info, dino_transform, dino_transform_Image
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def subject_consistency(model, video_list, device, read_frame):
sim = 0.0
cnt = 0
video_results = []
if read_frame:
image_transform = dino_transform_Image(224)
else:
image_transform = dino_transform(224)
for video_path in tqdm(video_list):
video_sim = 0.0
if read_frame:
video_path = video_path[:-4].replace('videos', 'frames').replace(' ', '_')
tmp_paths = [os.path.join(video_path, f) for f in sorted(os.listdir(video_path))]
images = []
for tmp_path in tmp_paths:
images.append(image_transform(Image.open(tmp_path)))
else:
images = load_video(video_path)
images = image_transform(images)
for i in range(len(images)):
with torch.no_grad():
image = images[i].unsqueeze(0)
image = image.to(device)
image_features = model(image)
image_features = F.normalize(image_features, dim=-1, p=2)
if i == 0:
first_image_features = image_features
else:
sim_pre = max(0.0, F.cosine_similarity(former_image_features, image_features).item())
sim_fir = max(0.0, F.cosine_similarity(first_image_features, image_features).item())
cur_sim = (sim_pre + sim_fir) / 2
video_sim += cur_sim
cnt += 1
former_image_features = image_features
sim += video_sim
video_results.append({'video_path': video_path, 'video_results': video_sim})
sim_per_video = sim / (len(video_list) - 1)
sim_per_frame = sim / cnt
return sim_per_frame, video_results
def compute_subject_consistency(json_dir, device, submodules_list):
dino_model = torch.hub.load(**submodules_list).to(device)
read_frame = submodules_list['read_frame']
logger.info("Initialize DINO success")
video_list, _ = load_dimension_info(json_dir, dimension='subject_consistency', lang='en')
all_results, video_results = subject_consistency(dino_model, video_list, device, read_frame)
return all_results, video_results