ZJF-Thunder
添加文件
e26e560
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
from mmcv.ops import DeformConv2d, CornerPool
from mmdet.core import (PointGenerator, build_assigner, build_sampler,
images_to_levels, multi_apply, multiclass_nms_rpd, unmap)
from ..builder import HEADS, build_loss, build_head
from .anchor_free_head import AnchorFreeHead
from mmdet.utils.instances import Instances
from mmdet.utils.common import compute_locations
from itertools import accumulate
class CornerPoolPack(nn.Module):
def __init__(self, dim, pool1, pool2, conv_cfg=None, norm_cfg=None, first_kernel_size=3, kernel_size=3, corner_dim=128):
super(CornerPoolPack, self).__init__()
self.p1_conv1 = ConvModule(
dim,
corner_dim,
first_kernel_size,
stride=1,
padding=(first_kernel_size - 1) // 2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
self.p2_conv1 = ConvModule(
dim,
corner_dim,
first_kernel_size,
stride=1,
padding=(first_kernel_size - 1) // 2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
self.p_conv1 = nn.Conv2d(corner_dim, dim, 3, padding=1, bias=False)
self.p_gn1 = nn.GroupNorm(num_groups=32, num_channels=dim)
self.conv1 = nn.Conv2d(dim, dim, 1, bias=False)
self.gn1 = nn.GroupNorm(num_groups=32, num_channels=dim)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = ConvModule(
dim,
dim,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
self.pool1 = pool1
self.pool2 = pool2
def forward(self, x):
# pool 1
p1_conv1 = self.p1_conv1(x)
pool1 = self.pool1(p1_conv1)
# pool 2
p2_conv1 = self.p2_conv1(x)
pool2 = self.pool2(p2_conv1)
# pool 1 + pool 2
p_conv1 = self.p_conv1(pool1 + pool2)
p_gn1 = self.p_gn1(p_conv1)
conv1 = self.conv1(x)
gn1 = self.gn1(conv1)
relu1 = self.relu1(p_gn1 + gn1)
conv2 = self.conv2(relu1)
return conv2
class TLPool(CornerPoolPack):
def __init__(self, dim, conv_cfg=None, norm_cfg=None, first_kernel_size=3, kernel_size=3, corner_dim=128):
super(TLPool, self).__init__(dim, CornerPool('top'), CornerPool('left'), conv_cfg, norm_cfg, first_kernel_size, kernel_size, corner_dim)
class BRPool(CornerPoolPack):
def __init__(self, dim, conv_cfg=None, norm_cfg=None, first_kernel_size=3, kernel_size=3, corner_dim=128):
super(BRPool, self).__init__(dim, CornerPool('bottom'), CornerPool('right'), conv_cfg, norm_cfg, first_kernel_size, kernel_size, corner_dim)
@HEADS.register_module()
class RepPointsV2Head(AnchorFreeHead):
"""RepPointV2 head. With mask branch support
Args:
point_feat_channels (int): Number of channels of points features.
gradient_mul (float): The multiplier to gradients from
points refinement and recognition.
point_strides (Iterable): points strides.
point_base_scale (int): bbox scale for assigning labels.
loss_cls (dict): Config of classification loss.
loss_bbox_init (dict): Config of initial points loss.
loss_bbox_refine (dict): Config of points loss in refinement.
use_grid_points (bool): If we use bounding box representation, the
reppoints is represented as grid points on the bounding box.
center_init (bool): Whether to use center point assignment.
transform_method (str): The methods to transform RepPoints to bbox.
""" # noqa: W605
def __init__(self,
num_classes,
in_channels,
point_feat_channels=256,
shared_stacked_convs=1,
first_kernel_size=3,
kernel_size=1,
corner_dim=64,
num_points=9,
gradient_mul=0.1,
point_strides=[8, 16, 32, 64, 128],
point_base_scale=4,
corner_refine=True,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox_init=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
loss_bbox_refine=dict(
type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_heatmap=dict(
type='GaussianFocalLoss',
alpha=2.0,
gamma=4.0,
loss_weight=0.25),
loss_offset=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_sem=dict(type='SEPFocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=0.1),
use_grid_points=False,
center_init=True,
transform_method='moment',
moment_mul=0.01,
mask_head=None,
background_label=None,
**kwargs):
self.background_label = (
num_classes if background_label is None else background_label)
# background_label should be either 0 or num_classes
assert (self.background_label == 0
or self.background_label == num_classes)
self.num_points = num_points
self.point_feat_channels = point_feat_channels
self.shared_stacked_convs = shared_stacked_convs
self.use_grid_points = use_grid_points
self.center_init = center_init
self.first_kernel_size = first_kernel_size
self.kernel_size = kernel_size
self.corner_dim = corner_dim
self.corner_refine = corner_refine
# we use deformable conv to extract points features
self.dcn_kernel = int(np.sqrt(num_points))
self.dcn_pad = int((self.dcn_kernel - 1) / 2)
assert self.dcn_kernel * self.dcn_kernel == num_points, \
'The points number should be a square number.'
assert self.dcn_kernel % 2 == 1, \
'The points number should be an odd square number.'
dcn_base = np.arange(-self.dcn_pad,
self.dcn_pad + 1).astype(np.float64)
dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
(-1))
self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
self.controller_on = kwargs.pop('controller_on', 'cls')
self.coord_pos = kwargs.pop('coord_pos', 'grid')
if mask_head is not None and self.coord_pos in ['grid', 'center']:
mask_head.head_cfg.rel_num = 1
elif mask_head is not None and self.coord_pos in ['center-lt-rb']:
mask_head.head_cfg.rel_num = 3
super().__init__(num_classes, in_channels, loss_cls=loss_cls, **kwargs)
self.gradient_mul = gradient_mul
self.point_base_scale = point_base_scale
self.point_strides = point_strides
self.point_generators = [PointGenerator() for _ in self.point_strides]
if self.train_cfg:
self.init_assigner = build_assigner(self.train_cfg.init.assigner)
self.refine_assigner = build_assigner(self.train_cfg.refine.assigner)
self.hm_assigner = build_assigner(self.train_cfg.heatmap.assigner)
# use PseudoSampler when sampling is False
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.transform_method = transform_method
if self.transform_method == 'moment':
self.moment_transfer = nn.Parameter(data=torch.zeros(2), requires_grad=True)
self.moment_mul = moment_mul
self.cls_out_channels = self.num_classes
self.loss_bbox_init = build_loss(loss_bbox_init)
self.loss_bbox_refine = build_loss(loss_bbox_refine)
self.loss_heatmap = build_loss(loss_heatmap)
self.loss_offset = build_loss(loss_offset)
self.loss_sem = build_loss(loss_sem)
# mask
self.mask_head = mask_head
if mask_head is not None:
self.mask_head = build_head(mask_head)
self.controller = nn.Conv2d(
point_feat_channels, self.mask_head.head.num_gen_params,
kernel_size=3, stride=1, padding=1
)
torch.nn.init.normal_(self.controller.weight, std=0.01)
torch.nn.init.constant_(self.controller.bias, 0)
def _init_layers(self):
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
self.shared_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
for i in range(self.shared_stacked_convs):
self.shared_convs.append(
ConvModule(
self.feat_channels,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.hem_tl = TLPool(self.feat_channels, self.conv_cfg, self.norm_cfg, first_kernel_size=self.first_kernel_size, kernel_size=self.kernel_size, corner_dim=self.corner_dim)
self.hem_br = BRPool(self.feat_channels, self.conv_cfg, self.norm_cfg, first_kernel_size=self.first_kernel_size, kernel_size=self.kernel_size, corner_dim=self.corner_dim)
pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
cls_in_channels = self.feat_channels + 6
self.reppoints_cls_conv = DeformConv2d(cls_in_channels,
self.point_feat_channels,
self.dcn_kernel, 1, self.dcn_pad)
self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
self.cls_out_channels, 1, 1, 0)
self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
self.point_feat_channels, 3, 1, 1)
self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
pts_out_dim, 1, 1, 0)
pts_in_channels = self.feat_channels + 6
self.reppoints_pts_refine_conv = DeformConv2d(pts_in_channels,
self.point_feat_channels,
self.dcn_kernel, 1,
self.dcn_pad)
self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
pts_out_dim, 1, 1, 0)
self.reppoints_hem_tl_score_out = nn.Conv2d(self.feat_channels, 1, 3, 1, 1)
self.reppoints_hem_br_score_out = nn.Conv2d(self.feat_channels, 1, 3, 1, 1)
self.reppoints_hem_tl_offset_out = nn.Conv2d(self.feat_channels, 2, 3, 1, 1)
self.reppoints_hem_br_offset_out = nn.Conv2d(self.feat_channels, 2, 3, 1, 1)
self.reppoints_sem_out = nn.Conv2d(self.feat_channels, self.cls_out_channels, 1, 1, 0)
self.reppoints_sem_embedding = ConvModule(
self.feat_channels,
self.feat_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
def init_weights(self):
"""Initialize weights of the head."""
for m in self.cls_convs:
normal_init(m.conv, std=0.01)
for m in self.reg_convs:
normal_init(m.conv, std=0.01)
for m in self.shared_convs:
normal_init(m.conv, std=0.01)
bias_cls = bias_init_with_prob(0.01)
normal_init(self.reppoints_cls_conv, std=0.01)
normal_init(self.reppoints_cls_out, std=0.01, bias=bias_cls)
normal_init(self.reppoints_pts_init_conv, std=0.01)
normal_init(self.reppoints_pts_init_out, std=0.01)
normal_init(self.reppoints_pts_refine_conv, std=0.01)
normal_init(self.reppoints_pts_refine_out, std=0.01)
normal_init(self.reppoints_hem_tl_score_out, std=0.01, bias=bias_cls)
normal_init(self.reppoints_hem_tl_offset_out, std=0.01)
normal_init(self.reppoints_hem_br_score_out, std=0.01, bias=bias_cls)
normal_init(self.reppoints_hem_br_offset_out, std=0.01)
normal_init(self.reppoints_sem_out, std=0.01, bias=bias_cls)
def points2bbox(self, pts, y_first=True):
"""Converting the points set into bounding box.
:param pts: the input points sets (fields), each points
set (fields) is represented as 2n scalar.
:param y_first: if y_fisrt=True, the point set is represented as
[y1, x1, y2, x2 ... yn, xn], otherwise the point set is
represented as [x1, y1, x2, y2 ... xn, yn].
:return: each points set is converting to a bbox [x1, y1, x2, y2].
"""
pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...]
pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...]
if self.transform_method == 'minmax':
bbox_left = pts_x.min(dim=1, keepdim=True)[0]
bbox_right = pts_x.max(dim=1, keepdim=True)[0]
bbox_up = pts_y.min(dim=1, keepdim=True)[0]
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
dim=1)
elif self.transform_method == 'partial_minmax':
pts_y = pts_y[:, :4, ...]
pts_x = pts_x[:, :4, ...]
bbox_left = pts_x.min(dim=1, keepdim=True)[0]
bbox_right = pts_x.max(dim=1, keepdim=True)[0]
bbox_up = pts_y.min(dim=1, keepdim=True)[0]
bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
dim=1)
elif self.transform_method == 'moment':
pts_y_mean = pts_y.mean(dim=1, keepdim=True)
pts_x_mean = pts_x.mean(dim=1, keepdim=True)
pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
moment_transfer = (self.moment_transfer * self.moment_mul) + (
self.moment_transfer.detach() * (1 - self.moment_mul))
moment_width_transfer = moment_transfer[0]
moment_height_transfer = moment_transfer[1]
half_width = pts_x_std * torch.exp(moment_width_transfer)
half_height = pts_y_std * torch.exp(moment_height_transfer)
bbox = torch.cat([
pts_x_mean - half_width, pts_y_mean - half_height,
pts_x_mean + half_width, pts_y_mean + half_height
],
dim=1)
elif self.transform_method == "exact_minmax":
pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
pts_reshape = pts_reshape[:, :2, ...]
pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1, ...]
pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0, ...]
bbox_left = pts_x[:, 0:1, ...]
bbox_right = pts_x[:, 1:2, ...]
bbox_up = pts_y[:, 0:1, ...]
bbox_bottom = pts_y[:, 1:2, ...]
bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom], dim=1)
else:
raise NotImplementedError
return bbox
def gen_grid_from_reg(self, reg, previous_boxes):
"""Base on the previous bboxes and regression values, we compute the
regressed bboxes and generate the grids on the bboxes.
:param reg: the regression value to previous bboxes.
:param previous_boxes: previous bboxes.
:return: generate grids on the regressed bboxes.
"""
b, _, h, w = reg.shape
bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
bwh = (previous_boxes[:, 2:, ...] -
previous_boxes[:, :2, ...]).clamp(min=1e-6)
grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
reg[:, 2:, ...])
grid_wh = bwh * torch.exp(reg[:, 2:, ...])
grid_left = grid_topleft[:, [0], ...]
grid_top = grid_topleft[:, [1], ...]
grid_width = grid_wh[:, [0], ...]
grid_height = grid_wh[:, [1], ...]
intervel = torch.linspace(0., 1., self.dcn_kernel).view(
1, self.dcn_kernel, 1, 1).type_as(reg)
grid_x = grid_left + grid_width * intervel
grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
grid_x = grid_x.view(b, -1, h, w)
grid_y = grid_top + grid_height * intervel
grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
grid_y = grid_y.view(b, -1, h, w)
grid_yx = torch.stack([grid_y, grid_x], dim=2)
grid_yx = grid_yx.view(b, -1, h, w)
regressed_bbox = torch.cat([
grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
], 1)
return grid_yx, regressed_bbox
def compute_coord_pos(self, points):
# points: N, HW, C
x, y = points[..., 0::2], points[..., 1::2]
out_x, out_y = None, None
if self.coord_pos == 'center':
out_x, out_y = x.mean(-1), y.mean(-1)
else:
assert 0
return torch.stack([out_x.flatten(), out_y.flatten()], dim=-1)
def forward(self, feats):
if not self.mask_head:
return multi_apply(self.forward_single, feats)
else:
self.pred_instances = Instances((0,0)) # just a dumy image shape
bbox_res = multi_apply(self.forward_single, feats, controller=self.controller) # cls_out, pts_out_init, pts_out_refine, controller_out
device = bbox_res[1][0].device
self.level_sum = tuple(accumulate([0]+[r[:,0].numel() for r in bbox_res[1]][:-1]))
self.image_sum = torch.tensor([r[0,0].numel() for r in bbox_res[1]])
self.pred_instances.fpn_levels = torch.cat([torch.tensor([i]*r[:,0].numel()).to(device) for i,r in enumerate(bbox_res[1])])
self.pred_instances.im_inds = torch.cat([torch.cat([torch.tensor([i]*img.numel()).to(device) for i,img in enumerate(r[:,0])]) for r in bbox_res[1]])
pts_grid = torch.cat([compute_locations(*r.shape[-2:], self.point_strides[level], device).reshape(-1,2).repeat(len(r),1) for level, r in enumerate(bbox_res[1])], dim=0)
pts_center_list = []
for i_lvl in range(len(self.point_strides)):
pts_shift = bbox_res[1][i_lvl]
pts_shift = pts_shift.permute(0, 2, 3, 1)
yx_pts_shift = pts_shift.reshape(*pts_shift.shape[:-1], 2 * self.num_points)
y_pts_shift = yx_pts_shift[..., 0::2]
x_pts_shift = yx_pts_shift[..., 1::2]
xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
pts = xy_pts_shift * self.point_strides[i_lvl] # N, H, W, np, 2
pts_center = pts.mean(-2)
pts_center_list.append(pts_center)
self.pred_instances.locations = torch.cat([l.flatten(0, -2) for l in pts_center_list]) + pts_grid
self.pred_instances.mask_head_params = torch.cat([r.permute(0,2,3,1).flatten(0,-2) for r in bbox_res[-1]])
return bbox_res[:-1]
def forward_single(self, x, controller=None):
""" Forward feature map of a single FPN level."""
dcn_base_offset = self.dcn_base_offset.type_as(x)
# If we use center_init, the initial reppoints is from center points.
# If we use bounding bbox representation, the initial reppoints is
# from regular grid placed on a pre-defined bbox.
if self.use_grid_points or not self.center_init:
scale = self.point_base_scale / 2
points_init = dcn_base_offset / dcn_base_offset.max() * scale
bbox_init = x.new_tensor([-scale, -scale, scale,
scale]).view(1, 4, 1, 1)
else:
points_init = 0
cls_feat = x
pts_feat = x
for cls_conv in self.cls_convs:
cls_feat = cls_conv(cls_feat)
for reg_conv in self.reg_convs:
pts_feat = reg_conv(pts_feat)
shared_feat = pts_feat
for shared_conv in self.shared_convs:
shared_feat = shared_conv(shared_feat)
sem_feat = shared_feat
hem_feat = shared_feat
sem_scores_out = self.reppoints_sem_out(sem_feat)
sem_feat = self.reppoints_sem_embedding(sem_feat)
cls_feat = cls_feat + sem_feat
pts_feat = pts_feat + sem_feat
hem_feat = hem_feat + sem_feat
# generate heatmap and offset
hem_tl_feat = self.hem_tl(hem_feat)
hem_br_feat = self.hem_br(hem_feat)
hem_tl_score_out = self.reppoints_hem_tl_score_out(hem_tl_feat)
hem_tl_offset_out = self.reppoints_hem_tl_offset_out(hem_tl_feat)
hem_br_score_out = self.reppoints_hem_br_score_out(hem_br_feat)
hem_br_offset_out = self.reppoints_hem_br_offset_out(hem_br_feat)
hem_score_out = torch.cat([hem_tl_score_out, hem_br_score_out], dim=1)
hem_offset_out = torch.cat([hem_tl_offset_out, hem_br_offset_out], dim=1)
# initialize reppoints
# pts_out_init = self.reppoints_pts_init_out(self.relu(self.reppoints_pts_init_conv(pts_feat)))
init_reg_tower = self.relu(self.reppoints_pts_init_conv(pts_feat))
pts_out_init = self.reppoints_pts_init_out(init_reg_tower)
if self.controller_on == 'init':
controller_out = controller(init_reg_tower)
if self.use_grid_points:
pts_out_init, bbox_out_init = self.gen_grid_from_reg(pts_out_init, bbox_init.detach())
else:
pts_out_init = pts_out_init + points_init
# refine and classify reppoints
pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach() + self.gradient_mul * pts_out_init
dcn_offset = pts_out_init_grad_mul - dcn_base_offset
hem_feat = torch.cat([hem_score_out, hem_offset_out], dim=1)
cls_feat = torch.cat([cls_feat, hem_feat], dim=1)
pts_feat = torch.cat([pts_feat, hem_feat], dim=1)
cls_tower = self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))
cls_out = self.reppoints_cls_out(cls_tower)
if self.mask_head and controller is not None and self.controller_on == 'cls':
controller_out = controller(cls_tower)
refine_reg_tower = self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))
pts_out_refine = self.reppoints_pts_refine_out(refine_reg_tower)
if self.controller_on == 'refine':
controller_out = controller(refine_reg_tower)
if self.use_grid_points:
pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(pts_out_refine, bbox_out_init.detach())
else:
pts_out_refine = pts_out_refine + pts_out_init.detach()
if self.mask_head:
return cls_out, pts_out_init, pts_out_refine, hem_score_out, hem_offset_out, sem_scores_out, controller_out
else:
return cls_out, pts_out_init, pts_out_refine, hem_score_out, hem_offset_out, sem_scores_out
def get_points(self, featmap_sizes, img_metas):
"""Get points according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
img_metas (list[dict]): Image meta info.
Returns:
tuple: points of each image, valid flags of each image
"""
num_imgs = len(img_metas)
num_levels = len(featmap_sizes)
# since feature map sizes of all images are the same, we only compute
# points center for one time
multi_level_points = []
for i in range(num_levels):
points = self.point_generators[i].grid_points(
featmap_sizes[i], self.point_strides[i])
multi_level_points.append(points)
points_list = [[point.clone() for point in multi_level_points] for _ in range(num_imgs)]
# for each image, we compute valid flags of multi level grids
valid_flag_list = []
for img_id, img_meta in enumerate(img_metas):
multi_level_flags = []
for i in range(num_levels):
point_stride = self.point_strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = img_meta['pad_shape'][:2]
valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h)
valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w)
flags = self.point_generators[i].valid_flags(
(feat_h, feat_w), (valid_feat_h, valid_feat_w))
multi_level_flags.append(flags)
valid_flag_list.append(multi_level_flags)
return points_list, valid_flag_list
def centers_to_bboxes(self, point_list):
"""Get bboxes according to center points. Only used in MaxIOUAssigner.
"""
bbox_list = []
for i_img, point in enumerate(point_list):
bbox = []
for i_lvl in range(len(self.point_strides)):
scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
bbox_shift = torch.Tensor([-scale, -scale, scale, scale]).view(1, 4).type_as(point[0])
bbox_center = torch.cat([point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
bbox.append(bbox_center + bbox_shift)
bbox_list.append(bbox)
return bbox_list
def offset_to_pts(self, center_list, pred_list):
"""Change from point offset to point coordinate."""
pts_list = []
for i_lvl in range(len(self.point_strides)):
pts_lvl = []
for i_img in range(len(center_list)):
pts_center = center_list[i_img][i_lvl][:, :2].repeat(
1, self.num_points)
pts_shift = pred_list[i_lvl][i_img]
yx_pts_shift = pts_shift.permute(1, 2, 0).view(
-1, 2 * self.num_points)
y_pts_shift = yx_pts_shift[..., 0::2]
x_pts_shift = yx_pts_shift[..., 1::2]
xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
pts_lvl.append(pts)
pts_lvl = torch.stack(pts_lvl, 0)
pts_list.append(pts_lvl)
return pts_list
def _point_target_single(self,
flat_proposals,
valid_flags,
num_level_proposals,
gt_bboxes,
gt_bboxes_ignore,
gt_labels,
gt_masks=None,
label_channels=1,
stage='init',
unmap_outputs=True):
inside_flags = valid_flags
if not inside_flags.any():
return (None, ) * 6
# assign gt and sample proposals
proposals = flat_proposals[inside_flags, :]
num_level_proposals_inside = self.get_num_level_proposals_inside(num_level_proposals, inside_flags)
if stage == 'init':
assigner = self.init_assigner
assigner_type = self.train_cfg.init.assigner.type
pos_weight = self.train_cfg.init.pos_weight
else:
assigner = self.refine_assigner
assigner_type = self.train_cfg.refine.assigner.type
pos_weight = self.train_cfg.refine.pos_weight
if assigner_type != "ATSSAssignerV2":
assign_result = assigner.assign(proposals, gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks)
else:
assign_result = assigner.assign(proposals, num_level_proposals_inside, gt_bboxes, gt_bboxes_ignore, gt_labels, gt_masks)
sampling_result = self.sampler.sample(assign_result, proposals, gt_bboxes)
num_valid_proposals = proposals.shape[0]
bbox_gt = proposals.new_zeros([num_valid_proposals, 4])
bbox_weights = proposals.new_zeros([num_valid_proposals, 4])
labels = proposals.new_full((num_valid_proposals, ), self.background_label, dtype=torch.long)
label_weights = proposals.new_zeros(num_valid_proposals, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
pos_gt_bboxes = sampling_result.pos_gt_bboxes
bbox_gt[pos_inds, :] = pos_gt_bboxes
bbox_weights[pos_inds, :] = 1.0
if gt_labels is None:
labels[pos_inds] = 1
else:
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
if pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of proposals
if unmap_outputs:
num_total_proposals = flat_proposals.size(0)
labels = unmap(labels, num_total_proposals, inside_flags)
label_weights = unmap(label_weights, num_total_proposals, inside_flags)
bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_proposals, inside_flags)
gt_inds = unmap(
sampling_result.assign_result.gt_inds, num_total_proposals, inside_flags)
return labels, label_weights, bbox_gt, bbox_weights, pos_inds, neg_inds, gt_inds
def get_targets(self,
proposals_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_bboxes_ignore_list=None,
gt_labels_list=None,
stage='init',
label_channels=1,
unmap_outputs=True,
gt_masks=None):
"""Compute corresponding GT box and classification targets for
proposals.
Args:
proposals_list (list[list]): Multi level points/bboxes of each
image.
valid_flag_list (list[list]): Multi level valid flags of each
image.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
ignored.
gt_bboxes_list (list[Tensor]): Ground truth labels of each box.
stage (str): `init` or `refine`. Generate target for init stage or
refine stage
label_channels (int): Channel of label.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple:
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each level. # noqa: E501
- bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
- proposal_list (list[Tensor]): Proposals(points/bboxes) of each level. # noqa: E501
- proposal_weights_list (list[Tensor]): Proposal weights of each level. # noqa: E501
- num_total_pos (int): Number of positive samples in all images. # noqa: E501
- num_total_neg (int): Number of negative samples in all images. # noqa: E501
"""
assert stage in ['init', 'refine']
num_imgs = len(img_metas)
assert len(proposals_list) == len(valid_flag_list) == num_imgs
# points number of multi levels
num_level_proposals = [points.size(0) for points in proposals_list[0]]
num_level_proposals_list = [num_level_proposals] * num_imgs
# concat all level points and flags to a single tensor
for i in range(num_imgs):
assert len(proposals_list[i]) == len(valid_flag_list[i])
proposals_list[i] = torch.cat(proposals_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
# compute targets for each image
if gt_bboxes_ignore_list is None:
gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_labels, all_label_weights, all_bbox_gt, all_bbox_weights,
pos_inds_list, neg_inds_list, gt_inds) = multi_apply(
self._point_target_single,
proposals_list,
valid_flag_list,
num_level_proposals_list,
gt_bboxes_list,
gt_bboxes_ignore_list,
gt_labels_list,
gt_masks if self.mask_head else [None] * len(proposals_list),
stage=stage,
label_channels=label_channels,
unmap_outputs=unmap_outputs)
# no valid points
if any([labels is None for labels in all_labels]):
return None
# sampled points of all images
num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
labels_list = images_to_levels(all_labels, num_level_proposals)
label_weights_list = images_to_levels(all_label_weights, num_level_proposals)
bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_proposals)
shift_gt_inds = [] # shift gt_inds by previous mask number
if self.mask_head:
for gt_ind,nsum in zip(gt_inds,self.masknum_sum):
gt_ind[gt_ind>0] += nsum
shift_gt_inds.append(gt_ind)
shift_gt_inds = images_to_levels(shift_gt_inds, num_level_proposals)
return (labels_list, label_weights_list, bbox_gt_list, bbox_weights_list,
num_total_pos, num_total_neg, shift_gt_inds)
def _hm_target_single(self,
flat_points,
inside_flags,
gt_bboxes,
gt_labels,
unmap_outputs=True):
# assign gt and sample points
if not inside_flags.any():
return (None, ) * 12
points = flat_points[inside_flags, :]
assigner = self.hm_assigner
gt_hm_tl, gt_offset_tl, pos_inds_tl, neg_inds_tl, \
gt_hm_br, gt_offset_br, pos_inds_br, neg_inds_br = \
assigner.assign(points, gt_bboxes, gt_labels)
num_valid_points = points.shape[0]
hm_tl_weights = points.new_zeros(num_valid_points, dtype=torch.float)
hm_br_weights = points.new_zeros(num_valid_points, dtype=torch.float)
offset_tl_weights = points.new_zeros([num_valid_points, 2], dtype=torch.float)
offset_br_weights = points.new_zeros([num_valid_points, 2], dtype=torch.float)
hm_tl_weights[pos_inds_tl] = 1.0
hm_tl_weights[neg_inds_tl] = 1.0
offset_tl_weights[pos_inds_tl, :] = 1.0
hm_br_weights[pos_inds_br] = 1.0
hm_br_weights[neg_inds_br] = 1.0
offset_br_weights[pos_inds_br, :] = 1.0
# map up to original set of grids
if unmap_outputs:
num_total_points = flat_points.shape[0]
gt_hm_tl = unmap(gt_hm_tl, num_total_points, inside_flags)
gt_offset_tl = unmap(gt_offset_tl, num_total_points, inside_flags)
hm_tl_weights = unmap(hm_tl_weights, num_total_points, inside_flags)
offset_tl_weights = unmap(offset_tl_weights, num_total_points, inside_flags)
gt_hm_br = unmap(gt_hm_br, num_total_points, inside_flags)
gt_offset_br = unmap(gt_offset_br, num_total_points, inside_flags)
hm_br_weights = unmap(hm_br_weights, num_total_points, inside_flags)
offset_br_weights = unmap(offset_br_weights, num_total_points, inside_flags)
return (gt_hm_tl, gt_offset_tl, hm_tl_weights, offset_tl_weights, pos_inds_tl, neg_inds_tl,
gt_hm_br, gt_offset_br, hm_br_weights, offset_br_weights, pos_inds_br, neg_inds_br)
def get_hm_targets(self,
proposals_list,
valid_flag_list,
gt_bboxes_list,
img_metas,
gt_labels_list=None,
unmap_outputs=True):
"""Compute refinement and classification targets for points.
Args:
points_list (list[list]): Multi level points of each image.
valid_flag_list (list[list]): Multi level valid flags of each image.
gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
img_metas (list[dict]): Meta info of each image.
cfg (dict): train sample configs.
Returns:
tuple
"""
num_imgs = len(img_metas)
assert len(proposals_list) == len(valid_flag_list) == num_imgs
# points number of multi levels
num_level_proposals = [points.size(0) for points in proposals_list[0]]
# concat all level points and flags to a single tensor
for i in range(len(proposals_list)):
assert len(proposals_list[i]) == len(valid_flag_list[i])
proposals_list[i] = torch.cat(proposals_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
if gt_labels_list is None:
gt_labels_list = [None for _ in range(num_imgs)]
(all_gt_hm_tl, all_gt_offset_tl, all_hm_tl_weights, all_offset_tl_weights, pos_inds_tl_list, neg_inds_tl_list,
all_gt_hm_br, all_gt_offset_br, all_hm_br_weights, all_offset_br_weights, pos_inds_br_list, neg_inds_br_list) = \
multi_apply(
self._hm_target_single,
proposals_list,
valid_flag_list,
gt_bboxes_list,
gt_labels_list,
unmap_outputs=unmap_outputs)
# no valid points
if any([gt_hm_tl is None for gt_hm_tl in all_gt_hm_tl]):
return None
# sampled points of all images
num_total_pos_tl = sum([max(inds.numel(), 1) for inds in pos_inds_tl_list])
num_total_neg_tl = sum([max(inds.numel(), 1) for inds in neg_inds_tl_list])
num_total_pos_br = sum([max(inds.numel(), 1) for inds in pos_inds_br_list])
num_total_neg_br = sum([max(inds.numel(), 1) for inds in neg_inds_br_list])
gt_hm_tl_list = images_to_levels(all_gt_hm_tl, num_level_proposals)
gt_offset_tl_list = images_to_levels(all_gt_offset_tl, num_level_proposals)
hm_tl_weight_list = images_to_levels(all_hm_tl_weights, num_level_proposals)
offset_tl_weight_list = images_to_levels(all_offset_tl_weights, num_level_proposals)
gt_hm_br_list = images_to_levels(all_gt_hm_br, num_level_proposals)
gt_offset_br_list = images_to_levels(all_gt_offset_br, num_level_proposals)
hm_br_weight_list = images_to_levels(all_hm_br_weights, num_level_proposals)
offset_br_weight_list = images_to_levels(all_offset_br_weights, num_level_proposals)
return (gt_hm_tl_list, gt_offset_tl_list, hm_tl_weight_list, offset_tl_weight_list,
gt_hm_br_list, gt_offset_br_list, hm_br_weight_list, offset_br_weight_list,
num_total_pos_tl, num_total_neg_tl, num_total_pos_br, num_total_neg_br)
def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, hm_score, hm_offset,
labels, label_weights,
bbox_gt_init, bbox_weights_init,
bbox_gt_refine, bbox_weights_refine,
gt_hm_tl, gt_offset_tl, gt_hm_tl_weight, gt_offset_tl_weight,
gt_hm_br, gt_offset_br, gt_hm_br_weight, gt_offset_br_weight,
stride,
num_total_samples_init, num_total_samples_refine,
num_total_samples_tl, num_total_samples_br):
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
# points loss
bbox_gt_init = bbox_gt_init.reshape(-1, 4)
bbox_weights_init = bbox_weights_init.reshape(-1, 4)
bbox_pred_init = self.points2bbox(pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
bbox_pred_refine = self.points2bbox(pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
normalize_term = self.point_base_scale * stride
loss_pts_init = self.loss_bbox_init(
bbox_pred_init / normalize_term,
bbox_gt_init / normalize_term,
bbox_weights_init,
avg_factor=num_total_samples_init)
loss_pts_refine = self.loss_bbox_refine(
bbox_pred_refine / normalize_term,
bbox_gt_refine / normalize_term,
bbox_weights_refine,
avg_factor=num_total_samples_refine)
#cls loss
if hasattr(self.loss_cls,'requires_box') :
loss_cls = self.loss_cls(cls_score,labels, label_weights, bbox_pred_refine,bbox_gt_refine,bbox_weights_refine,background_label=self.background_label,avg_factor=num_total_samples_refine)
else:
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=num_total_samples_refine)
# heatmap cls loss
hm_score = hm_score.permute(0, 2, 3, 1).reshape(-1, 2)
hm_score_tl, hm_score_br = torch.chunk(hm_score, 2, dim=-1)
hm_score_tl = hm_score_tl.squeeze(1).sigmoid()
hm_score_br = hm_score_br.squeeze(1).sigmoid()
gt_hm_tl = gt_hm_tl.reshape(-1)
gt_hm_tl_weight = gt_hm_tl_weight.reshape(-1)
gt_hm_br = gt_hm_br.reshape(-1)
gt_hm_br_weight = gt_hm_br_weight.reshape(-1)
loss_heatmap = 0
loss_heatmap += self.loss_heatmap(
hm_score_tl, gt_hm_tl, gt_hm_tl_weight, avg_factor=num_total_samples_tl
)
loss_heatmap += self.loss_heatmap(
hm_score_br, gt_hm_br, gt_hm_br_weight, avg_factor=num_total_samples_br
)
loss_heatmap /= 2.0
# heatmap offset loss
hm_offset = hm_offset.permute(0, 2, 3, 1).reshape(-1, 4)
hm_offset_tl, hm_offset_br = torch.chunk(hm_offset, 2, dim=-1)
gt_offset_tl = gt_offset_tl.reshape(-1, 2)
gt_offset_tl_weight = gt_offset_tl_weight.reshape(-1, 2)
gt_offset_br = gt_offset_br.reshape(-1, 2)
gt_offset_br_weight = gt_offset_br_weight.reshape(-1, 2)
loss_offset = 0
loss_offset += self.loss_offset(
hm_offset_tl, gt_offset_tl, gt_offset_tl_weight,
avg_factor=num_total_samples_tl
)
loss_offset += self.loss_offset(
hm_offset_br, gt_offset_br, gt_offset_br_weight,
avg_factor=num_total_samples_br
)
loss_offset /= 2.0
return loss_cls, loss_pts_init, loss_pts_refine, loss_heatmap, loss_offset
def loss(self,
cls_scores,
pts_preds_init,
pts_preds_refine,
hm_scores,
hm_offsets,
sem_scores,
gt_bboxes,
gt_sem_map,
gt_sem_weights,
gt_labels,
img_metas,
gt_bboxes_ignore=None,
gt_masks=None):
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == len(self.point_generators)
label_channels = self.cls_out_channels
# target for initial stage
# stride affect
center_list, valid_flag_list = self.get_points(featmap_sizes, img_metas)
pts_coordinate_preds_init = self.offset_to_pts(center_list, pts_preds_init)
if self.train_cfg.init.assigner['type'] != 'MaxIoUAssigner':
# Assign target for center list
candidate_list = center_list
else:
# transform center list to bbox list and
# assign target for bbox list
bbox_list = self.centers_to_bboxes(center_list)
candidate_list = bbox_list
cls_reg_targets_init = self.get_targets(
candidate_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
stage='init',
label_channels=label_channels,
gt_masks=gt_masks)
(*_, bbox_gt_list_init, bbox_weights_list_init,
num_total_pos_init, num_total_neg_init, gt_inds) = cls_reg_targets_init
# target for heatmap in initial stage
# stride affect
proposal_list, valid_flag_list = self.get_points(featmap_sizes, img_metas)
heatmap_targets = self.get_hm_targets(
proposal_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_labels)
(gt_hm_tl_list, gt_offset_tl_list, gt_hm_tl_weight_list, gt_offset_tl_weight_list,
gt_hm_br_list, gt_offset_br_list, gt_hm_br_weight_list, gt_offset_br_weight_list,
num_total_pos_tl, num_total_neg_tl, num_total_pos_br, num_total_neg_br) = heatmap_targets
# target for refinement stage
center_list, valid_flag_list = self.get_points(featmap_sizes, img_metas) #[x,y,stride]
pts_coordinate_preds_refine = self.offset_to_pts(center_list, pts_preds_refine)
bbox_list = []
for i_img, center in enumerate(center_list):
bbox = []
for i_lvl in range(len(pts_preds_refine)):
bbox_preds_init = self.points2bbox(
pts_preds_init[i_lvl].detach())
bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
bbox_center = torch.cat([center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1) # no stride
bbox.append(bbox_center + bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
bbox_list.append(bbox)
cls_reg_targets_refine = self.get_targets(
bbox_list,
valid_flag_list,
gt_bboxes,
img_metas,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
stage='refine',
label_channels=label_channels,
gt_masks=gt_masks)
(labels_list, label_weights_list,
bbox_gt_list_refine, bbox_weights_list_refine,
num_total_pos_refine, num_total_neg_refine, gt_inds) = cls_reg_targets_refine
if self.mask_head:
def get_shape(boxes):
return torch.stack((boxes[:,2]-boxes[:,0], boxes[:,3]-boxes[:,1]), axis=-1)
self.pred_instances.gt_inds = torch.cat([torch.flatten(gt_ind) for gt_ind in gt_inds])
# norm by gt box, size=(W,H)
self.pred_instances.boxsz = torch.cat([get_shape(gt_bbox.reshape(-1,4)) for gt_bbox in bbox_gt_list_refine])
# norm by pred box, size=(W,H)
# bbox_pred_refine = [self.points2bbox(pts.reshape(-1, 2 * self.num_points), y_first=False) for pts in pts_coordinate_preds_refine]
# self.pred_instances.boxsz = torch.cat([get_shape(bbox.reshape(-1,4)) for bbox in bbox_pred_refine])
# compute loss
losses_cls, losses_pts_init, losses_pts_refine, losses_heatmap, losses_offset = multi_apply(
self.loss_single,
cls_scores,
pts_coordinate_preds_init,
pts_coordinate_preds_refine,
hm_scores,
hm_offsets,
labels_list,
label_weights_list,
bbox_gt_list_init,
bbox_weights_list_init,
bbox_gt_list_refine,
bbox_weights_list_refine,
gt_hm_tl_list,
gt_offset_tl_list,
gt_hm_tl_weight_list,
gt_offset_tl_weight_list,
gt_hm_br_list,
gt_offset_br_list,
gt_hm_br_weight_list,
gt_offset_br_weight_list,
self.point_strides,
num_total_samples_init=num_total_pos_init,
num_total_samples_refine=num_total_pos_refine,
num_total_samples_tl=num_total_pos_tl,
num_total_samples_br=num_total_pos_br)
# sem loss
concat_sem_scores = []
concat_gt_sem_map = []
concat_gt_sem_weights = []
for i in range(5):
sem_score = sem_scores[i]
gt_lvl_sem_map = F.interpolate(gt_sem_map, sem_score.shape[-2:]).reshape(-1)
gt_lvl_sem_weight = F.interpolate(gt_sem_weights, sem_score.shape[-2:]).reshape(-1)
sem_score = sem_score.reshape(-1)
try:
concat_sem_scores = torch.cat([concat_sem_scores, sem_score])
concat_gt_sem_map = torch.cat([concat_gt_sem_map, gt_lvl_sem_map])
concat_gt_sem_weights = torch.cat([concat_gt_sem_weights, gt_lvl_sem_weight])
except:
concat_sem_scores = sem_score
concat_gt_sem_map = gt_lvl_sem_map
concat_gt_sem_weights = gt_lvl_sem_weight
loss_sem = self.loss_sem(concat_sem_scores, concat_gt_sem_map, concat_gt_sem_weights, avg_factor=(concat_gt_sem_map > 0).sum())
loss_dict_all = {'loss_cls': losses_cls,
'loss_pts_init': losses_pts_init,
'loss_pts_refine': losses_pts_refine,
'loss_heatmap': losses_heatmap,
'loss_offset': losses_offset,
'loss_sem': loss_sem,
}
return loss_dict_all
def get_bboxes(self,
cls_scores,
pts_preds_init,
pts_preds_refine,
hm_scores,
hm_offsets,
sem_scores,
img_metas,
cfg=None,
rescale=False,
nms=True):
assert len(cls_scores) == len(pts_preds_refine)
bbox_preds_refine = [self.points2bbox(pts_pred_refine) for pts_pred_refine in pts_preds_refine]
num_levels = len(cls_scores)
mlvl_points = [
self.point_generators[i].grid_points(cls_scores[i].size()[-2:],
self.point_strides[i])
for i in range(num_levels)
]
result_list = []
for img_id in range(len(img_metas)):
cls_score_list = [
cls_scores[i][img_id].detach() for i in range(num_levels)
]
bbox_pred_list = [
bbox_preds_refine[i][img_id].detach() for i in range(num_levels)
]
hm_scores_list = [
hm_scores[i][img_id].detach() for i in range(num_levels)
]
hm_offsets_list = [
hm_offsets[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id]['img_shape']
scale_factor = img_metas[img_id]['scale_factor']
proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list, hm_scores_list, hm_offsets_list,
mlvl_points, img_shape,
scale_factor, cfg, rescale,
nms, imid_shift=self.image_sum*img_id if self.mask_head else 0)
result_list.append(proposals)
return result_list
def _get_bboxes_single(self,
cls_scores,
bbox_preds,
hm_scores,
hm_offsets,
mlvl_points,
img_shape,
scale_factor,
cfg,
rescale=False,
nms=True,
imid_shift=0):
def select(score_map, x, y, ks=2, i=0):
H, W = score_map.shape[-2], score_map.shape[-1]
score_map = score_map.sigmoid()
score_map_original = score_map.clone()
score_map, indices = F.max_pool2d_with_indices(score_map.unsqueeze(0), kernel_size=ks, stride=1, padding=(ks - 1) // 2)
indices = indices.squeeze(0).squeeze(0)
if ks % 2 == 0:
round_func = torch.floor
else:
round_func = torch.round
x_round = round_func((x / self.point_strides[i]).clamp(min=0, max=score_map.shape[-1] - 1))
y_round = round_func((y / self.point_strides[i]).clamp(min=0, max=score_map.shape[-2] - 1))
select_indices = indices[y_round.to(torch.long), x_round.to(torch.long)]
new_x = select_indices % W
new_y = select_indices // W
score_map_squeeze = score_map_original.squeeze(0)
score = score_map_squeeze[new_y, new_x]
new_x, new_y = new_x.to(torch.float), new_y.to(torch.float)
return new_x, new_y, score
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
mlvl_bboxes = []
mlvl_scores = []
inst_inds = [] if self.mask_head else None
for i_lvl, (cls_score, bbox_pred, points) in enumerate(zip(cls_scores, bbox_preds, mlvl_points)):
if self.mask_head:
inst_ind = torch.arange(cls_score.shape[-2]*cls_score.shape[-1], device=cls_score.device) + self.level_sum[i_lvl] + imid_shift[i_lvl]
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
scores = cls_score.permute(1, 2, 0).reshape(-1, self.cls_out_channels).sigmoid()
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
nms_pre = cfg.get('nms_pre', -1)
if nms_pre > 0 and scores.shape[0] > nms_pre:
max_scores, _ = scores.max(dim=1)
_, topk_inds = max_scores.topk(nms_pre)
points = points[topk_inds, :]
bbox_pred = bbox_pred[topk_inds, :]
scores = scores[topk_inds, :]
if self.mask_head:
inst_ind = inst_ind[topk_inds]
bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
bboxes = bbox_pred * self.point_strides[i_lvl] + bbox_pos_center
x1 = bboxes[:, 0].clamp(min=0, max=img_shape[1])
y1 = bboxes[:, 1].clamp(min=0, max=img_shape[0])
x2 = bboxes[:, 2].clamp(min=0, max=img_shape[1])
y2 = bboxes[:, 3].clamp(min=0, max=img_shape[0])
if self.corner_refine:
if i_lvl > 0:
i = 0 if i_lvl in (1, 2) else 1
x1_new, y1_new, score1_new = select(hm_scores[i][0, ...], x1, y1, 2, i)
x2_new, y2_new, score2_new = select(hm_scores[i][1, ...], x2, y2, 2, i)
hm_offset = hm_offsets[i].permute(1, 2, 0)
point_stride = self.point_strides[i]
x1 = ((x1_new + hm_offset[y1_new.to(torch.long), x1_new.to(torch.long), 0]) * point_stride).clamp(min=0, max=img_shape[1])
y1 = ((y1_new + hm_offset[y1_new.to(torch.long), x1_new.to(torch.long), 1]) * point_stride).clamp(min=0, max=img_shape[0])
x2 = ((x2_new + hm_offset[y2_new.to(torch.long), x2_new.to(torch.long), 2]) * point_stride).clamp(min=0, max=img_shape[1])
y2 = ((y2_new + hm_offset[y2_new.to(torch.long), x2_new.to(torch.long), 3]) * point_stride).clamp(min=0, max=img_shape[0])
bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
mlvl_bboxes.append(bboxes)
mlvl_scores.append(scores)
if self.mask_head:
inst_inds.append(inst_ind)
if self.mask_head:
inst_inds = torch.cat(inst_inds)
mlvl_bboxes = torch.cat(mlvl_bboxes)
if rescale:
mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
mlvl_scores = torch.cat(mlvl_scores)
padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
if nms:
if self.mask_head:
det_bboxes, det_labels, inst_inds, _ = multiclass_nms_rpd(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img, inst_inds=inst_inds)
return det_bboxes, det_labels, inst_inds
else:
det_bboxes, det_labels, _ = multiclass_nms_rpd(mlvl_bboxes, mlvl_scores,
cfg.score_thr, cfg.nms,
cfg.max_per_img)
return det_bboxes, det_labels
else:
if self.mask_head:
return mlvl_bboxes, mlvl_scores, inst_inds
else:
return mlvl_bboxes, mlvl_scores
def get_num_level_proposals_inside(self, num_level_proposals, inside_flags):
split_inside_flags = torch.split(inside_flags, num_level_proposals)
num_level_proposals_inside = [
int(flags.sum()) for flags in split_inside_flags
]
return num_level_proposals_inside
def forward_train(self,
x,
img_metas,
gt_bboxes,
gt_labels=None,
gt_bboxes_ignore=None,
gt_sem_map=None,
gt_sem_weights=None,
gt_masks=None,
proposal_cfg=None,
**kwargs):
"""
# NOTE! for condinst only
Args:
x (list[Tensor]): Features from FPN.
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (Tensor): Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels (Tensor): Ground truth labels of each box,
shape (num_gts,).
gt_bboxes_ignore (Tensor): Ground truth bboxes to be
ignored, shape (num_ignored_gts, 4).
gt_masks (None | list[Tensor]) : true segmentation masks for each box
used if the architecture supports a segmentation task.
proposal_cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used
Returns:
tuple:
losses: (dict[str, Tensor]): A dictionary of loss components.
proposal_list (list[Tensor]): Proposals of each image.
"""
if gt_masks:
self.masknum_sum = tuple(accumulate([0]+[len(mask) for mask in gt_masks][:-1]))
outs = self(x)
if gt_labels is None:
loss_inputs = outs + (gt_bboxes, gt_sem_map, gt_sem_weights, img_metas)
else:
loss_inputs = outs + (gt_bboxes, gt_sem_map, gt_sem_weights, gt_labels, img_metas)
losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, gt_masks=gt_masks)
if self.mask_head:
self.pred_instances = self.pred_instances[self.pred_instances.gt_inds > 0] # select only foreground labels
self.pred_instances.gt_inds = self.pred_instances.gt_inds - 1
mask_loss = self.mask_head(x, self.pred_instances, gt_masks, gt_labels)
losses.update(mask_loss)
if proposal_cfg is None:
return losses
else:
proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
return losses, proposal_list