|
from abc import abstractmethod |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init |
|
from mmcv.runner import force_fp32 |
|
|
|
from mmdet.core import multi_apply |
|
from ..builder import HEADS, build_loss |
|
from .base_dense_head import BaseDenseHead |
|
from .dense_test_mixins import BBoxTestMixin |
|
|
|
|
|
@HEADS.register_module() |
|
class AnchorFreeHead(BaseDenseHead, BBoxTestMixin): |
|
"""Anchor-free head (FCOS, Fovea, RepPoints, etc.). |
|
|
|
Args: |
|
num_classes (int): Number of categories excluding the background |
|
category. |
|
in_channels (int): Number of channels in the input feature map. |
|
feat_channels (int): Number of hidden channels. Used in child classes. |
|
stacked_convs (int): Number of stacking convs of the head. |
|
strides (tuple): Downsample factor of each feature map. |
|
dcn_on_last_conv (bool): If true, use dcn in the last layer of |
|
towers. Default: False. |
|
conv_bias (bool | str): If specified as `auto`, it will be decided by |
|
the norm_cfg. Bias of conv will be set as True if `norm_cfg` is |
|
None, otherwise False. Default: "auto". |
|
loss_cls (dict): Config of classification loss. |
|
loss_bbox (dict): Config of localization loss. |
|
conv_cfg (dict): Config dict for convolution layer. Default: None. |
|
norm_cfg (dict): Config dict for normalization layer. Default: None. |
|
train_cfg (dict): Training config of anchor head. |
|
test_cfg (dict): Testing config of anchor head. |
|
""" |
|
|
|
_version = 1 |
|
|
|
def __init__(self, |
|
num_classes, |
|
in_channels, |
|
feat_channels=256, |
|
stacked_convs=4, |
|
strides=(4, 8, 16, 32, 64), |
|
dcn_on_last_conv=False, |
|
conv_bias='auto', |
|
loss_cls=dict( |
|
type='FocalLoss', |
|
use_sigmoid=True, |
|
gamma=2.0, |
|
alpha=0.25, |
|
loss_weight=1.0), |
|
loss_bbox=dict(type='IoULoss', loss_weight=1.0), |
|
conv_cfg=None, |
|
norm_cfg=None, |
|
train_cfg=None, |
|
test_cfg=None): |
|
super(AnchorFreeHead, self).__init__() |
|
self.num_classes = num_classes |
|
self.cls_out_channels = num_classes |
|
self.in_channels = in_channels |
|
self.feat_channels = feat_channels |
|
self.stacked_convs = stacked_convs |
|
self.strides = strides |
|
self.dcn_on_last_conv = dcn_on_last_conv |
|
assert conv_bias == 'auto' or isinstance(conv_bias, bool) |
|
self.conv_bias = conv_bias |
|
self.loss_cls = build_loss(loss_cls) |
|
self.loss_bbox = build_loss(loss_bbox) |
|
self.train_cfg = train_cfg |
|
self.test_cfg = test_cfg |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.fp16_enabled = False |
|
|
|
self._init_layers() |
|
|
|
def _init_layers(self): |
|
"""Initialize layers of the head.""" |
|
self._init_cls_convs() |
|
self._init_reg_convs() |
|
self._init_predictor() |
|
|
|
def _init_cls_convs(self): |
|
"""Initialize classification conv layers of the head.""" |
|
self.cls_convs = nn.ModuleList() |
|
for i in range(self.stacked_convs): |
|
chn = self.in_channels if i == 0 else self.feat_channels |
|
if self.dcn_on_last_conv and i == self.stacked_convs - 1: |
|
conv_cfg = dict(type='DCNv2') |
|
else: |
|
conv_cfg = self.conv_cfg |
|
self.cls_convs.append( |
|
ConvModule( |
|
chn, |
|
self.feat_channels, |
|
3, |
|
stride=1, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
bias=self.conv_bias)) |
|
|
|
def _init_reg_convs(self): |
|
"""Initialize bbox regression conv layers of the head.""" |
|
self.reg_convs = nn.ModuleList() |
|
for i in range(self.stacked_convs): |
|
chn = self.in_channels if i == 0 else self.feat_channels |
|
if self.dcn_on_last_conv and i == self.stacked_convs - 1: |
|
conv_cfg = dict(type='DCNv2') |
|
else: |
|
conv_cfg = self.conv_cfg |
|
self.reg_convs.append( |
|
ConvModule( |
|
chn, |
|
self.feat_channels, |
|
3, |
|
stride=1, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
bias=self.conv_bias)) |
|
|
|
def _init_predictor(self): |
|
"""Initialize predictor layers of the head.""" |
|
self.conv_cls = nn.Conv2d( |
|
self.feat_channels, self.cls_out_channels, 3, padding=1) |
|
self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) |
|
|
|
def init_weights(self): |
|
"""Initialize weights of the head.""" |
|
for m in self.cls_convs: |
|
if isinstance(m.conv, nn.Conv2d): |
|
normal_init(m.conv, std=0.01) |
|
for m in self.reg_convs: |
|
if isinstance(m.conv, nn.Conv2d): |
|
normal_init(m.conv, std=0.01) |
|
bias_cls = bias_init_with_prob(0.01) |
|
normal_init(self.conv_cls, std=0.01, bias=bias_cls) |
|
normal_init(self.conv_reg, std=0.01) |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
|
missing_keys, unexpected_keys, error_msgs): |
|
"""Hack some keys of the model state dict so that can load checkpoints |
|
of previous version.""" |
|
version = local_metadata.get('version', None) |
|
if version is None: |
|
|
|
|
|
bbox_head_keys = [ |
|
k for k in state_dict.keys() if k.startswith(prefix) |
|
] |
|
ori_predictor_keys = [] |
|
new_predictor_keys = [] |
|
|
|
for key in bbox_head_keys: |
|
ori_predictor_keys.append(key) |
|
key = key.split('.') |
|
conv_name = None |
|
if key[1].endswith('cls'): |
|
conv_name = 'conv_cls' |
|
elif key[1].endswith('reg'): |
|
conv_name = 'conv_reg' |
|
elif key[1].endswith('centerness'): |
|
conv_name = 'conv_centerness' |
|
else: |
|
assert NotImplementedError |
|
if conv_name is not None: |
|
key[1] = conv_name |
|
new_predictor_keys.append('.'.join(key)) |
|
else: |
|
ori_predictor_keys.pop(-1) |
|
for i in range(len(new_predictor_keys)): |
|
state_dict[new_predictor_keys[i]] = state_dict.pop( |
|
ori_predictor_keys[i]) |
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, |
|
strict, missing_keys, unexpected_keys, |
|
error_msgs) |
|
|
|
def forward(self, feats): |
|
"""Forward features from the upstream network. |
|
|
|
Args: |
|
feats (tuple[Tensor]): Features from the upstream network, each is |
|
a 4D-tensor. |
|
|
|
Returns: |
|
tuple: Usually contain classification scores and bbox predictions. |
|
cls_scores (list[Tensor]): Box scores for each scale level, |
|
each is a 4D-tensor, the channel number is |
|
num_points * num_classes. |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level, each is a 4D-tensor, the channel number is |
|
num_points * 4. |
|
""" |
|
return multi_apply(self.forward_single, feats)[:2] |
|
|
|
def forward_single(self, x): |
|
"""Forward features of a single scale level. |
|
|
|
Args: |
|
x (Tensor): FPN feature maps of the specified stride. |
|
|
|
Returns: |
|
tuple: Scores for each class, bbox predictions, features |
|
after classification and regression conv layers, some |
|
models needs these features like FCOS. |
|
""" |
|
cls_feat = x |
|
reg_feat = x |
|
|
|
for cls_layer in self.cls_convs: |
|
cls_feat = cls_layer(cls_feat) |
|
cls_score = self.conv_cls(cls_feat) |
|
|
|
for reg_layer in self.reg_convs: |
|
reg_feat = reg_layer(reg_feat) |
|
bbox_pred = self.conv_reg(reg_feat) |
|
return cls_score, bbox_pred, cls_feat, reg_feat |
|
|
|
@abstractmethod |
|
@force_fp32(apply_to=('cls_scores', 'bbox_preds')) |
|
def loss(self, |
|
cls_scores, |
|
bbox_preds, |
|
gt_bboxes, |
|
gt_labels, |
|
img_metas, |
|
gt_bboxes_ignore=None): |
|
"""Compute loss of the head. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level, |
|
each is a 4D-tensor, the channel number is |
|
num_points * num_classes. |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level, each is a 4D-tensor, the channel number is |
|
num_points * 4. |
|
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with |
|
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. |
|
gt_labels (list[Tensor]): class indices corresponding to each box |
|
img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
gt_bboxes_ignore (None | list[Tensor]): specify which bounding |
|
boxes can be ignored when computing the loss. |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
@force_fp32(apply_to=('cls_scores', 'bbox_preds')) |
|
def get_bboxes(self, |
|
cls_scores, |
|
bbox_preds, |
|
img_metas, |
|
cfg=None, |
|
rescale=None): |
|
"""Transform network output for a batch into bbox predictions. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level |
|
Has shape (N, num_points * num_classes, H, W) |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level with shape (N, num_points * 4, H, W) |
|
img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
cfg (mmcv.Config): Test / postprocessing configuration, |
|
if None, test_cfg would be used |
|
rescale (bool): If True, return boxes in original image space |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def get_targets(self, points, gt_bboxes_list, gt_labels_list): |
|
"""Compute regression, classification and centerness targets for points |
|
in multiple images. |
|
|
|
Args: |
|
points (list[Tensor]): Points of each fpn level, each has shape |
|
(num_points, 2). |
|
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image, |
|
each has shape (num_gt, 4). |
|
gt_labels_list (list[Tensor]): Ground truth labels of each box, |
|
each has shape (num_gt,). |
|
""" |
|
raise NotImplementedError |
|
|
|
def _get_points_single(self, |
|
featmap_size, |
|
stride, |
|
dtype, |
|
device, |
|
flatten=False): |
|
"""Get points of a single scale level.""" |
|
h, w = featmap_size |
|
x_range = torch.arange(w, dtype=dtype, device=device) |
|
y_range = torch.arange(h, dtype=dtype, device=device) |
|
y, x = torch.meshgrid(y_range, x_range) |
|
if flatten: |
|
y = y.flatten() |
|
x = x.flatten() |
|
return y, x |
|
|
|
def get_points(self, featmap_sizes, dtype, device, flatten=False): |
|
"""Get points according to feature map sizes. |
|
|
|
Args: |
|
featmap_sizes (list[tuple]): Multi-level feature map sizes. |
|
dtype (torch.dtype): Type of points. |
|
device (torch.device): Device of points. |
|
|
|
Returns: |
|
tuple: points of each image. |
|
""" |
|
mlvl_points = [] |
|
for i in range(len(featmap_sizes)): |
|
mlvl_points.append( |
|
self._get_points_single(featmap_sizes[i], self.strides[i], |
|
dtype, device, flatten)) |
|
return mlvl_points |
|
|
|
def aug_test(self, feats, img_metas, rescale=False): |
|
"""Test function with test time augmentation. |
|
|
|
Args: |
|
feats (list[Tensor]): the outer list indicates test-time |
|
augmentations and inner Tensor should have a shape NxCxHxW, |
|
which contains features for all images in the batch. |
|
img_metas (list[list[dict]]): the outer list indicates test-time |
|
augs (multiscale, flip, etc.) and the inner list indicates |
|
images in a batch. each dict has image information. |
|
rescale (bool, optional): Whether to rescale the results. |
|
Defaults to False. |
|
|
|
Returns: |
|
list[ndarray]: bbox results of each class |
|
""" |
|
return self.aug_test_bboxes(feats, img_metas, rescale=rescale) |
|
|