VideoCrafterXtend / VBench /vbench /spatial_relationship.py
ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
raw
history blame
5.25 kB
import os
import json
import torch
import numpy as np
from tqdm import tqdm
from vbench.utils import load_video, load_dimension_info
from vbench.third_party.grit_model import DenseCaptioning
import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def get_position_score(locality, obj1,obj2, iou_threshold=0.1):
# input obj1 and obj2 should be [x0,y0,x1,y1]
# Calculate centers of bounding boxes
box1 = {
'x_min': obj1[0],
'y_min': obj1[1],
'x_max': obj1[2],
'y_max': obj1[3],
'width': obj1[2] - obj1[0],
'height': obj1[3] - obj1[1]
}
box2 = {
'x_min': obj2[0],
'y_min': obj2[1],
'x_max': obj2[2],
'y_max': obj2[3],
'width': obj2[2] - obj2[0],
'height': obj2[3] - obj2[1]
}
# Get the object center
box1_center = ((box1['x_min'] + box1['x_max']) / 2, (box1['y_min'] + box1['y_max']) / 2)
box2_center = ((box2['x_min'] + box2['x_max']) / 2, (box2['y_min'] + box2['y_max']) / 2)
# Calculate horizontal and vertical distances
x_distance = box2_center[0] - box1_center[0]
y_distance = box2_center[1] - box1_center[1]
# Calculate IoU
x_overlap = max(0, min(box1['x_max'], box2['x_max']) - max(box1['x_min'], box2['x_min']))
y_overlap = max(0, min(box1['y_max'], box2['y_max']) - max(box1['y_min'], box2['y_min']))
intersection = x_overlap * y_overlap
box1_area = (box1['x_max'] - box1['x_min']) * (box1['y_max'] - box1['y_min'])
box2_area = (box2['x_max'] - box2['x_min']) * (box2['y_max'] - box2['y_min'])
union = box1_area + box2_area - intersection
iou = intersection / union
# get max object width and max object height
max_width = max(box1['width'], box2['width'])
max_height = max(box1['height'], box2['height'])
score=0
if locality in 'on the right of' or locality in 'on the left of':
if abs(x_distance) > abs(y_distance) and iou < iou_threshold:
score=1
elif abs(x_distance) > abs(y_distance) and iou >= iou_threshold:
score=iou_threshold/iou
else:
score=0
elif locality in 'on the bottom of' or locality in 'on the top of':
if abs(y_distance) > abs(x_distance) and iou < iou_threshold:
score=1
elif abs(y_distance) > abs(x_distance) and iou >= iou_threshold:
score=iou_threshold/iou
else:
score = 0
return score
def get_dect_from_grit(model, image_arrays):
pred = []
if type(image_arrays) is not list:
image_arrays = image_arrays.numpy()
with torch.no_grad():
for frame in image_arrays:
ret = model.run_caption_tensor(frame)
pred_cur = []
if len(ret[0])>0:
for info in ret[0]:
pred_cur.append([info[0],info[1]])
pred.append(pred_cur)
return pred
def check_generate(key_info, predictions):
key_a = key_info['object_a']
key_b = key_info['object_b']
relation = key_info['relationship']
frame_score =[]
for frame_pred in predictions:
# filter the target object
frame_obj_locats = []
cur_score = [0]
for item in frame_pred:
if (key_a == item[0]) or (key_b == item[0]):
frame_obj_locats.append(item[1])
for c_obj1 in range(len(frame_obj_locats)-1):
for c_obj2 in range(c_obj1+1 ,len(frame_obj_locats)):
score_obj1_obj2 = get_position_score(relation, frame_obj_locats[c_obj1], frame_obj_locats[c_obj2])
cur_score.append(score_obj1_obj2)
frame_score.append(max(cur_score))
return frame_score
def spatial_relationship(model, video_dict, device):
video_results = []
frame_score_overall = []
for info in tqdm(video_dict):
if 'auxiliary_info' not in info:
raise "Auxiliary info is not in json, please check your json."
object_info = info['auxiliary_info']['spatial_relationship']
for video_path in info['video_list']:
video_tensor = load_video(video_path, num_frames=16)
cur_video_pred = get_dect_from_grit(model, video_tensor.permute(0,2,3,1))
cur_video_frame_score = check_generate(object_info, cur_video_pred)
cur_success_frame_rate = np.mean(cur_video_frame_score)
frame_score_overall.extend(cur_video_frame_score)
video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate, 'frame_results':cur_video_frame_score})
success_rate = np.mean(frame_score_overall)
return success_rate, video_results
def compute_spatial_relationship(json_dir, device, submodules_dict):
dense_caption_model = DenseCaptioning(device)
dense_caption_model.initialize_model_det(**submodules_dict)
logger.info("Initialize detection model success")
_, prompt_dict_ls = load_dimension_info(json_dir, dimension='spatial_relationship', lang='en')
all_results, video_results = spatial_relationship(dense_caption_model, prompt_dict_ls, device)
return all_results, video_results