ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
3.26 kB
import os
import json
import torch
import numpy as np
from tqdm import tqdm
from vbench.utils import load_video, load_dimension_info, read_frames_decord_by_fps
from vbench.third_party.grit_model import DenseCaptioning
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_dect_from_grit(model, image_arrays):
pred = []
if type(image_arrays) is not list and type(image_arrays) is not np.ndarray:
image_arrays = image_arrays.numpy()
with torch.no_grad():
for frame in image_arrays:
ret = model.run_caption_tensor(frame)
cur_pred = []
if len(ret[0])<1:
cur_pred.append(['',''])
else:
for idx, cap_det in enumerate(ret[0]):
cur_pred.append([cap_det[0], cap_det[2][0]])
pred.append(cur_pred)
return pred
def check_generate(color_key, object_key, predictions):
cur_object_color, cur_object = 0, 0
for frame_pred in predictions:
object_flag, color_flag = False, False
for pred in frame_pred:
if object_key == pred[1]:
for color_query in ["white","red","pink","blue","silver","purple","orange","green","gray","yellow","black","grey"]:
if color_query in pred[0]:
object_flag =True
if color_key in pred[0]:
color_flag = True
if color_flag:
cur_object_color+=1
if object_flag:
cur_object +=1
return cur_object, cur_object_color
def color(model, video_dict, device):
success_frame_count_all, video_count = 0, 0
video_results = []
for info in tqdm(video_dict):
if 'auxiliary_info' not in info:
raise "Auxiliary info is not in json, please check your json."
# print(info)
color_info = info['auxiliary_info']['color']
object_info = info['prompt']
object_info = object_info.replace('a ','').replace('an ','').replace(color_info,'').strip()
for video_path in info['video_list']:
video_arrays = load_video(video_path, num_frames=16, return_tensor=False)
cur_video_pred = get_dect_from_grit(model ,video_arrays)
cur_object, cur_object_color = check_generate(color_info, object_info, cur_video_pred)
if cur_object>0:
cur_success_frame_rate = cur_object_color/cur_object
success_frame_count_all += cur_success_frame_rate
video_count += 1
video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate})
success_rate = success_frame_count_all / video_count
return success_rate, video_results
def compute_color(json_dir, device, submodules_dict):
dense_caption_model = DenseCaptioning(device)
dense_caption_model.initialize_model(**submodules_dict)
logger.info("Initialize detection model success")
_, prompt_dict_ls = load_dimension_info(json_dir, dimension='color', lang='en')
all_results, video_results = color(dense_caption_model, prompt_dict_ls, device)
return all_results, video_results