|
import mmcv |
|
import torch |
|
|
|
from mmdet.models.roi_heads.mask_heads import FCNMaskHead, MaskIoUHead |
|
from .utils import _dummy_bbox_sampling |
|
|
|
|
|
def test_mask_head_loss(): |
|
"""Test mask head loss when mask target is empty.""" |
|
self = FCNMaskHead( |
|
num_convs=1, |
|
roi_feat_size=6, |
|
in_channels=8, |
|
conv_out_channels=8, |
|
num_classes=8) |
|
|
|
|
|
proposal_list = [ |
|
torch.Tensor([[23.6667, 23.8757, 228.6326, 153.8874]]), |
|
] |
|
|
|
gt_bboxes = [ |
|
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), |
|
] |
|
gt_labels = [torch.LongTensor([2])] |
|
sampling_results = _dummy_bbox_sampling(proposal_list, gt_bboxes, |
|
gt_labels) |
|
|
|
|
|
import numpy as np |
|
from mmdet.core import BitmapMasks |
|
dummy_mask = np.random.randint(0, 2, (1, 160, 240), dtype=np.uint8) |
|
gt_masks = [BitmapMasks(dummy_mask, 160, 240)] |
|
|
|
|
|
train_cfg = mmcv.Config(dict(mask_size=12, mask_thr_binary=0.5)) |
|
|
|
|
|
num_sampled = sum(len(res.bboxes) for res in sampling_results) |
|
dummy_feats = torch.rand(num_sampled, 8, 6, 6) |
|
|
|
mask_pred = self.forward(dummy_feats) |
|
mask_targets = self.get_targets(sampling_results, gt_masks, train_cfg) |
|
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) |
|
loss_mask = self.loss(mask_pred, mask_targets, pos_labels) |
|
|
|
onegt_mask_loss = sum(loss_mask['loss_mask']) |
|
assert onegt_mask_loss.item() > 0, 'mask loss should be non-zero' |
|
|
|
|
|
mask_iou_head = MaskIoUHead( |
|
num_convs=1, |
|
num_fcs=1, |
|
roi_feat_size=6, |
|
in_channels=8, |
|
conv_out_channels=8, |
|
fc_out_channels=8, |
|
num_classes=8) |
|
|
|
pos_mask_pred = mask_pred[range(mask_pred.size(0)), pos_labels] |
|
mask_iou_pred = mask_iou_head(dummy_feats, pos_mask_pred) |
|
pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)), pos_labels] |
|
|
|
mask_iou_targets = mask_iou_head.get_targets(sampling_results, gt_masks, |
|
pos_mask_pred, mask_targets, |
|
train_cfg) |
|
loss_mask_iou = mask_iou_head.loss(pos_mask_iou_pred, mask_iou_targets) |
|
onegt_mask_iou_loss = loss_mask_iou['loss_mask_iou'].sum() |
|
assert onegt_mask_iou_loss.item() >= 0 |
|
|