ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
2.42 kB
import os
import json
import torch
import numpy as np
from tqdm import tqdm
from vbench.utils import load_video, load_dimension_info, tag2text_transform
from vbench.third_party.tag2Text.tag2text import tag2text_caption
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_caption(model, image_arrays):
caption, tag_predict = model.generate(image_arrays, tag_input = None, return_tag_predict = True)
return caption
def check_generate(key_info, predictions):
cur_cnt = 0
key = key_info['scene']
for pred in predictions:
q_flag = [q in pred for q in key.split(' ')]
if len(q_flag) == sum(q_flag):
cur_cnt +=1
return cur_cnt
def scene(model, video_dict, device):
success_frame_count, frame_count = 0,0
video_results = []
transform = tag2text_transform(384)
for info in tqdm(video_dict):
if 'auxiliary_info' not in info:
raise "Auxiliary info is not in json, please check your json."
scene_info = info['auxiliary_info']['scene']
for video_path in info['video_list']:
video_array = load_video(video_path, num_frames=16, return_tensor=False, width=384, height=384)
video_tensor_list = []
for i in video_array:
video_tensor_list.append(transform(i).to(device).unsqueeze(0))
video_tensor = torch.cat(video_tensor_list)
cur_video_pred = get_caption(model, video_tensor)
cur_success_frame_count = check_generate(scene_info, cur_video_pred)
cur_success_frame_rate = cur_success_frame_count/len(cur_video_pred)
success_frame_count += cur_success_frame_count
frame_count += len(cur_video_pred)
video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate})
success_rate = success_frame_count / frame_count
return success_rate, video_results
def compute_scene(json_dir, device, submodules_dict):
model = tag2text_caption(**submodules_dict)
model.eval()
model = model.to(device)
logger.info("Initialize caption model success")
_, prompt_dict_ls = load_dimension_info(json_dir, dimension='scene', lang='en')
all_results, video_results = scene(model, prompt_dict_ls, device)
return all_results, video_results