File size: 2,870 Bytes
04fbff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import io
import os
import cv2
import json
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from vbench2_beta_i2v.utils import load_video, load_i2v_dimension_info, dino_transform, dino_transform_Image
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def i2v_subject(model, video_pair_list, device):
    video_results = []
    sim_list = []

    max_weight = 0.5
    mean_weight = 0.5
    min_weight = 0.0

    image_transform = dino_transform_Image(224)
    frames_transform = dino_transform(224)

    for image_path, video_path in tqdm(video_pair_list):
        # input image preprocess & extract feature
        input_image = image_transform(Image.open(image_path))
        input_image = input_image.unsqueeze(0)
        input_image = input_image.to(device)
        input_image_features = model(input_image)
        input_image_features = F.normalize(input_image_features, dim=-1, p=2)

        # get frames from video
        images = load_video(video_path)
        images = frames_transform(images)

        # calculate sim between input image and frames in generated video
        conformity_scores = []
        consec_scores = []
        for i in range(len(images)):
            with torch.no_grad():
                image = images[i].unsqueeze(0)
                image = image.to(device)
                image_features = model(image)
                image_features = F.normalize(image_features, dim=-1, p=2)
                if i != 0:
                    sim_consec = max(0.0, F.cosine_similarity(former_image_features, image_features).item())
                    consec_scores.append(sim_consec)
                sim_to_input = max(0.0, F.cosine_similarity(input_image_features, image_features).item())
                conformity_scores.append(sim_to_input)
                former_image_features = image_features

        video_score = max_weight * np.max(conformity_scores) + \
            mean_weight * np.mean(consec_scores) + \
            min_weight * np.min(consec_scores)

        sim_list.append(video_score)
        video_results.append({'image_path': image_path, 'video_path': video_path, 'video_results': video_score})
    return np.mean(sim_list), video_results


def compute_i2v_subject(json_dir, device, submodules_list):
    dino_model = torch.hub.load(**submodules_list).to(device)
    resolution = submodules_list['resolution']
    logger.info("Initialize DINO success")
    video_pair_list, _ = load_i2v_dimension_info(json_dir, dimension='i2v_subject', lang='en', resolution=resolution)
    all_results, video_results = i2v_subject(dino_model, video_pair_list, device)
    return all_results, video_results