File size: 5,247 Bytes
04fbff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import json

import torch
import numpy as np
from tqdm import tqdm
from vbench.utils import load_video, load_dimension_info
from vbench.third_party.grit_model import DenseCaptioning

import logging
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def get_position_score(locality, obj1,obj2, iou_threshold=0.1):
    # input obj1 and obj2 should be [x0,y0,x1,y1]
    # Calculate centers of bounding boxes
    box1 = {
        'x_min': obj1[0],
        'y_min': obj1[1],
        'x_max': obj1[2],
        'y_max': obj1[3],
        'width': obj1[2] - obj1[0],
        'height': obj1[3] - obj1[1]
    }

    box2 = {
        'x_min': obj2[0],
        'y_min': obj2[1],
        'x_max': obj2[2],
        'y_max': obj2[3],
        'width': obj2[2] - obj2[0],
        'height': obj2[3] - obj2[1]
    }
    
    # Get the object center
    box1_center = ((box1['x_min'] + box1['x_max']) / 2, (box1['y_min'] + box1['y_max']) / 2)
    box2_center = ((box2['x_min'] + box2['x_max']) / 2, (box2['y_min'] + box2['y_max']) / 2)

    # Calculate horizontal and vertical distances
    x_distance = box2_center[0] - box1_center[0]
    y_distance = box2_center[1] - box1_center[1]

    # Calculate IoU
    x_overlap = max(0, min(box1['x_max'], box2['x_max']) - max(box1['x_min'], box2['x_min']))
    y_overlap = max(0, min(box1['y_max'], box2['y_max']) - max(box1['y_min'], box2['y_min']))
    intersection = x_overlap * y_overlap
    box1_area = (box1['x_max'] - box1['x_min']) * (box1['y_max'] - box1['y_min'])
    box2_area = (box2['x_max'] - box2['x_min']) * (box2['y_max'] - box2['y_min'])
    union = box1_area + box2_area - intersection
    iou = intersection / union

    # get max object width and max object height
    max_width = max(box1['width'], box2['width'])
    max_height = max(box1['height'], box2['height'])

    score=0
    if locality in 'on the right of' or locality in 'on the left of':
        if abs(x_distance) > abs(y_distance) and iou < iou_threshold:
            score=1
        elif abs(x_distance) > abs(y_distance) and iou >= iou_threshold:
            score=iou_threshold/iou
        else:
            score=0
    elif locality in 'on the bottom of' or locality in 'on the top of':
        if abs(y_distance) > abs(x_distance) and iou < iou_threshold:
            score=1
        elif abs(y_distance) > abs(x_distance) and iou >= iou_threshold:
            score=iou_threshold/iou
        else:
            score = 0
    return score

def get_dect_from_grit(model, image_arrays):
    pred = []
    if type(image_arrays) is not list:
        image_arrays = image_arrays.numpy()
    with torch.no_grad():
        for frame in image_arrays:
            ret = model.run_caption_tensor(frame)
            pred_cur = []
            if len(ret[0])>0:
                for info in ret[0]:
                    pred_cur.append([info[0],info[1]])
            pred.append(pred_cur)
    return pred

def check_generate(key_info, predictions):
    key_a = key_info['object_a']
    key_b = key_info['object_b']
    relation = key_info['relationship']
    frame_score =[]
    for frame_pred in predictions:
        # filter the target object
        frame_obj_locats = []
        cur_score = [0]
        for item in frame_pred:
            if (key_a == item[0]) or (key_b == item[0]):
                frame_obj_locats.append(item[1])
            for c_obj1 in range(len(frame_obj_locats)-1):
                for c_obj2 in range(c_obj1+1 ,len(frame_obj_locats)):
                    score_obj1_obj2 = get_position_score(relation, frame_obj_locats[c_obj1], frame_obj_locats[c_obj2])
                    cur_score.append(score_obj1_obj2)
        frame_score.append(max(cur_score))
    return frame_score

def spatial_relationship(model, video_dict, device):
    video_results = []
    frame_score_overall = []
    for info in tqdm(video_dict):
        if 'auxiliary_info' not in info:
            raise "Auxiliary info is not in json, please check your json."
        object_info = info['auxiliary_info']['spatial_relationship']
        for video_path in info['video_list']:
            video_tensor = load_video(video_path, num_frames=16)
            cur_video_pred = get_dect_from_grit(model, video_tensor.permute(0,2,3,1))
            cur_video_frame_score = check_generate(object_info, cur_video_pred)
            cur_success_frame_rate = np.mean(cur_video_frame_score)
            frame_score_overall.extend(cur_video_frame_score)
            video_results.append({'video_path': video_path, 'video_results': cur_success_frame_rate, 'frame_results':cur_video_frame_score})
    success_rate = np.mean(frame_score_overall)
    return success_rate, video_results
        

def compute_spatial_relationship(json_dir, device, submodules_dict):
    dense_caption_model = DenseCaptioning(device)
    dense_caption_model.initialize_model_det(**submodules_dict)
    logger.info("Initialize detection model success")
    _, prompt_dict_ls = load_dimension_info(json_dir, dimension='spatial_relationship', lang='en')
    all_results, video_results = spatial_relationship(dense_caption_model, prompt_dict_ls, device)
    return all_results, video_results