File size: 4,873 Bytes
3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 3f1124e 88c0383 f1a5f04 3f1124e 88c0383 f1a5f04 3f1124e 88c0383 3f1124e 88c0383 3f1124e f1a5f04 88c0383 f1a5f04 88c0383 |
|
import numpy as np
def cxywh2xywh(cx, cy, w, h):
"""CxCyWH format to XYWH format conversion"""
x = cx - w / 2
y = cy - h / 2
return x, y, w, h
def cxywh2ltrb(cx, cy, w, h):
"""CxCyWH format to LeftRightTopBottom format"""
l = cx - w / 2
t = cy - h / 2
r = cx + w / 2
b = cy + h / 2
return l, t, r, b
def iou(ba, bb):
"""Calculate Intersection-Over-Union
Args:
ba (tuple): CxCyWH format with score
bb (tuple): CxCyWH format with score
Returns:
IoU with size of length of given box
"""
a_l, a_t, a_r, a_b, sa = ba
b_l, b_t, b_r, b_b, sb = bb
x1 = np.maximum(a_l, b_l)
y1 = np.maximum(a_t, b_t)
x2 = np.minimum(a_r, b_r)
y2 = np.minimum(a_b, b_b)
w = np.maximum(0, x2 - x1)
h = np.maximum(0, y2 - y1)
intersec = w * h
iou = (intersec) / (sa + sb - intersec)
return iou.squeeze()
def nms(cx, cy, w, h, s, iou_thresh=0.3):
"""Bounding box Non-maximum Suppression
Args:
cx, cy, w, h, s: CxCyWH Format with score boxes
iou_thresh (float, optional): IoU threshold. Defaults to 0.3.
Returns:
res: indexes of the selected boxes
"""
l, t, r, b = cxywh2ltrb(cx, cy, w, h)
areas = w * h
res = []
sort_ind = np.argsort(s, axis=-1)[::-1]
while sort_ind.shape[0] > 0:
i = sort_ind[0]
res.append(i)
_iou = iou(
(l[i], t[i], r[i], b[i], areas[i]),
(
l[sort_ind[1:]],
t[sort_ind[1:]],
r[sort_ind[1:]],
b[sort_ind[1:]],
areas[sort_ind[1:]],
),
)
sel_ind = np.where(_iou <= iou_thresh)[0]
sort_ind = sort_ind[sel_ind + 1]
return res
def filter_nonpos(boxes, agnostic_ratio=0.5, class_ratio=0.7):
"""filter out insignificant boxes
Args:
boxes (list of records): returned query to be filtered
"""
ret = []
labelwise = {}
for b in boxes:
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
if label not in labelwise:
labelwise[label] = []
labelwise[label].append(logit)
labelwise = {l: max(s) for l, s in labelwise.items()}
agnostic = max([v for _, v in labelwise.items()])
for b in boxes:
_id, cx, cy, w, h, label, logit, is_selected = b[:8]
if logit > class_ratio * labelwise[label] and logit > agnostic_ratio * agnostic:
ret.append(b)
return ret
def postprocess(matches, prompt_labels, img_matches=None, agnostic_ratio=0.4, class_ratio=0.7):
meta = []
boxes_w_img = []
matches_ = {m["img_id"]: m for m in matches}
if img_matches is not None:
img_matches_ = {m["img_id"]: m for m in img_matches}
for k in matches_.keys():
m = matches_[k]
boxes = []
boxes += list(
map(
list,
zip(
m["box_id"],
m["cx"],
m["cy"],
m["w"],
m["h"],
[prompt_labels[int(l)] for l in m["label"]],
m["logit"],
[1] * len(m["box_id"]),
),
)
)
if img_matches is not None and k in img_matches_:
img_m = img_matches_[k]
# and also those non-TopK hits and those non-topk are not anticipating training
boxes += [
i
for i in map(
list,
zip(
img_m["box_id"],
img_m["cx"],
img_m["cy"],
img_m["w"],
img_m["h"],
[prompt_labels[int(l)] for l in img_m["label"]],
img_m["logit"],
[0] * len(img_m["box_id"]),
),
)
if i[0] not in [b[0] for b in boxes]
]
else:
img_m = None
# update record metadata after query
for b in boxes:
meta.append(b[0])
# remove some non-significant boxes
boxes = filter_nonpos(boxes, agnostic_ratio=agnostic_ratio, class_ratio=class_ratio)
# doing non-maximum suppression
cx, cy, w, h, s = list(
map(lambda x: np.array(x), list(zip(*[(*b[1:5], b[6]) for b in boxes])))
)
ind = nms(cx, cy, w, h, s, 0.3)
boxes = [boxes[i] for i in ind]
if img_m is not None:
img_score = (
img_m["img_score"] if img_matches is not None else m["img_score"]
)
boxes_w_img.append(
(m["img_id"], m["img_url"], m["img_w"], m["img_h"], img_score, boxes)
)
return boxes_w_img, meta |