TEOChat / videollava /eval /geochat_bench.py
jirvin16's picture
Initial commit
134cb11
from eval_geochat_referring import get_single_image_results, convert_geochat_string
from collections import defaultdict
import numpy as np
import json
import ast
import re
import cv2
from shapely import wkt, Polygon, box
from infer_utils import create_mask
from matplotlib.path import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
import math
from matplotlib.path import Path
DIMENSIONS = {'FAST': 600,
'SIOR': 800,
'SOTA': 1024}
def calc_iou_individual_rotated(pred_box, gt_box, img_size=None):
"""Calculate IoU of single predicted and ground truth box
Args:
pred_box (list of floats): location of predicted object as
[[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
gt_box (list of floats): location of ground truth object as
[[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
Returns:
float: value of the IoU for the two boxes.
Raises:
AssertionError: if the box is obviously malformed
"""
pred_box = np.array(pred_box)
gt_box = np.array(gt_box)
pred_box = pred_box.reshape(4, 2)
gt_box = gt_box.reshape(4, 2)
pred_polygon = Polygon(pred_box)
gt_polygon = Polygon(gt_box)
intersection = pred_polygon.intersection(gt_polygon).area
union = pred_polygon.union(gt_polygon).area
iou = intersection / union
return iou
def get_single_image_results_rotated(gt_boxes, pred_boxes, iou_thr, img_size=None):
"""Calculates number of true_pos, false_pos, false_neg from single batch of boxes.
Args:
gt_boxes (list of list of floats): list of locations of ground truth
objects as [[x1,y1], [x2,y2], ...]
pred_boxes (dict): dict of dicts of 'boxes'
[[x1,y1], [x2,y2], ...]
iou_thr (float): value of IoU to consider as threshold for a
true prediction.
Returns:
dict: true positives (int), false positives (int), false negatives (int)
"""
all_pred_indices = range(len(pred_boxes))
all_gt_indices = range(len(gt_boxes))
if len(all_pred_indices) == 0:
tp = 0
fp = 0
fn = len(gt_boxes)
return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}
if len(all_gt_indices) == 0:
tp = 0
fp = len(pred_boxes)
fn = 0
return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}
gt_idx_thr = []
pred_idx_thr = []
ious = []
for ipb, pred_box in enumerate(pred_boxes):
for igb, gt_box in enumerate(gt_boxes):
iou = calc_iou_individual_rotated(pred_box, gt_box, img_size)
if iou > iou_thr:
gt_idx_thr.append(igb)
pred_idx_thr.append(ipb)
ious.append(iou)
args_desc = np.argsort(ious)[::-1]
if len(args_desc) == 0:
# No matches
tp = 0
fp = len(pred_boxes)
fn = len(gt_boxes)
else:
gt_match_idx = []
pred_match_idx = []
for idx in args_desc:
gt_idx = gt_idx_thr[idx]
pr_idx = pred_idx_thr[idx]
# If the boxes are unmatched, add them to matches
if (gt_idx not in gt_match_idx) and (pr_idx not in pred_match_idx):
gt_match_idx.append(gt_idx)
pred_match_idx.append(pr_idx)
tp = len(gt_match_idx)
fp = len(pred_boxes) - len(pred_match_idx)
fn = len(gt_boxes) - len(gt_match_idx)
return {'true_pos': tp, 'false_pos': fp, 'false_neg': fn}
def accuracy0_5(answer_path, dataset, aux_dataset="scripts/geochat_bench_dict.json"):
# Replace with the path to the answers file
results = None
if dataset != "geochat_xbd":
if type(answer_path) == dict:
results = answer_path
else:
results = []
with open(answer_path) as json_data:
for line in json_data:
results.append(json.loads(line))
with open(aux_dataset) as json_data:
aux_results = json.load(json_data)
img_results = {}
num_bboxes = 0
if dataset != "geochat_xbd":
print("Number of images in Geochat: ", len(aux_results))
print("Number of images predicted: ", len(results))
i = 0
# Loop over results and get precision, recall overall
for id, result in tqdm(aux_results.items()):
if dataset == "geochat_xbd":
pred = result['answer']
img_size = DIMENSIONS[result['dataset']]
pred = convert_geochat_string(pred, img_size)
ground_truth = result['ground_truth']
ground_truth = np.array(ground_truth)
num_bboxes += len(ground_truth)
img_results[id] = get_single_image_results_rotated(ground_truth, pred, iou_thr=0.5)
else:
geochat_id = id.split(".")[0]
img_size = DIMENSIONS[aux_results[geochat_id]['dataset']]
ground_truth = result['ground_truth']
ground_truth = np.array(ground_truth)
num_bboxes += len(ground_truth)
parsed_predicted = results[i]['predicted']
# Load list of predicted and round truth bounding boxes for a single image
try:
predicted_boxes = ast.literal_eval("[" + parsed_predicted + "]")
except:
match = re.search(r'\[\[.*\]\]', parsed_predicted)
if match:
predicted_boxes = ast.literal_eval(match.group())
else:
predicted_boxes = []
predicted_boxes = [[coord * 100 if coord < 1 else coord for coord in box] for box in predicted_boxes]
# scale by img_size
predicted_boxes = [[coord * img_size / 100 for coord in box] for box in predicted_boxes]
assert results[i]['ground_truth'] == result['ground_truth']
# convert the pred bboxes [xmin, ymin, xmax, ymax] to [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
pred_bboxes = []
for bbox in predicted_boxes:
x1, y1, x2, y2 = bbox
pred_bboxes.append([[x1, y1], [x2, y1], [x2, y2], [x1, y2]])
img_results[id] = get_single_image_results_rotated(ground_truth, pred_bboxes, iou_thr=0.5, img_size=img_size)
i+=1
acc = np.sum([res['true_pos'] for res in img_results.values()]) / num_bboxes
print("[email protected]: ", acc)
return acc
if __name__ == '__main__':
print("Geochat bench")
geochat_path = "scripts/geochat_bench_dict.json"
answer_path = "scripts/geochat_bench_dict.json"
acc_geochat = accuracy0_5(answer_path, dataset="geochat_xbd")
print()
print("Teochat bench")
answer_path = "/deep/u/idormoy/aicc-win24-geo-vlm/videollava/scripts/geovlm/eval/QFabric/answers/geochat-referring-checkpoint14000_prompt_strategy_interleave_chronological_prefix_True_load_8bit_True_load_4bit_False_delete_system_prompt_False_tmp_0_end.json"
acc_teochat = accuracy0_5(answer_path, dataset="geochat")
print()
print("Teochat-T bench")
answer_path = "/deep/u/idormoy/aicc-win24-geo-vlm/videollava/videollava/eval/video/geochat-bench-ckpt8000-FIXED_prompt_strategy_interleave_chronological_prefix_True_load_8bit_False_load_4bit_True_delete_system_prompt_False_tmp_0_end (1).json"
acc_teochatT = accuracy0_5(answer_path, dataset="geochat")
print()
print("VideoLLaVA bench")
answer_path = "/deep/u/idormoy/aicc-win24-geo-vlm/videollava/videollava/eval/video/geochat-referring-Video-LLaVA-7B_prompt_strategy_interleave_chronological_prefix_True_load_8bit_False_load_4bit_True_delete_system_prompt_False_tmp_0_end (1).json"
acc_videollava = accuracy0_5(answer_path, dataset="geochat")
print()
print("Overall accuracies")
print("Geochat: ", acc_geochat)
print("Teochat: ", acc_teochat)
print("Teochat-T: ", acc_teochatT)
print("VideoLLaVA: ", acc_videollava)