|
import copy |
|
import warnings |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv import ConfigDict |
|
from mmcv.cnn import normal_init |
|
from mmcv.ops import nms |
|
|
|
from ..builder import HEADS |
|
from .guided_anchor_head import GuidedAnchorHead |
|
from .rpn_test_mixin import RPNTestMixin |
|
|
|
|
|
@HEADS.register_module() |
|
class GARPNHead(RPNTestMixin, GuidedAnchorHead): |
|
"""Guided-Anchor-based RPN head.""" |
|
|
|
def __init__(self, in_channels, **kwargs): |
|
super(GARPNHead, self).__init__(1, in_channels, **kwargs) |
|
|
|
def _init_layers(self): |
|
"""Initialize layers of the head.""" |
|
self.rpn_conv = nn.Conv2d( |
|
self.in_channels, self.feat_channels, 3, padding=1) |
|
super(GARPNHead, self)._init_layers() |
|
|
|
def init_weights(self): |
|
"""Initialize weights of the head.""" |
|
normal_init(self.rpn_conv, std=0.01) |
|
super(GARPNHead, self).init_weights() |
|
|
|
def forward_single(self, x): |
|
"""Forward feature of a single scale level.""" |
|
|
|
x = self.rpn_conv(x) |
|
x = F.relu(x, inplace=True) |
|
(cls_score, bbox_pred, shape_pred, |
|
loc_pred) = super(GARPNHead, self).forward_single(x) |
|
return cls_score, bbox_pred, shape_pred, loc_pred |
|
|
|
def loss(self, |
|
cls_scores, |
|
bbox_preds, |
|
shape_preds, |
|
loc_preds, |
|
gt_bboxes, |
|
img_metas, |
|
gt_bboxes_ignore=None): |
|
losses = super(GARPNHead, self).loss( |
|
cls_scores, |
|
bbox_preds, |
|
shape_preds, |
|
loc_preds, |
|
gt_bboxes, |
|
None, |
|
img_metas, |
|
gt_bboxes_ignore=gt_bboxes_ignore) |
|
return dict( |
|
loss_rpn_cls=losses['loss_cls'], |
|
loss_rpn_bbox=losses['loss_bbox'], |
|
loss_anchor_shape=losses['loss_shape'], |
|
loss_anchor_loc=losses['loss_loc']) |
|
|
|
def _get_bboxes_single(self, |
|
cls_scores, |
|
bbox_preds, |
|
mlvl_anchors, |
|
mlvl_masks, |
|
img_shape, |
|
scale_factor, |
|
cfg, |
|
rescale=False): |
|
cfg = self.test_cfg if cfg is None else cfg |
|
|
|
cfg = copy.deepcopy(cfg) |
|
|
|
|
|
if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg: |
|
warnings.warn( |
|
'In rpn_proposal or test_cfg, ' |
|
'nms_thr has been moved to a dict named nms as ' |
|
'iou_threshold, max_num has been renamed as max_per_img, ' |
|
'name of original arguments and the way to specify ' |
|
'iou_threshold of NMS will be deprecated.') |
|
if 'nms' not in cfg: |
|
cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr)) |
|
if 'max_num' in cfg: |
|
if 'max_per_img' in cfg: |
|
assert cfg.max_num == cfg.max_per_img, f'You ' \ |
|
f'set max_num and max_per_img at the same time, ' \ |
|
f'but get {cfg.max_num} ' \ |
|
f'and {cfg.max_per_img} respectively' \ |
|
'Please delete max_num which will be deprecated.' |
|
else: |
|
cfg.max_per_img = cfg.max_num |
|
if 'nms_thr' in cfg: |
|
assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \ |
|
f'iou_threshold in nms and ' \ |
|
f'nms_thr at the same time, but get ' \ |
|
f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \ |
|
f' respectively. Please delete the ' \ |
|
f'nms_thr which will be deprecated.' |
|
|
|
assert cfg.nms.get('type', 'nms') == 'nms', 'GARPNHead only support ' \ |
|
'naive nms.' |
|
|
|
mlvl_proposals = [] |
|
for idx in range(len(cls_scores)): |
|
rpn_cls_score = cls_scores[idx] |
|
rpn_bbox_pred = bbox_preds[idx] |
|
anchors = mlvl_anchors[idx] |
|
mask = mlvl_masks[idx] |
|
assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] |
|
|
|
if mask.sum() == 0: |
|
continue |
|
rpn_cls_score = rpn_cls_score.permute(1, 2, 0) |
|
if self.use_sigmoid_cls: |
|
rpn_cls_score = rpn_cls_score.reshape(-1) |
|
scores = rpn_cls_score.sigmoid() |
|
else: |
|
rpn_cls_score = rpn_cls_score.reshape(-1, 2) |
|
|
|
|
|
|
|
scores = rpn_cls_score.softmax(dim=1)[:, :-1] |
|
|
|
|
|
scores = scores[mask] |
|
rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, |
|
4)[mask, :] |
|
if scores.dim() == 0: |
|
rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0) |
|
anchors = anchors.unsqueeze(0) |
|
scores = scores.unsqueeze(0) |
|
|
|
if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: |
|
_, topk_inds = scores.topk(cfg.nms_pre) |
|
rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] |
|
anchors = anchors[topk_inds, :] |
|
scores = scores[topk_inds] |
|
|
|
proposals = self.bbox_coder.decode( |
|
anchors, rpn_bbox_pred, max_shape=img_shape) |
|
|
|
if cfg.min_bbox_size > 0: |
|
w = proposals[:, 2] - proposals[:, 0] |
|
h = proposals[:, 3] - proposals[:, 1] |
|
valid_inds = torch.nonzero( |
|
(w >= cfg.min_bbox_size) & (h >= cfg.min_bbox_size), |
|
as_tuple=False).squeeze() |
|
proposals = proposals[valid_inds, :] |
|
scores = scores[valid_inds] |
|
|
|
proposals, _ = nms(proposals, scores, cfg.nms.iou_threshold) |
|
proposals = proposals[:cfg.nms_post, :] |
|
mlvl_proposals.append(proposals) |
|
proposals = torch.cat(mlvl_proposals, 0) |
|
if cfg.get('nms_across_levels', False): |
|
|
|
proposals, _ = nms(proposals[:, :4], proposals[:, -1], |
|
cfg.nms.iou_threshold) |
|
proposals = proposals[:cfg.max_per_img, :] |
|
else: |
|
scores = proposals[:, 4] |
|
num = min(cfg.max_per_img, proposals.shape[0]) |
|
_, topk_inds = scores.topk(num) |
|
proposals = proposals[topk_inds, :] |
|
return proposals |
|
|