|
from inspect import signature |
|
|
|
import torch |
|
|
|
from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms |
|
|
|
|
|
class BBoxTestMixin(object): |
|
"""Mixin class for test time augmentation of bboxes.""" |
|
|
|
def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas): |
|
"""Merge augmented detection bboxes and scores. |
|
|
|
Args: |
|
aug_bboxes (list[Tensor]): shape (n, 4*#class) |
|
aug_scores (list[Tensor] or None): shape (n, #class) |
|
img_shapes (list[Tensor]): shape (3, ). |
|
|
|
Returns: |
|
tuple: (bboxes, scores) |
|
""" |
|
recovered_bboxes = [] |
|
for bboxes, img_info in zip(aug_bboxes, img_metas): |
|
img_shape = img_info[0]['img_shape'] |
|
scale_factor = img_info[0]['scale_factor'] |
|
flip = img_info[0]['flip'] |
|
flip_direction = img_info[0]['flip_direction'] |
|
bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip, |
|
flip_direction) |
|
recovered_bboxes.append(bboxes) |
|
bboxes = torch.cat(recovered_bboxes, dim=0) |
|
if aug_scores is None: |
|
return bboxes |
|
else: |
|
scores = torch.cat(aug_scores, dim=0) |
|
return bboxes, scores |
|
|
|
def aug_test_bboxes(self, feats, img_metas, rescale=False): |
|
"""Test det bboxes with test time augmentation. |
|
|
|
Args: |
|
feats (list[Tensor]): the outer list indicates test-time |
|
augmentations and inner Tensor should have a shape NxCxHxW, |
|
which contains features for all images in the batch. |
|
img_metas (list[list[dict]]): the outer list indicates test-time |
|
augs (multiscale, flip, etc.) and the inner list indicates |
|
images in a batch. each dict has image information. |
|
rescale (bool, optional): Whether to rescale the results. |
|
Defaults to False. |
|
|
|
Returns: |
|
list[ndarray]: bbox results of each class |
|
""" |
|
|
|
gb_sig = signature(self.get_bboxes) |
|
gb_args = [p.name for p in gb_sig.parameters.values()] |
|
if hasattr(self, '_get_bboxes'): |
|
gbs_sig = signature(self._get_bboxes) |
|
else: |
|
gbs_sig = signature(self._get_bboxes_single) |
|
gbs_args = [p.name for p in gbs_sig.parameters.values()] |
|
assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \ |
|
f'{self.__class__.__name__}' \ |
|
' does not support test-time augmentation' |
|
|
|
aug_bboxes = [] |
|
aug_scores = [] |
|
aug_factors = [] |
|
for x, img_meta in zip(feats, img_metas): |
|
|
|
outs = self.forward(x) |
|
bbox_inputs = outs + (img_meta, self.test_cfg, False, False) |
|
bbox_outputs = self.get_bboxes(*bbox_inputs)[0] |
|
aug_bboxes.append(bbox_outputs[0]) |
|
aug_scores.append(bbox_outputs[1]) |
|
|
|
|
|
if len(bbox_outputs) >= 3: |
|
aug_factors.append(bbox_outputs[2]) |
|
|
|
|
|
merged_bboxes, merged_scores = self.merge_aug_bboxes( |
|
aug_bboxes, aug_scores, img_metas) |
|
merged_factors = torch.cat(aug_factors, dim=0) if aug_factors else None |
|
det_bboxes, det_labels = multiclass_nms( |
|
merged_bboxes, |
|
merged_scores, |
|
self.test_cfg.score_thr, |
|
self.test_cfg.nms, |
|
self.test_cfg.max_per_img, |
|
score_factors=merged_factors) |
|
|
|
if rescale: |
|
_det_bboxes = det_bboxes |
|
else: |
|
_det_bboxes = det_bboxes.clone() |
|
_det_bboxes[:, :4] *= det_bboxes.new_tensor( |
|
img_metas[0][0]['scale_factor']) |
|
bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes) |
|
return bbox_results |
|
|