|
|
|
from typing import List
|
|
import fvcore.nn.weight_init as weight_init
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.config import configurable
|
|
from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm
|
|
from detectron2.layers.wrappers import move_device_like
|
|
from detectron2.structures import Instances
|
|
from detectron2.utils.events import get_event_storage
|
|
from detectron2.utils.registry import Registry
|
|
|
|
__all__ = [
|
|
"BaseMaskRCNNHead",
|
|
"MaskRCNNConvUpsampleHead",
|
|
"build_mask_head",
|
|
"ROI_MASK_HEAD_REGISTRY",
|
|
]
|
|
|
|
|
|
ROI_MASK_HEAD_REGISTRY = Registry("ROI_MASK_HEAD")
|
|
ROI_MASK_HEAD_REGISTRY.__doc__ = """
|
|
Registry for mask heads, which predicts instance masks given
|
|
per-region features.
|
|
|
|
The registered object will be called with `obj(cfg, input_shape)`.
|
|
"""
|
|
|
|
|
|
@torch.jit.unused
|
|
def mask_rcnn_loss(pred_mask_logits: torch.Tensor, instances: List[Instances], vis_period: int = 0):
|
|
"""
|
|
Compute the mask prediction loss defined in the Mask R-CNN paper.
|
|
|
|
Args:
|
|
pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
|
|
for class-specific or class-agnostic, where B is the total number of predicted masks
|
|
in all images, C is the number of foreground classes, and Hmask, Wmask are the height
|
|
and width of the mask predictions. The values are logits.
|
|
instances (list[Instances]): A list of N Instances, where N is the number of images
|
|
in the batch. These instances are in 1:1
|
|
correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask,
|
|
...) associated with each instance are stored in fields.
|
|
vis_period (int): the period (in steps) to dump visualization.
|
|
|
|
Returns:
|
|
mask_loss (Tensor): A scalar tensor containing the loss.
|
|
"""
|
|
cls_agnostic_mask = pred_mask_logits.size(1) == 1
|
|
total_num_masks = pred_mask_logits.size(0)
|
|
mask_side_len = pred_mask_logits.size(2)
|
|
assert pred_mask_logits.size(2) == pred_mask_logits.size(3), "Mask prediction must be square!"
|
|
|
|
gt_classes = []
|
|
gt_masks = []
|
|
for instances_per_image in instances:
|
|
if len(instances_per_image) == 0:
|
|
continue
|
|
if not cls_agnostic_mask:
|
|
gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64)
|
|
gt_classes.append(gt_classes_per_image)
|
|
|
|
gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize(
|
|
instances_per_image.proposal_boxes.tensor, mask_side_len
|
|
).to(device=pred_mask_logits.device)
|
|
|
|
gt_masks.append(gt_masks_per_image)
|
|
|
|
if len(gt_masks) == 0:
|
|
return pred_mask_logits.sum() * 0
|
|
|
|
gt_masks = cat(gt_masks, dim=0)
|
|
|
|
if cls_agnostic_mask:
|
|
pred_mask_logits = pred_mask_logits[:, 0]
|
|
else:
|
|
indices = torch.arange(total_num_masks)
|
|
gt_classes = cat(gt_classes, dim=0)
|
|
pred_mask_logits = pred_mask_logits[indices, gt_classes]
|
|
|
|
if gt_masks.dtype == torch.bool:
|
|
gt_masks_bool = gt_masks
|
|
else:
|
|
|
|
gt_masks_bool = gt_masks > 0.5
|
|
gt_masks = gt_masks.to(dtype=torch.float32)
|
|
|
|
|
|
mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool
|
|
mask_accuracy = 1 - (mask_incorrect.sum().item() / max(mask_incorrect.numel(), 1.0))
|
|
num_positive = gt_masks_bool.sum().item()
|
|
false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max(
|
|
gt_masks_bool.numel() - num_positive, 1.0
|
|
)
|
|
false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(num_positive, 1.0)
|
|
|
|
storage = get_event_storage()
|
|
storage.put_scalar("mask_rcnn/accuracy", mask_accuracy)
|
|
storage.put_scalar("mask_rcnn/false_positive", false_positive)
|
|
storage.put_scalar("mask_rcnn/false_negative", false_negative)
|
|
if vis_period > 0 and storage.iter % vis_period == 0:
|
|
pred_masks = pred_mask_logits.sigmoid()
|
|
vis_masks = torch.cat([pred_masks, gt_masks], axis=2)
|
|
name = "Left: mask prediction; Right: mask GT"
|
|
for idx, vis_mask in enumerate(vis_masks):
|
|
vis_mask = torch.stack([vis_mask] * 3, axis=0)
|
|
storage.put_image(name + f" ({idx})", vis_mask)
|
|
|
|
mask_loss = F.binary_cross_entropy_with_logits(pred_mask_logits, gt_masks, reduction="mean")
|
|
return mask_loss
|
|
|
|
|
|
def mask_rcnn_inference(pred_mask_logits: torch.Tensor, pred_instances: List[Instances]):
|
|
"""
|
|
Convert pred_mask_logits to estimated foreground probability masks while also
|
|
extracting only the masks for the predicted classes in pred_instances. For each
|
|
predicted box, the mask of the same class is attached to the instance by adding a
|
|
new "pred_masks" field to pred_instances.
|
|
|
|
Args:
|
|
pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask)
|
|
for class-specific or class-agnostic, where B is the total number of predicted masks
|
|
in all images, C is the number of foreground classes, and Hmask, Wmask are the height
|
|
and width of the mask predictions. The values are logits.
|
|
pred_instances (list[Instances]): A list of N Instances, where N is the number of images
|
|
in the batch. Each Instances must have field "pred_classes".
|
|
|
|
Returns:
|
|
None. pred_instances will contain an extra "pred_masks" field storing a mask of size (Hmask,
|
|
Wmask) for predicted class. Note that the masks are returned as a soft (non-quantized)
|
|
masks the resolution predicted by the network; post-processing steps, such as resizing
|
|
the predicted masks to the original image resolution and/or binarizing them, is left
|
|
to the caller.
|
|
"""
|
|
cls_agnostic_mask = pred_mask_logits.size(1) == 1
|
|
|
|
if cls_agnostic_mask:
|
|
mask_probs_pred = pred_mask_logits.sigmoid()
|
|
else:
|
|
|
|
num_masks = pred_mask_logits.shape[0]
|
|
class_pred = cat([i.pred_classes for i in pred_instances])
|
|
device = (
|
|
class_pred.device
|
|
if torch.jit.is_scripting()
|
|
else ("cpu" if torch.jit.is_tracing() else class_pred.device)
|
|
)
|
|
indices = move_device_like(torch.arange(num_masks, device=device), class_pred)
|
|
mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
|
|
|
|
|
|
num_boxes_per_image = [len(i) for i in pred_instances]
|
|
mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0)
|
|
|
|
for prob, instances in zip(mask_probs_pred, pred_instances):
|
|
instances.pred_masks = prob
|
|
|
|
|
|
class BaseMaskRCNNHead(nn.Module):
|
|
"""
|
|
Implement the basic Mask R-CNN losses and inference logic described in :paper:`Mask R-CNN`
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(self, *, loss_weight: float = 1.0, vis_period: int = 0):
|
|
"""
|
|
NOTE: this interface is experimental.
|
|
|
|
Args:
|
|
loss_weight (float): multiplier of the loss
|
|
vis_period (int): visualization period
|
|
"""
|
|
super().__init__()
|
|
self.vis_period = vis_period
|
|
self.loss_weight = loss_weight
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg, input_shape):
|
|
return {"vis_period": cfg.VIS_PERIOD}
|
|
|
|
def forward(self, x, instances: List[Instances]):
|
|
"""
|
|
Args:
|
|
x: input region feature(s) provided by :class:`ROIHeads`.
|
|
instances (list[Instances]): contains the boxes & labels corresponding
|
|
to the input features.
|
|
Exact format is up to its caller to decide.
|
|
Typically, this is the foreground instances in training, with
|
|
"proposal_boxes" field and other gt annotations.
|
|
In inference, it contains boxes that are already predicted.
|
|
|
|
Returns:
|
|
A dict of losses in training. The predicted "instances" in inference.
|
|
"""
|
|
x = self.layers(x)
|
|
if self.training:
|
|
return {"loss_mask": mask_rcnn_loss(x, instances, self.vis_period) * self.loss_weight}
|
|
else:
|
|
mask_rcnn_inference(x, instances)
|
|
return instances
|
|
|
|
def layers(self, x):
|
|
"""
|
|
Neural network layers that makes predictions from input features.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
@ROI_MASK_HEAD_REGISTRY.register()
|
|
class MaskRCNNConvUpsampleHead(BaseMaskRCNNHead, nn.Sequential):
|
|
"""
|
|
A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`).
|
|
Predictions are made with a final 1x1 conv layer.
|
|
"""
|
|
|
|
@configurable
|
|
def __init__(self, input_shape: ShapeSpec, *, num_classes, conv_dims, conv_norm="", **kwargs):
|
|
"""
|
|
NOTE: this interface is experimental.
|
|
|
|
Args:
|
|
input_shape (ShapeSpec): shape of the input feature
|
|
num_classes (int): the number of foreground classes (i.e. background is not
|
|
included). 1 if using class agnostic prediction.
|
|
conv_dims (list[int]): a list of N>0 integers representing the output dimensions
|
|
of N-1 conv layers and the last upsample layer.
|
|
conv_norm (str or callable): normalization for the conv layers.
|
|
See :func:`detectron2.layers.get_norm` for supported types.
|
|
"""
|
|
super().__init__(**kwargs)
|
|
assert len(conv_dims) >= 1, "conv_dims have to be non-empty!"
|
|
|
|
self.conv_norm_relus = []
|
|
|
|
cur_channels = input_shape.channels
|
|
for k, conv_dim in enumerate(conv_dims[:-1]):
|
|
conv = Conv2d(
|
|
cur_channels,
|
|
conv_dim,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
bias=not conv_norm,
|
|
norm=get_norm(conv_norm, conv_dim),
|
|
activation=nn.ReLU(),
|
|
)
|
|
self.add_module("mask_fcn{}".format(k + 1), conv)
|
|
self.conv_norm_relus.append(conv)
|
|
cur_channels = conv_dim
|
|
|
|
self.deconv = ConvTranspose2d(
|
|
cur_channels, conv_dims[-1], kernel_size=2, stride=2, padding=0
|
|
)
|
|
self.add_module("deconv_relu", nn.ReLU())
|
|
cur_channels = conv_dims[-1]
|
|
|
|
self.predictor = Conv2d(cur_channels, num_classes, kernel_size=1, stride=1, padding=0)
|
|
|
|
for layer in self.conv_norm_relus + [self.deconv]:
|
|
weight_init.c2_msra_fill(layer)
|
|
|
|
nn.init.normal_(self.predictor.weight, std=0.001)
|
|
if self.predictor.bias is not None:
|
|
nn.init.constant_(self.predictor.bias, 0)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg, input_shape):
|
|
ret = super().from_config(cfg, input_shape)
|
|
conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM
|
|
num_conv = cfg.MODEL.ROI_MASK_HEAD.NUM_CONV
|
|
ret.update(
|
|
conv_dims=[conv_dim] * (num_conv + 1),
|
|
conv_norm=cfg.MODEL.ROI_MASK_HEAD.NORM,
|
|
input_shape=input_shape,
|
|
)
|
|
if cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK:
|
|
ret["num_classes"] = 1
|
|
else:
|
|
ret["num_classes"] = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
|
return ret
|
|
|
|
def layers(self, x):
|
|
for layer in self:
|
|
x = layer(x)
|
|
return x
|
|
|
|
|
|
def build_mask_head(cfg, input_shape):
|
|
"""
|
|
Build a mask head defined by `cfg.MODEL.ROI_MASK_HEAD.NAME`.
|
|
"""
|
|
name = cfg.MODEL.ROI_MASK_HEAD.NAME
|
|
return ROI_MASK_HEAD_REGISTRY.get(name)(cfg, input_shape)
|
|
|