|
|
|
from typing import List
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.config import configurable
|
|
from detectron2.layers import Conv2d, ConvTranspose2d, cat, interpolate
|
|
from detectron2.structures import Instances, heatmaps_to_keypoints
|
|
from detectron2.utils.events import get_event_storage
|
|
from detectron2.utils.registry import Registry
|
|
|
|
_TOTAL_SKIPPED = 0
|
|
|
|
|
|
__all__ = [
|
|
"ROI_KEYPOINT_HEAD_REGISTRY",
|
|
"build_keypoint_head",
|
|
"BaseKeypointRCNNHead",
|
|
"KRCNNConvDeconvUpsampleHead",
|
|
]
|
|
|
|
|
|
ROI_KEYPOINT_HEAD_REGISTRY = Registry("ROI_KEYPOINT_HEAD")
|
|
ROI_KEYPOINT_HEAD_REGISTRY.__doc__ = """
|
|
Registry for keypoint heads, which make keypoint predictions from per-region features.
|
|
|
|
The registered object will be called with `obj(cfg, input_shape)`.
|
|
"""
|
|
|
|
|
|
def build_keypoint_head(cfg, input_shape):
|
|
"""
|
|
Build a keypoint head from `cfg.MODEL.ROI_KEYPOINT_HEAD.NAME`.
|
|
"""
|
|
name = cfg.MODEL.ROI_KEYPOINT_HEAD.NAME
|
|
return ROI_KEYPOINT_HEAD_REGISTRY.get(name)(cfg, input_shape)
|
|
|
|
|
|
def keypoint_rcnn_loss(pred_keypoint_logits, instances, normalizer):
|
|
"""
|
|
Arguments:
|
|
pred_keypoint_logits (Tensor): A tensor of shape (N, K, S, S) where N is the total number
|
|
of instances in the batch, K is the number of keypoints, and S is the side length
|
|
of the keypoint heatmap. The values are spatial logits.
|
|
instances (list[Instances]): A list of M Instances, where M is the batch size.
|
|
These instances are predictions from the model
|
|
that are in 1:1 correspondence with pred_keypoint_logits.
|
|
Each Instances should contain a `gt_keypoints` field containing a `structures.Keypoint`
|
|
instance.
|
|
normalizer (float): Normalize the loss by this amount.
|
|
If not specified, we normalize by the number of visible keypoints in the minibatch.
|
|
|
|
Returns a scalar tensor containing the loss.
|
|
"""
|
|
heatmaps = []
|
|
valid = []
|
|
|
|
keypoint_side_len = pred_keypoint_logits.shape[2]
|
|
for instances_per_image in instances:
|
|
if len(instances_per_image) == 0:
|
|
continue
|
|
keypoints = instances_per_image.gt_keypoints
|
|
heatmaps_per_image, valid_per_image = keypoints.to_heatmap(
|
|
instances_per_image.proposal_boxes.tensor, keypoint_side_len
|
|
)
|
|
heatmaps.append(heatmaps_per_image.view(-1))
|
|
valid.append(valid_per_image.view(-1))
|
|
|
|
if len(heatmaps):
|
|
keypoint_targets = cat(heatmaps, dim=0)
|
|
valid = cat(valid, dim=0).to(dtype=torch.uint8)
|
|
valid = torch.nonzero(valid).squeeze(1)
|
|
|
|
|
|
|
|
if len(heatmaps) == 0 or valid.numel() == 0:
|
|
global _TOTAL_SKIPPED
|
|
_TOTAL_SKIPPED += 1
|
|
storage = get_event_storage()
|
|
storage.put_scalar("kpts_num_skipped_batches", _TOTAL_SKIPPED, smoothing_hint=False)
|
|
return pred_keypoint_logits.sum() * 0
|
|
|
|
N, K, H, W = pred_keypoint_logits.shape
|
|
pred_keypoint_logits = pred_keypoint_logits.view(N * K, H * W)
|
|
|
|
keypoint_loss = F.cross_entropy(
|
|
pred_keypoint_logits[valid], keypoint_targets[valid], reduction="sum"
|
|
)
|
|
|
|
|
|
if normalizer is None:
|
|
normalizer = valid.numel()
|
|
keypoint_loss /= normalizer
|
|
|
|
return keypoint_loss
|
|
|
|
|
|
def keypoint_rcnn_inference(pred_keypoint_logits: torch.Tensor, pred_instances: List[Instances]):
|
|
"""
|
|
Post process each predicted keypoint heatmap in `pred_keypoint_logits` into (x, y, score)
|
|
and add it to the `pred_instances` as a `pred_keypoints` field.
|
|
|
|
Args:
|
|
pred_keypoint_logits (Tensor): A tensor of shape (R, K, S, S) where R is the total number
|
|
of instances in the batch, K is the number of keypoints, and S is the side length of
|
|
the keypoint heatmap. The values are spatial logits.
|
|
pred_instances (list[Instances]): A list of N Instances, where N is the number of images.
|
|
|
|
Returns:
|
|
None. Each element in pred_instances will contain extra "pred_keypoints" and
|
|
"pred_keypoint_heatmaps" fields. "pred_keypoints" is a tensor of shape
|
|
(#instance, K, 3) where the last dimension corresponds to (x, y, score).
|
|
The scores are larger than 0. "pred_keypoint_heatmaps" contains the raw
|
|
keypoint logits as passed to this function.
|
|
"""
|
|
|
|
bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0)
|
|
|
|
pred_keypoint_logits = pred_keypoint_logits.detach()
|
|
keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits, bboxes_flat.detach())
|
|
num_instances_per_image = [len(i) for i in pred_instances]
|
|
keypoint_results = keypoint_results[:, :, [0, 1, 3]].split(num_instances_per_image, dim=0)
|
|
heatmap_results = pred_keypoint_logits.split(num_instances_per_image, dim=0)
|
|
|
|
for keypoint_results_per_image, heatmap_results_per_image, instances_per_image in zip(
|
|
keypoint_results, heatmap_results, pred_instances
|
|
):
|
|
|
|
|
|
instances_per_image.pred_keypoints = keypoint_results_per_image
|
|
instances_per_image.pred_keypoint_heatmaps = heatmap_results_per_image
|
|
|
|
|
|
class BaseKeypointRCNNHead(nn.Module):
|
|
"""
|
|
Implement the basic Keypoint R-CNN losses and inference logic described in
|
|
Sec. 5 of :paper:`Mask R-CNN`.
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(self, *, num_keypoints, loss_weight=1.0, loss_normalizer=1.0):
|
|
"""
|
|
NOTE: this interface is experimental.
|
|
|
|
Args:
|
|
num_keypoints (int): number of keypoints to predict
|
|
loss_weight (float): weight to multiple on the keypoint loss
|
|
loss_normalizer (float or str):
|
|
If float, divide the loss by `loss_normalizer * #images`.
|
|
If 'visible', the loss is normalized by the total number of
|
|
visible keypoints across images.
|
|
"""
|
|
super().__init__()
|
|
self.num_keypoints = num_keypoints
|
|
self.loss_weight = loss_weight
|
|
assert loss_normalizer == "visible" or isinstance(loss_normalizer, float), loss_normalizer
|
|
self.loss_normalizer = loss_normalizer
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg, input_shape):
|
|
ret = {
|
|
"loss_weight": cfg.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT,
|
|
"num_keypoints": cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS,
|
|
}
|
|
normalize_by_visible = (
|
|
cfg.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS
|
|
)
|
|
if not normalize_by_visible:
|
|
batch_size_per_image = cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE
|
|
positive_sample_fraction = cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION
|
|
ret["loss_normalizer"] = (
|
|
ret["num_keypoints"] * batch_size_per_image * positive_sample_fraction
|
|
)
|
|
else:
|
|
ret["loss_normalizer"] = "visible"
|
|
return ret
|
|
|
|
def forward(self, x, instances: List[Instances]):
|
|
"""
|
|
Args:
|
|
x: input 4D region feature(s) provided by :class:`ROIHeads`.
|
|
instances (list[Instances]): contains the boxes & labels corresponding
|
|
to the input features.
|
|
Exact format is up to its caller to decide.
|
|
Typically, this is the foreground instances in training, with
|
|
"proposal_boxes" field and other gt annotations.
|
|
In inference, it contains boxes that are already predicted.
|
|
|
|
Returns:
|
|
A dict of losses if in training. The predicted "instances" if in inference.
|
|
"""
|
|
x = self.layers(x)
|
|
if self.training:
|
|
num_images = len(instances)
|
|
normalizer = (
|
|
None if self.loss_normalizer == "visible" else num_images * self.loss_normalizer
|
|
)
|
|
return {
|
|
"loss_keypoint": keypoint_rcnn_loss(x, instances, normalizer=normalizer)
|
|
* self.loss_weight
|
|
}
|
|
else:
|
|
keypoint_rcnn_inference(x, instances)
|
|
return instances
|
|
|
|
def layers(self, x):
|
|
"""
|
|
Neural network layers that makes predictions from regional input features.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
@ROI_KEYPOINT_HEAD_REGISTRY.register()
|
|
class KRCNNConvDeconvUpsampleHead(BaseKeypointRCNNHead, nn.Sequential):
|
|
"""
|
|
A standard keypoint head containing a series of 3x3 convs, followed by
|
|
a transpose convolution and bilinear interpolation for upsampling.
|
|
It is described in Sec. 5 of :paper:`Mask R-CNN`.
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(self, input_shape, *, num_keypoints, conv_dims, **kwargs):
|
|
"""
|
|
NOTE: this interface is experimental.
|
|
|
|
Args:
|
|
input_shape (ShapeSpec): shape of the input feature
|
|
conv_dims: an iterable of output channel counts for each conv in the head
|
|
e.g. (512, 512, 512) for three convs outputting 512 channels.
|
|
"""
|
|
super().__init__(num_keypoints=num_keypoints, **kwargs)
|
|
|
|
|
|
up_scale = 2.0
|
|
in_channels = input_shape.channels
|
|
|
|
for idx, layer_channels in enumerate(conv_dims, 1):
|
|
module = Conv2d(in_channels, layer_channels, 3, stride=1, padding=1)
|
|
self.add_module("conv_fcn{}".format(idx), module)
|
|
self.add_module("conv_fcn_relu{}".format(idx), nn.ReLU())
|
|
in_channels = layer_channels
|
|
|
|
deconv_kernel = 4
|
|
self.score_lowres = ConvTranspose2d(
|
|
in_channels, num_keypoints, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1
|
|
)
|
|
self.up_scale = up_scale
|
|
|
|
for name, param in self.named_parameters():
|
|
if "bias" in name:
|
|
nn.init.constant_(param, 0)
|
|
elif "weight" in name:
|
|
|
|
|
|
nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg, input_shape):
|
|
ret = super().from_config(cfg, input_shape)
|
|
ret["input_shape"] = input_shape
|
|
ret["conv_dims"] = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS
|
|
return ret
|
|
|
|
def layers(self, x):
|
|
for layer in self:
|
|
x = layer(x)
|
|
x = interpolate(x, scale_factor=self.up_scale, mode="bilinear", align_corners=False)
|
|
return x
|
|
|