|
import os |
|
import json |
|
|
|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
from vbench.utils import load_video, load_dimension_info, read_frames_decord_by_fps |
|
from vbench.third_party.grit_model import DenseCaptioning |
|
|
|
import logging |
|
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def get_dect_from_grit(model, image_arrays): |
|
pred = [] |
|
if type(image_arrays) is not list and type(image_arrays) is not np.ndarray: |
|
image_arrays = image_arrays.numpy() |
|
with torch.no_grad(): |
|
for frame in image_arrays: |
|
ret = model.run_caption_tensor(frame) |
|
cur_pred = [] |
|
if len(ret[0])<1: |
|
cur_pred.append(['','']) |
|
else: |
|
for idx, cap_det in enumerate(ret[0]): |
|
cur_pred.append([cap_det[0], cap_det[2][0]]) |
|
pred.append(cur_pred) |
|
return pred |
|
|
|
def check_generate(color_key, object_key, predictions): |
|
cur_object_color, cur_object = 0, 0 |
|
for frame_pred in predictions: |
|
object_flag, color_flag = False, False |
|
for pred in frame_pred: |
|
if object_key == pred[1]: |
|
for color_query in ["white","red","pink","blue","silver","purple","orange","green","gray","yellow","black","grey"]: |
|
if color_query in pred[0]: |
|
object_flag =True |
|
if color_key in pred[0]: |
|
color_flag = True |
|
if color_flag: |
|
cur_object_color+=1 |
|
if object_flag: |
|
cur_object +=1 |
|
return cur_object, cur_object_color |
|
|
|
def color(model, video_dict, device): |
|
success_frame_count_all, video_count = 0, 0 |
|
video_results = [] |
|
for info in tqdm(video_dict): |
|
if 'auxiliary_info' not in info: |
|
raise "Auxiliary info is not in json, please check your json." |
|
|
|
color_info = info['auxiliary_info']['color'] |
|
object_info = info['prompt'] |
|
object_info = object_info.replace('a ','').replace('an ','').replace(color_info,'').strip() |
|
for video_path in info['video_list']: |
|
video_arrays = load_video(video_path, num_frames=16, return_tensor=False) |
|
cur_video_pred = get_dect_from_grit(model ,video_arrays) |
|
cur_object, cur_object_color = check_generate(color_info, object_info, cur_video_pred) |
|
if cur_object>0: |
|
cur_success_frame_rate = cur_object_color/cur_object |
|
success_frame_count_all += cur_success_frame_rate |
|
video_count += 1 |
|
video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate}) |
|
success_rate = success_frame_count_all / video_count |
|
return success_rate, video_results |
|
|
|
|
|
def compute_color(json_dir, device, submodules_dict): |
|
dense_caption_model = DenseCaptioning(device) |
|
dense_caption_model.initialize_model(**submodules_dict) |
|
logger.info("Initialize detection model success") |
|
_, prompt_dict_ls = load_dimension_info(json_dir, dimension='color', lang='en') |
|
all_results, video_results = color(dense_caption_model, prompt_dict_ls, device) |
|
return all_results, video_results |
|
|