|
import torch |
|
import torch.nn.functional as F |
|
|
|
from mmdet.core import bbox_overlaps |
|
from ..builder import HEADS |
|
from .retina_head import RetinaHead |
|
|
|
EPS = 1e-12 |
|
|
|
|
|
@HEADS.register_module() |
|
class FreeAnchorRetinaHead(RetinaHead): |
|
"""FreeAnchor RetinaHead used in https://arxiv.org/abs/1909.02466. |
|
|
|
Args: |
|
num_classes (int): Number of categories excluding the background |
|
category. |
|
in_channels (int): Number of channels in the input feature map. |
|
stacked_convs (int): Number of conv layers in cls and reg tower. |
|
Default: 4. |
|
conv_cfg (dict): dictionary to construct and config conv layer. |
|
Default: None. |
|
norm_cfg (dict): dictionary to construct and config norm layer. |
|
Default: norm_cfg=dict(type='GN', num_groups=32, |
|
requires_grad=True). |
|
pre_anchor_topk (int): Number of boxes that be token in each bag. |
|
bbox_thr (float): The threshold of the saturated linear function. It is |
|
usually the same with the IoU threshold used in NMS. |
|
gamma (float): Gamma parameter in focal loss. |
|
alpha (float): Alpha parameter in focal loss. |
|
""" |
|
|
|
def __init__(self, |
|
num_classes, |
|
in_channels, |
|
stacked_convs=4, |
|
conv_cfg=None, |
|
norm_cfg=None, |
|
pre_anchor_topk=50, |
|
bbox_thr=0.6, |
|
gamma=2.0, |
|
alpha=0.5, |
|
**kwargs): |
|
super(FreeAnchorRetinaHead, |
|
self).__init__(num_classes, in_channels, stacked_convs, conv_cfg, |
|
norm_cfg, **kwargs) |
|
|
|
self.pre_anchor_topk = pre_anchor_topk |
|
self.bbox_thr = bbox_thr |
|
self.gamma = gamma |
|
self.alpha = alpha |
|
|
|
def loss(self, |
|
cls_scores, |
|
bbox_preds, |
|
gt_bboxes, |
|
gt_labels, |
|
img_metas, |
|
gt_bboxes_ignore=None): |
|
"""Compute losses of the head. |
|
|
|
Args: |
|
cls_scores (list[Tensor]): Box scores for each scale level |
|
Has shape (N, num_anchors * num_classes, H, W) |
|
bbox_preds (list[Tensor]): Box energies / deltas for each scale |
|
level with shape (N, num_anchors * 4, H, W) |
|
gt_bboxes (list[Tensor]): each item are the truth boxes for each |
|
image in [tl_x, tl_y, br_x, br_y] format. |
|
gt_labels (list[Tensor]): class indices corresponding to each box |
|
img_metas (list[dict]): Meta information of each image, e.g., |
|
image size, scaling factor, etc. |
|
gt_bboxes_ignore (None | list[Tensor]): specify which bounding |
|
boxes can be ignored when computing the loss. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] |
|
assert len(featmap_sizes) == len(self.anchor_generator.base_anchors) |
|
|
|
anchor_list, _ = self.get_anchors(featmap_sizes, img_metas) |
|
anchors = [torch.cat(anchor) for anchor in anchor_list] |
|
|
|
|
|
cls_scores = [ |
|
cls.permute(0, 2, 3, |
|
1).reshape(cls.size(0), -1, self.cls_out_channels) |
|
for cls in cls_scores |
|
] |
|
bbox_preds = [ |
|
bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4) |
|
for bbox_pred in bbox_preds |
|
] |
|
cls_scores = torch.cat(cls_scores, dim=1) |
|
bbox_preds = torch.cat(bbox_preds, dim=1) |
|
|
|
cls_prob = torch.sigmoid(cls_scores) |
|
box_prob = [] |
|
num_pos = 0 |
|
positive_losses = [] |
|
for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_, |
|
bbox_preds_) in enumerate( |
|
zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)): |
|
|
|
with torch.no_grad(): |
|
if len(gt_bboxes_) == 0: |
|
image_box_prob = torch.zeros( |
|
anchors_.size(0), |
|
self.cls_out_channels).type_as(bbox_preds_) |
|
else: |
|
|
|
pred_boxes = self.bbox_coder.decode(anchors_, bbox_preds_) |
|
|
|
|
|
object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes) |
|
|
|
|
|
t1 = self.bbox_thr |
|
t2 = object_box_iou.max( |
|
dim=1, keepdim=True).values.clamp(min=t1 + 1e-12) |
|
object_box_prob = ((object_box_iou - t1) / |
|
(t2 - t1)).clamp( |
|
min=0, max=1) |
|
|
|
|
|
num_obj = gt_labels_.size(0) |
|
indices = torch.stack([ |
|
torch.arange(num_obj).type_as(gt_labels_), gt_labels_ |
|
], |
|
dim=0) |
|
object_cls_box_prob = torch.sparse_coo_tensor( |
|
indices, object_box_prob) |
|
|
|
|
|
""" |
|
from "start" to "end" implement: |
|
image_box_iou = torch.sparse.max(object_cls_box_prob, |
|
dim=0).t() |
|
|
|
""" |
|
|
|
box_cls_prob = torch.sparse.sum( |
|
object_cls_box_prob, dim=0).to_dense() |
|
|
|
indices = torch.nonzero(box_cls_prob, as_tuple=False).t_() |
|
if indices.numel() == 0: |
|
image_box_prob = torch.zeros( |
|
anchors_.size(0), |
|
self.cls_out_channels).type_as(object_box_prob) |
|
else: |
|
nonzero_box_prob = torch.where( |
|
(gt_labels_.unsqueeze(dim=-1) == indices[0]), |
|
object_box_prob[:, indices[1]], |
|
torch.tensor([ |
|
0 |
|
]).type_as(object_box_prob)).max(dim=0).values |
|
|
|
|
|
image_box_prob = torch.sparse_coo_tensor( |
|
indices.flip([0]), |
|
nonzero_box_prob, |
|
size=(anchors_.size(0), |
|
self.cls_out_channels)).to_dense() |
|
|
|
|
|
box_prob.append(image_box_prob) |
|
|
|
|
|
match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_) |
|
_, matched = torch.topk( |
|
match_quality_matrix, |
|
self.pre_anchor_topk, |
|
dim=1, |
|
sorted=False) |
|
del match_quality_matrix |
|
|
|
|
|
matched_cls_prob = torch.gather( |
|
cls_prob_[matched], 2, |
|
gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk, |
|
1)).squeeze(2) |
|
|
|
|
|
matched_anchors = anchors_[matched] |
|
matched_object_targets = self.bbox_coder.encode( |
|
matched_anchors, |
|
gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors)) |
|
loss_bbox = self.loss_bbox( |
|
bbox_preds_[matched], |
|
matched_object_targets, |
|
reduction_override='none').sum(-1) |
|
matched_box_prob = torch.exp(-loss_bbox) |
|
|
|
|
|
num_pos += len(gt_bboxes_) |
|
positive_losses.append( |
|
self.positive_bag_loss(matched_cls_prob, matched_box_prob)) |
|
positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos) |
|
|
|
|
|
box_prob = torch.stack(box_prob, dim=0) |
|
|
|
|
|
|
|
negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max( |
|
1, num_pos * self.pre_anchor_topk) |
|
|
|
|
|
|
|
if num_pos == 0: |
|
positive_loss = bbox_preds.sum() * 0 |
|
|
|
losses = { |
|
'positive_bag_loss': positive_loss, |
|
'negative_bag_loss': negative_loss |
|
} |
|
return losses |
|
|
|
def positive_bag_loss(self, matched_cls_prob, matched_box_prob): |
|
"""Compute positive bag loss. |
|
|
|
:math:`-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )`. |
|
|
|
:math:`P_{ij}^{cls}`: matched_cls_prob, classification probability of matched samples. |
|
|
|
:math:`P_{ij}^{loc}`: matched_box_prob, box probability of matched samples. |
|
|
|
Args: |
|
matched_cls_prob (Tensor): Classification probabilty of matched |
|
samples in shape (num_gt, pre_anchor_topk). |
|
matched_box_prob (Tensor): BBox probability of matched samples, |
|
in shape (num_gt, pre_anchor_topk). |
|
|
|
Returns: |
|
Tensor: Positive bag loss in shape (num_gt,). |
|
""" |
|
|
|
matched_prob = matched_cls_prob * matched_box_prob |
|
weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None) |
|
weight /= weight.sum(dim=1).unsqueeze(dim=-1) |
|
bag_prob = (weight * matched_prob).sum(dim=1) |
|
|
|
return self.alpha * F.binary_cross_entropy( |
|
bag_prob, torch.ones_like(bag_prob), reduction='none') |
|
|
|
def negative_bag_loss(self, cls_prob, box_prob): |
|
"""Compute negative bag loss. |
|
|
|
:math:`FL((1 - P_{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}))`. |
|
|
|
:math:`P_{a_{j} \in A_{+}}`: Box_probability of matched samples. |
|
|
|
:math:`P_{j}^{bg}`: Classification probability of negative samples. |
|
|
|
Args: |
|
cls_prob (Tensor): Classification probability, in shape |
|
(num_img, num_anchors, num_classes). |
|
box_prob (Tensor): Box probability, in shape |
|
(num_img, num_anchors, num_classes). |
|
|
|
Returns: |
|
Tensor: Negative bag loss in shape (num_img, num_anchors, num_classes). |
|
""" |
|
prob = cls_prob * (1 - box_prob) |
|
|
|
|
|
prob = prob.clamp(min=EPS, max=1 - EPS) |
|
negative_bag_loss = prob**self.gamma * F.binary_cross_entropy( |
|
prob, torch.zeros_like(prob), reduction='none') |
|
return (1 - self.alpha) * negative_bag_loss |
|
|