|
|
|
|
|
from eval_geochat_referring import get_single_image_results, convert_geochat_string |
|
|
|
from collections import defaultdict |
|
import numpy as np |
|
import json |
|
import ast |
|
import re |
|
import cv2 |
|
from shapely import wkt, Polygon, box |
|
from infer_utils import create_mask |
|
from matplotlib.path import Path |
|
from tqdm import tqdm |
|
import matplotlib.pyplot as plt |
|
import time |
|
import math |
|
from matplotlib.path import Path |
|
|
|
|
|
|
|
DIMENSIONS = {'FAST': 600, |
|
'SIOR': 800, |
|
'SOTA': 1024} |
|
|
|
def calc_iou_individual_rotated(pred_box, gt_box, img_size=None): |
|
"""Calculate IoU of single predicted and ground truth box |
|
Args: |
|
pred_box (list of floats): location of predicted object as |
|
[[x1, y1], [x2, y2], [x3, y3], [x4, y4]] |
|
gt_box (list of floats): location of ground truth object as |
|
[[x1, y1], [x2, y2], [x3, y3], [x4, y4]] |
|
Returns: |
|
float: value of the IoU for the two boxes. |
|
Raises: |
|
AssertionError: if the box is obviously malformed |
|
""" |
|
|
|
pred_box = np.array(pred_box) |
|
gt_box = np.array(gt_box) |
|
pred_box = pred_box.reshape(4, 2) |
|
gt_box = gt_box.reshape(4, 2) |
|
pred_polygon = Polygon(pred_box) |
|
gt_polygon = Polygon(gt_box) |
|
intersection = pred_polygon.intersection(gt_polygon).area |
|
union = pred_polygon.union(gt_polygon).area |
|
iou = intersection / union |
|
|
|
return iou |
|
|
|
|
|
def get_single_image_results_rotated(gt_boxes, pred_boxes, iou_thr, img_size=None): |
|
"""Calculates number of true_pos, false_pos, false_neg from single batch of boxes. |
|
Args: |
|
gt_boxes (list of list of floats): list of locations of ground truth |
|
objects as [[x1,y1], [x2,y2], ...] |
|
pred_boxes (dict): dict of dicts of 'boxes' |
|
[[x1,y1], [x2,y2], ...] |
|
iou_thr (float): value of IoU to consider as threshold for a |
|
true prediction. |
|
Returns: |
|
dict: true positives (int), false positives (int), false negatives (int) |
|
""" |
|
|
|
all_pred_indices = range(len(pred_boxes)) |
|
all_gt_indices = range(len(gt_boxes)) |
|
if len(all_pred_indices) == 0: |
|
tp = 0 |
|
fp = 0 |
|
fn = len(gt_boxes) |
|
return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn} |
|
if len(all_gt_indices) == 0: |
|
tp = 0 |
|
fp = len(pred_boxes) |
|
fn = 0 |
|
return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn} |
|
|
|
gt_idx_thr = [] |
|
pred_idx_thr = [] |
|
ious = [] |
|
for ipb, pred_box in enumerate(pred_boxes): |
|
for igb, gt_box in enumerate(gt_boxes): |
|
iou = calc_iou_individual_rotated(pred_box, gt_box, img_size) |
|
if iou > iou_thr: |
|
gt_idx_thr.append(igb) |
|
pred_idx_thr.append(ipb) |
|
ious.append(iou) |
|
|
|
args_desc = np.argsort(ious)[::-1] |
|
if len(args_desc) == 0: |
|
|
|
tp = 0 |
|
fp = len(pred_boxes) |
|
fn = len(gt_boxes) |
|
else: |
|
gt_match_idx = [] |
|
pred_match_idx = [] |
|
for idx in args_desc: |
|
gt_idx = gt_idx_thr[idx] |
|
pr_idx = pred_idx_thr[idx] |
|
|
|
if (gt_idx not in gt_match_idx) and (pr_idx not in pred_match_idx): |
|
gt_match_idx.append(gt_idx) |
|
pred_match_idx.append(pr_idx) |
|
tp = len(gt_match_idx) |
|
fp = len(pred_boxes) - len(pred_match_idx) |
|
fn = len(gt_boxes) - len(gt_match_idx) |
|
|
|
return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn} |
|
|
|
|
|
def accuracy0_5(answer_path, dataset, aux_dataset="scripts/geochat_bench_dict.json"): |
|
|
|
results = None |
|
if dataset != "geochat_xbd": |
|
|
|
if type(answer_path) == dict: |
|
results = answer_path |
|
else: |
|
results = [] |
|
with open(answer_path) as json_data: |
|
for line in json_data: |
|
results.append(json.loads(line)) |
|
|
|
with open(aux_dataset) as json_data: |
|
aux_results = json.load(json_data) |
|
|
|
img_results = {} |
|
num_bboxes = 0 |
|
|
|
if dataset != "geochat_xbd": |
|
print("Number of images in Geochat: ", len(aux_results)) |
|
print("Number of images predicted: ", len(results)) |
|
|
|
i = 0 |
|
|
|
for id, result in tqdm(aux_results.items()): |
|
|
|
if dataset == "geochat_xbd": |
|
pred = result['answer'] |
|
|
|
img_size = DIMENSIONS[result['dataset']] |
|
pred = convert_geochat_string(pred, img_size) |
|
|
|
ground_truth = result['ground_truth'] |
|
ground_truth = np.array(ground_truth) |
|
num_bboxes += len(ground_truth) |
|
|
|
img_results[id] = get_single_image_results_rotated(ground_truth, pred, iou_thr=0.5) |
|
|
|
else: |
|
|
|
geochat_id = id.split(".")[0] |
|
|
|
img_size = DIMENSIONS[aux_results[geochat_id]['dataset']] |
|
ground_truth = result['ground_truth'] |
|
ground_truth = np.array(ground_truth) |
|
num_bboxes += len(ground_truth) |
|
|
|
parsed_predicted = results[i]['predicted'] |
|
|
|
try: |
|
predicted_boxes = ast.literal_eval("[" + parsed_predicted + "]") |
|
except: |
|
match = re.search(r'\[\[.*\]\]', parsed_predicted) |
|
if match: |
|
predicted_boxes = ast.literal_eval(match.group()) |
|
else: |
|
predicted_boxes = [] |
|
|
|
predicted_boxes = [[coord * 100 if coord < 1 else coord for coord in box] for box in predicted_boxes] |
|
|
|
|
|
predicted_boxes = [[coord * img_size / 100 for coord in box] for box in predicted_boxes] |
|
|
|
assert results[i]['ground_truth'] == result['ground_truth'] |
|
|
|
|
|
pred_bboxes = [] |
|
for bbox in predicted_boxes: |
|
x1, y1, x2, y2 = bbox |
|
pred_bboxes.append([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]) |
|
|
|
img_results[id] = get_single_image_results_rotated(ground_truth, pred_bboxes, iou_thr=0.5, img_size=img_size) |
|
|
|
i+=1 |
|
|
|
|
|
acc = np.sum([res['true_pos'] for res in img_results.values()]) / num_bboxes |
|
print("[email protected]: ", acc) |
|
return acc |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
print("Geochat bench") |
|
geochat_path = "scripts/geochat_bench_dict.json" |
|
answer_path = "scripts/geochat_bench_dict.json" |
|
acc_geochat = accuracy0_5(answer_path, dataset="geochat_xbd") |
|
print() |
|
|
|
|
|
print("Teochat bench") |
|
answer_path = "/deep/u/idormoy/aicc-win24-geo-vlm/videollava/scripts/geovlm/eval/QFabric/answers/geochat-referring-checkpoint14000_prompt_strategy_interleave_chronological_prefix_True_load_8bit_True_load_4bit_False_delete_system_prompt_False_tmp_0_end.json" |
|
acc_teochat = accuracy0_5(answer_path, dataset="geochat") |
|
print() |
|
|
|
|
|
print("Teochat-T bench") |
|
answer_path = "/deep/u/idormoy/aicc-win24-geo-vlm/videollava/videollava/eval/video/geochat-bench-ckpt8000-FIXED_prompt_strategy_interleave_chronological_prefix_True_load_8bit_False_load_4bit_True_delete_system_prompt_False_tmp_0_end (1).json" |
|
acc_teochatT = accuracy0_5(answer_path, dataset="geochat") |
|
print() |
|
|
|
|
|
|
|
print("VideoLLaVA bench") |
|
answer_path = "/deep/u/idormoy/aicc-win24-geo-vlm/videollava/videollava/eval/video/geochat-referring-Video-LLaVA-7B_prompt_strategy_interleave_chronological_prefix_True_load_8bit_False_load_4bit_True_delete_system_prompt_False_tmp_0_end (1).json" |
|
acc_videollava = accuracy0_5(answer_path, dataset="geochat") |
|
print() |
|
|
|
|
|
|
|
print("Overall accuracies") |
|
print("Geochat: ", acc_geochat) |
|
print("Teochat: ", acc_teochat) |
|
print("Teochat-T: ", acc_teochatT) |
|
print("VideoLLaVA: ", acc_videollava) |
|
|
|
|