ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
import io
import os
import cv2
import json
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from vbench2_beta_i2v.utils import load_video, load_i2v_dimension_info, dino_transform, dino_transform_Image
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def i2v_subject(model, video_pair_list, device):
video_results = []
sim_list = []
max_weight = 0.5
mean_weight = 0.5
min_weight = 0.0
image_transform = dino_transform_Image(224)
frames_transform = dino_transform(224)
for image_path, video_path in tqdm(video_pair_list):
# input image preprocess & extract feature
input_image = image_transform(Image.open(image_path))
input_image = input_image.unsqueeze(0)
input_image = input_image.to(device)
input_image_features = model(input_image)
input_image_features = F.normalize(input_image_features, dim=-1, p=2)
# get frames from video
images = load_video(video_path)
images = frames_transform(images)
# calculate sim between input image and frames in generated video
conformity_scores = []
consec_scores = []
for i in range(len(images)):
with torch.no_grad():
image = images[i].unsqueeze(0)
image = image.to(device)
image_features = model(image)
image_features = F.normalize(image_features, dim=-1, p=2)
if i != 0:
sim_consec = max(0.0, F.cosine_similarity(former_image_features, image_features).item())
consec_scores.append(sim_consec)
sim_to_input = max(0.0, F.cosine_similarity(input_image_features, image_features).item())
conformity_scores.append(sim_to_input)
former_image_features = image_features
video_score = max_weight * np.max(conformity_scores) + \
mean_weight * np.mean(consec_scores) + \
min_weight * np.min(consec_scores)
sim_list.append(video_score)
video_results.append({'image_path': image_path, 'video_path': video_path, 'video_results': video_score})
return np.mean(sim_list), video_results
def compute_i2v_subject(json_dir, device, submodules_list):
dino_model = torch.hub.load(**submodules_list).to(device)
resolution = submodules_list['resolution']
logger.info("Initialize DINO success")
video_pair_list, _ = load_i2v_dimension_info(json_dir, dimension='i2v_subject', lang='en', resolution=resolution)
all_results, video_results = i2v_subject(dino_model, video_pair_list, device)
return all_results, video_results