import io import os import cv2 import json import numpy as np from PIL import Image from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from vbench.utils import load_video, load_dimension_info, dino_transform, dino_transform_Image import logging logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def subject_consistency(model, video_list, device, read_frame): sim = 0.0 cnt = 0 video_results = [] if read_frame: image_transform = dino_transform_Image(224) else: image_transform = dino_transform(224) for video_path in tqdm(video_list): video_sim = 0.0 if read_frame: video_path = video_path[:-4].replace('videos', 'frames').replace(' ', '_') tmp_paths = [os.path.join(video_path, f) for f in sorted(os.listdir(video_path))] images = [] for tmp_path in tmp_paths: images.append(image_transform(Image.open(tmp_path))) else: images = load_video(video_path) images = image_transform(images) for i in range(len(images)): with torch.no_grad(): image = images[i].unsqueeze(0) image = image.to(device) image_features = model(image) image_features = F.normalize(image_features, dim=-1, p=2) if i == 0: first_image_features = image_features else: sim_pre = max(0.0, F.cosine_similarity(former_image_features, image_features).item()) sim_fir = max(0.0, F.cosine_similarity(first_image_features, image_features).item()) cur_sim = (sim_pre + sim_fir) / 2 video_sim += cur_sim cnt += 1 former_image_features = image_features sim += video_sim video_results.append({'video_path': video_path, 'video_results': video_sim}) sim_per_video = sim / (len(video_list) - 1) sim_per_frame = sim / cnt return sim_per_frame, video_results def compute_subject_consistency(json_dir, device, submodules_list): dino_model = torch.hub.load(**submodules_list).to(device) read_frame = submodules_list['read_frame'] logger.info("Initialize DINO success") video_list, _ = load_dimension_info(json_dir, dimension='subject_consistency', lang='en') all_results, video_results = subject_consistency(dino_model, video_list, device, read_frame) return all_results, video_results