|
|
|
|
|
import logging
|
|
from typing import List, Optional, Tuple
|
|
import torch
|
|
from fvcore.nn import sigmoid_focal_loss_jit
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.layers import ShapeSpec, batched_nms
|
|
from detectron2.structures import Boxes, ImageList, Instances, pairwise_point_box_distance
|
|
from detectron2.utils.events import get_event_storage
|
|
|
|
from ..anchor_generator import DefaultAnchorGenerator
|
|
from ..backbone import Backbone
|
|
from ..box_regression import Box2BoxTransformLinear, _dense_box_regression_loss
|
|
from .dense_detector import DenseDetector
|
|
from .retinanet import RetinaNetHead
|
|
|
|
__all__ = ["FCOS"]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FCOS(DenseDetector):
|
|
"""
|
|
Implement FCOS in :paper:`fcos`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
backbone: Backbone,
|
|
head: nn.Module,
|
|
head_in_features: Optional[List[str]] = None,
|
|
box2box_transform=None,
|
|
num_classes,
|
|
center_sampling_radius: float = 1.5,
|
|
focal_loss_alpha=0.25,
|
|
focal_loss_gamma=2.0,
|
|
test_score_thresh=0.2,
|
|
test_topk_candidates=1000,
|
|
test_nms_thresh=0.6,
|
|
max_detections_per_image=100,
|
|
pixel_mean,
|
|
pixel_std,
|
|
):
|
|
"""
|
|
Args:
|
|
center_sampling_radius: radius of the "center" of a groundtruth box,
|
|
within which all anchor points are labeled positive.
|
|
Other arguments mean the same as in :class:`RetinaNet`.
|
|
"""
|
|
super().__init__(
|
|
backbone, head, head_in_features, pixel_mean=pixel_mean, pixel_std=pixel_std
|
|
)
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
|
|
feature_shapes = backbone.output_shape()
|
|
fpn_strides = [feature_shapes[k].stride for k in self.head_in_features]
|
|
self.anchor_generator = DefaultAnchorGenerator(
|
|
sizes=[[k] for k in fpn_strides], aspect_ratios=[1.0], strides=fpn_strides
|
|
)
|
|
|
|
|
|
|
|
if box2box_transform is None:
|
|
box2box_transform = Box2BoxTransformLinear(normalize_by_size=True)
|
|
self.box2box_transform = box2box_transform
|
|
|
|
self.center_sampling_radius = float(center_sampling_radius)
|
|
|
|
|
|
self.focal_loss_alpha = focal_loss_alpha
|
|
self.focal_loss_gamma = focal_loss_gamma
|
|
|
|
|
|
self.test_score_thresh = test_score_thresh
|
|
self.test_topk_candidates = test_topk_candidates
|
|
self.test_nms_thresh = test_nms_thresh
|
|
self.max_detections_per_image = max_detections_per_image
|
|
|
|
def forward_training(self, images, features, predictions, gt_instances):
|
|
|
|
pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions(
|
|
predictions, [self.num_classes, 4, 1]
|
|
)
|
|
anchors = self.anchor_generator(features)
|
|
gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances)
|
|
return self.losses(
|
|
anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def _match_anchors(self, gt_boxes: Boxes, anchors: List[Boxes]):
|
|
"""
|
|
Match ground-truth boxes to a set of multi-level anchors.
|
|
|
|
Args:
|
|
gt_boxes: Ground-truth boxes from instances of an image.
|
|
anchors: List of anchors for each feature map (of different scales).
|
|
|
|
Returns:
|
|
torch.Tensor
|
|
A tensor of shape `(M, R)`, given `M` ground-truth boxes and total
|
|
`R` anchor points from all feature levels, indicating the quality
|
|
of match between m-th box and r-th anchor. Higher value indicates
|
|
better match.
|
|
"""
|
|
|
|
|
|
num_anchors_per_level = [len(x) for x in anchors]
|
|
anchors = Boxes.cat(anchors)
|
|
anchor_centers = anchors.get_centers()
|
|
anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0]
|
|
|
|
lower_bound = anchor_sizes * 4
|
|
lower_bound[: num_anchors_per_level[0]] = 0
|
|
upper_bound = anchor_sizes * 8
|
|
upper_bound[-num_anchors_per_level[-1] :] = float("inf")
|
|
|
|
gt_centers = gt_boxes.get_centers()
|
|
|
|
|
|
|
|
center_dists = (anchor_centers[None, :, :] - gt_centers[:, None, :]).abs_()
|
|
sampling_regions = self.center_sampling_radius * anchor_sizes[None, :]
|
|
|
|
match_quality_matrix = center_dists.max(dim=2).values < sampling_regions
|
|
|
|
pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_boxes)
|
|
pairwise_dist = pairwise_dist.permute(1, 0, 2)
|
|
|
|
|
|
match_quality_matrix &= pairwise_dist.min(dim=2).values > 0
|
|
|
|
|
|
|
|
pairwise_dist = pairwise_dist.max(dim=2).values
|
|
match_quality_matrix &= (pairwise_dist > lower_bound[None, :]) & (
|
|
pairwise_dist < upper_bound[None, :]
|
|
)
|
|
|
|
gt_areas = gt_boxes.area()
|
|
|
|
match_quality_matrix = match_quality_matrix.to(torch.float32)
|
|
match_quality_matrix *= 1e8 - gt_areas[:, None]
|
|
return match_quality_matrix
|
|
|
|
@torch.no_grad()
|
|
def label_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]):
|
|
"""
|
|
Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS
|
|
anchor matching rule.
|
|
|
|
Unlike RetinaNet, there are no ignored anchors.
|
|
"""
|
|
|
|
gt_labels, matched_gt_boxes = [], []
|
|
|
|
for inst in gt_instances:
|
|
if len(inst) > 0:
|
|
match_quality_matrix = self._match_anchors(inst.gt_boxes, anchors)
|
|
|
|
|
|
|
|
|
|
match_quality, matched_idxs = match_quality_matrix.max(dim=0)
|
|
matched_idxs[match_quality < 1e-5] = -1
|
|
|
|
matched_gt_boxes_i = inst.gt_boxes.tensor[matched_idxs.clip(min=0)]
|
|
gt_labels_i = inst.gt_classes[matched_idxs.clip(min=0)]
|
|
|
|
|
|
gt_labels_i[matched_idxs < 0] = self.num_classes
|
|
else:
|
|
matched_gt_boxes_i = torch.zeros_like(Boxes.cat(anchors).tensor)
|
|
gt_labels_i = torch.full(
|
|
(len(matched_gt_boxes_i),),
|
|
fill_value=self.num_classes,
|
|
dtype=torch.long,
|
|
device=matched_gt_boxes_i.device,
|
|
)
|
|
|
|
gt_labels.append(gt_labels_i)
|
|
matched_gt_boxes.append(matched_gt_boxes_i)
|
|
|
|
return gt_labels, matched_gt_boxes
|
|
|
|
def losses(
|
|
self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
|
|
):
|
|
"""
|
|
This method is almost identical to :meth:`RetinaNet.losses`, with an extra
|
|
"loss_centerness" in the returned dict.
|
|
"""
|
|
num_images = len(gt_labels)
|
|
gt_labels = torch.stack(gt_labels)
|
|
|
|
pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
|
|
num_pos_anchors = pos_mask.sum().item()
|
|
get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images)
|
|
normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 300)
|
|
|
|
|
|
gt_labels_target = F.one_hot(gt_labels, num_classes=self.num_classes + 1)[
|
|
:, :, :-1
|
|
]
|
|
loss_cls = sigmoid_focal_loss_jit(
|
|
torch.cat(pred_logits, dim=1),
|
|
gt_labels_target.to(pred_logits[0].dtype),
|
|
alpha=self.focal_loss_alpha,
|
|
gamma=self.focal_loss_gamma,
|
|
reduction="sum",
|
|
)
|
|
|
|
loss_box_reg = _dense_box_regression_loss(
|
|
anchors,
|
|
self.box2box_transform,
|
|
pred_anchor_deltas,
|
|
gt_boxes,
|
|
pos_mask,
|
|
box_reg_loss_type="giou",
|
|
)
|
|
|
|
ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes)
|
|
pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2)
|
|
ctrness_loss = F.binary_cross_entropy_with_logits(
|
|
pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum"
|
|
)
|
|
return {
|
|
"loss_fcos_cls": loss_cls / normalizer,
|
|
"loss_fcos_loc": loss_box_reg / normalizer,
|
|
"loss_fcos_ctr": ctrness_loss / normalizer,
|
|
}
|
|
|
|
def compute_ctrness_targets(self, anchors: List[Boxes], gt_boxes: List[torch.Tensor]):
|
|
anchors = Boxes.cat(anchors).tensor
|
|
reg_targets = [self.box2box_transform.get_deltas(anchors, m) for m in gt_boxes]
|
|
reg_targets = torch.stack(reg_targets, dim=0)
|
|
if len(reg_targets) == 0:
|
|
return reg_targets.new_zeros(len(reg_targets))
|
|
left_right = reg_targets[:, :, [0, 2]]
|
|
top_bottom = reg_targets[:, :, [1, 3]]
|
|
ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
|
|
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]
|
|
)
|
|
return torch.sqrt(ctrness)
|
|
|
|
def forward_inference(
|
|
self,
|
|
images: ImageList,
|
|
features: List[torch.Tensor],
|
|
predictions: List[List[torch.Tensor]],
|
|
):
|
|
pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions(
|
|
predictions, [self.num_classes, 4, 1]
|
|
)
|
|
anchors = self.anchor_generator(features)
|
|
|
|
results: List[Instances] = []
|
|
for img_idx, image_size in enumerate(images.image_sizes):
|
|
scores_per_image = [
|
|
|
|
|
|
torch.sqrt(x[img_idx].sigmoid_() * y[img_idx].sigmoid_())
|
|
for x, y in zip(pred_logits, pred_centerness)
|
|
]
|
|
deltas_per_image = [x[img_idx] for x in pred_anchor_deltas]
|
|
results_per_image = self.inference_single_image(
|
|
anchors, scores_per_image, deltas_per_image, image_size
|
|
)
|
|
results.append(results_per_image)
|
|
return results
|
|
|
|
def inference_single_image(
|
|
self,
|
|
anchors: List[Boxes],
|
|
box_cls: List[torch.Tensor],
|
|
box_delta: List[torch.Tensor],
|
|
image_size: Tuple[int, int],
|
|
):
|
|
"""
|
|
Identical to :meth:`RetinaNet.inference_single_image.
|
|
"""
|
|
pred = self._decode_multi_level_predictions(
|
|
anchors,
|
|
box_cls,
|
|
box_delta,
|
|
self.test_score_thresh,
|
|
self.test_topk_candidates,
|
|
image_size,
|
|
)
|
|
keep = batched_nms(
|
|
pred.pred_boxes.tensor, pred.scores, pred.pred_classes, self.test_nms_thresh
|
|
)
|
|
return pred[keep[: self.max_detections_per_image]]
|
|
|
|
|
|
class FCOSHead(RetinaNetHead):
|
|
"""
|
|
The head used in :paper:`fcos`. It adds an additional centerness
|
|
prediction branch on top of :class:`RetinaNetHead`.
|
|
"""
|
|
|
|
def __init__(self, *, input_shape: List[ShapeSpec], conv_dims: List[int], **kwargs):
|
|
super().__init__(input_shape=input_shape, conv_dims=conv_dims, num_anchors=1, **kwargs)
|
|
|
|
|
|
self._num_features = len(input_shape)
|
|
self.ctrness = nn.Conv2d(conv_dims[-1], 1, kernel_size=3, stride=1, padding=1)
|
|
torch.nn.init.normal_(self.ctrness.weight, std=0.01)
|
|
torch.nn.init.constant_(self.ctrness.bias, 0)
|
|
|
|
def forward(self, features):
|
|
assert len(features) == self._num_features
|
|
logits = []
|
|
bbox_reg = []
|
|
ctrness = []
|
|
for feature in features:
|
|
logits.append(self.cls_score(self.cls_subnet(feature)))
|
|
bbox_feature = self.bbox_subnet(feature)
|
|
bbox_reg.append(self.bbox_pred(bbox_feature))
|
|
ctrness.append(self.ctrness(bbox_feature))
|
|
return logits, bbox_reg, ctrness
|
|
|