|
import os |
|
import json |
|
|
|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
from vbench.utils import load_video, load_dimension_info |
|
from vbench.third_party.grit_model import DenseCaptioning |
|
|
|
import logging |
|
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
def get_position_score(locality, obj1,obj2, iou_threshold=0.1): |
|
|
|
|
|
box1 = { |
|
'x_min': obj1[0], |
|
'y_min': obj1[1], |
|
'x_max': obj1[2], |
|
'y_max': obj1[3], |
|
'width': obj1[2] - obj1[0], |
|
'height': obj1[3] - obj1[1] |
|
} |
|
|
|
box2 = { |
|
'x_min': obj2[0], |
|
'y_min': obj2[1], |
|
'x_max': obj2[2], |
|
'y_max': obj2[3], |
|
'width': obj2[2] - obj2[0], |
|
'height': obj2[3] - obj2[1] |
|
} |
|
|
|
|
|
box1_center = ((box1['x_min'] + box1['x_max']) / 2, (box1['y_min'] + box1['y_max']) / 2) |
|
box2_center = ((box2['x_min'] + box2['x_max']) / 2, (box2['y_min'] + box2['y_max']) / 2) |
|
|
|
|
|
x_distance = box2_center[0] - box1_center[0] |
|
y_distance = box2_center[1] - box1_center[1] |
|
|
|
|
|
x_overlap = max(0, min(box1['x_max'], box2['x_max']) - max(box1['x_min'], box2['x_min'])) |
|
y_overlap = max(0, min(box1['y_max'], box2['y_max']) - max(box1['y_min'], box2['y_min'])) |
|
intersection = x_overlap * y_overlap |
|
box1_area = (box1['x_max'] - box1['x_min']) * (box1['y_max'] - box1['y_min']) |
|
box2_area = (box2['x_max'] - box2['x_min']) * (box2['y_max'] - box2['y_min']) |
|
union = box1_area + box2_area - intersection |
|
iou = intersection / union |
|
|
|
|
|
max_width = max(box1['width'], box2['width']) |
|
max_height = max(box1['height'], box2['height']) |
|
|
|
score=0 |
|
if locality in 'on the right of' or locality in 'on the left of': |
|
if abs(x_distance) > abs(y_distance) and iou < iou_threshold: |
|
score=1 |
|
elif abs(x_distance) > abs(y_distance) and iou >= iou_threshold: |
|
score=iou_threshold/iou |
|
else: |
|
score=0 |
|
elif locality in 'on the bottom of' or locality in 'on the top of': |
|
if abs(y_distance) > abs(x_distance) and iou < iou_threshold: |
|
score=1 |
|
elif abs(y_distance) > abs(x_distance) and iou >= iou_threshold: |
|
score=iou_threshold/iou |
|
else: |
|
score = 0 |
|
return score |
|
|
|
def get_dect_from_grit(model, image_arrays): |
|
pred = [] |
|
if type(image_arrays) is not list: |
|
image_arrays = image_arrays.numpy() |
|
with torch.no_grad(): |
|
for frame in image_arrays: |
|
ret = model.run_caption_tensor(frame) |
|
pred_cur = [] |
|
if len(ret[0])>0: |
|
for info in ret[0]: |
|
pred_cur.append([info[0],info[1]]) |
|
pred.append(pred_cur) |
|
return pred |
|
|
|
def check_generate(key_info, predictions): |
|
key_a = key_info['object_a'] |
|
key_b = key_info['object_b'] |
|
relation = key_info['relationship'] |
|
frame_score =[] |
|
for frame_pred in predictions: |
|
|
|
frame_obj_locats = [] |
|
cur_score = [0] |
|
for item in frame_pred: |
|
if (key_a == item[0]) or (key_b == item[0]): |
|
frame_obj_locats.append(item[1]) |
|
for c_obj1 in range(len(frame_obj_locats)-1): |
|
for c_obj2 in range(c_obj1+1 ,len(frame_obj_locats)): |
|
score_obj1_obj2 = get_position_score(relation, frame_obj_locats[c_obj1], frame_obj_locats[c_obj2]) |
|
cur_score.append(score_obj1_obj2) |
|
frame_score.append(max(cur_score)) |
|
return frame_score |
|
|
|
def spatial_relationship(model, video_dict, device): |
|
video_results = [] |
|
frame_score_overall = [] |
|
for info in tqdm(video_dict): |
|
if 'auxiliary_info' not in info: |
|
raise "Auxiliary info is not in json, please check your json." |
|
object_info = info['auxiliary_info']['spatial_relationship'] |
|
for video_path in info['video_list']: |
|
video_tensor = load_video(video_path, num_frames=16) |
|
cur_video_pred = get_dect_from_grit(model, video_tensor.permute(0,2,3,1)) |
|
cur_video_frame_score = check_generate(object_info, cur_video_pred) |
|
cur_success_frame_rate = np.mean(cur_video_frame_score) |
|
frame_score_overall.extend(cur_video_frame_score) |
|
video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate, 'frame_results':cur_video_frame_score}) |
|
success_rate = np.mean(frame_score_overall) |
|
return success_rate, video_results |
|
|
|
|
|
def compute_spatial_relationship(json_dir, device, submodules_dict): |
|
dense_caption_model = DenseCaptioning(device) |
|
dense_caption_model.initialize_model_det(**submodules_dict) |
|
logger.info("Initialize detection model success") |
|
_, prompt_dict_ls = load_dimension_info(json_dir, dimension='spatial_relationship', lang='en') |
|
all_results, video_results = spatial_relationship(dense_caption_model, prompt_dict_ls, device) |
|
return all_results, video_results |
|
|