haotiz's picture
initial commit
708dec4
import torch
import torch.nn as nn
import torch.nn.functional as F
from .deform_conv import DeformConv2d
def add_conv(in_ch, out_ch, ksize, stride, leaky=True):
"""
Add a conv2d / batchnorm / leaky ReLU block.
Args:
in_ch (int): number of input channels of the convolution layer.
out_ch (int): number of output channels of the convolution layer.
ksize (int): kernel size of the convolution layer.
stride (int): stride of the convolution layer.
Returns:
stage (Sequential) : Sequential layers composing a convolution block.
"""
stage = nn.Sequential()
pad = (ksize - 1) // 2
stage.add_module('conv', nn.Conv2d(in_channels=in_ch,
out_channels=out_ch, kernel_size=ksize, stride=stride,
padding=pad, bias=False))
stage.add_module('batch_norm', nn.BatchNorm2d(out_ch))
if leaky:
stage.add_module('leaky', nn.LeakyReLU(0.1))
else:
stage.add_module('relu6', nn.ReLU6(inplace=True))
return stage
class upsample(nn.Module):
__constants__ = ['size', 'scale_factor', 'mode', 'align_corners', 'name']
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
super(upsample, self).__init__()
self.name = type(self).__name__
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, input):
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
def extra_repr(self):
if self.scale_factor is not None:
info = 'scale_factor=' + str(self.scale_factor)
else:
info = 'size=' + str(self.size)
info += ', mode=' + self.mode
return info
class SPPLayer(nn.Module):
def __init__(self):
super(SPPLayer, self).__init__()
def forward(self, x):
x_1 = x
x_2 = F.max_pool2d(x, 5, stride=1, padding=2)
x_3 = F.max_pool2d(x, 9, stride=1, padding=4)
x_4 = F.max_pool2d(x, 13, stride=1, padding=6)
out = torch.cat((x_1, x_2, x_3, x_4),dim=1)
return out
class DropBlock(nn.Module):
def __init__(self, block_size=7, keep_prob=0.9):
super(DropBlock, self).__init__()
self.block_size = block_size
self.keep_prob = keep_prob
self.gamma = None
self.kernel_size = (block_size, block_size)
self.stride = (1, 1)
self.padding = (block_size//2, block_size//2)
def reset(self, block_size, keep_prob):
self.block_size = block_size
self.keep_prob = keep_prob
self.gamma = None
self.kernel_size = (block_size, block_size)
self.stride = (1, 1)
self.padding = (block_size//2, block_size//2)
def calculate_gamma(self, x):
return (1-self.keep_prob) * x.shape[-1]**2/ \
(self.block_size**2 * (x.shape[-1] - self.block_size + 1)**2)
def forward(self, x):
if (not self.training or self.keep_prob==1): #set keep_prob=1 to turn off dropblock
return x
if self.gamma is None:
self.gamma = self.calculate_gamma(x)
if x.type() == 'torch.cuda.HalfTensor': #TODO: not fully support for FP16 now
FP16 = True
x = x.float()
else:
FP16 = False
p = torch.ones_like(x) * (self.gamma)
mask = 1 - torch.nn.functional.max_pool2d(torch.bernoulli(p),
self.kernel_size,
self.stride,
self.padding)
out = mask * x * (mask.numel()/mask.sum())
if FP16:
out = out.half()
return out
class resblock(nn.Module):
"""
Sequential residual blocks each of which consists of \
two convolution layers.
Args:
ch (int): number of input and output channels.
nblocks (int): number of residual blocks.
shortcut (bool): if True, residual tensor addition is enabled.
"""
def __init__(self, ch, nblocks=1, shortcut=True):
super().__init__()
self.shortcut = shortcut
self.module_list = nn.ModuleList()
for i in range(nblocks):
resblock_one = nn.ModuleList()
resblock_one.append(add_conv(ch, ch//2, 1, 1))
resblock_one.append(add_conv(ch//2, ch, 3, 1))
self.module_list.append(resblock_one)
def forward(self, x):
for module in self.module_list:
h = x
for res in module:
h = res(h)
x = x + h if self.shortcut else h
return x
class RFBblock(nn.Module):
def __init__(self,in_ch,residual=False):
super(RFBblock, self).__init__()
inter_c = in_ch // 4
self.branch_0 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
)
self.branch_1 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, padding=1)
)
self.branch_2 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, dilation=2, padding=2)
)
self.branch_3 = nn.Sequential(
nn.Conv2d(in_channels=in_ch, out_channels=inter_c, kernel_size=1, stride=1, padding=0),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=5, stride=1, padding=2),
nn.Conv2d(in_channels=inter_c, out_channels=inter_c, kernel_size=3, stride=1, dilation=3, padding=3)
)
self.residual= residual
def forward(self,x):
x_0 = self.branch_0(x)
x_1 = self.branch_1(x)
x_2 = self.branch_2(x)
x_3 = self.branch_3(x)
out = torch.cat((x_0,x_1,x_2,x_3),1)
if self.residual:
out +=x
return out
class FeatureAdaption(nn.Module):
def __init__(self, in_ch, out_ch, n_anchors, rfb=False, sep=False):
super(FeatureAdaption, self).__init__()
if sep:
self.sep=True
else:
self.sep=False
self.conv_offset = nn.Conv2d(in_channels=2*n_anchors,
out_channels=2*9*n_anchors, groups = n_anchors, kernel_size=1,stride=1,padding=0)
self.dconv = DeformConv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1,
padding=1, deformable_groups=n_anchors)
self.rfb=None
if rfb:
self.rfb = RFBblock(out_ch)
def forward(self, input, wh_pred):
#The RFB block is added behind FeatureAdaption
#For mobilenet, we currently don't support rfb and FeatureAdaption
if self.sep:
return input
if self.rfb is not None:
input = self.rfb(input)
wh_pred_new = wh_pred.detach()
offset = self.conv_offset(wh_pred_new)
out = self.dconv(input, offset)
return out
class ASFFmobile(nn.Module):
def __init__(self, level, rfb=False, vis=False):
super(ASFFmobile, self).__init__()
self.level = level
self.dim = [512, 256, 128]
self.inter_dim = self.dim[self.level]
if level==0:
self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2, leaky=False)
self.stride_level_2 = add_conv(128, self.inter_dim, 3, 2, leaky=False)
self.expand = add_conv(self.inter_dim, 1024, 3, 1, leaky=False)
elif level==1:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1, leaky=False)
self.stride_level_2 = add_conv(128, self.inter_dim, 3, 2, leaky=False)
self.expand = add_conv(self.inter_dim, 512, 3, 1, leaky=False)
elif level==2:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1, leaky=False)
self.compress_level_1 = add_conv(256, self.inter_dim, 1, 1, leaky=False)
self.expand = add_conv(self.inter_dim, 256, 3, 1,leaky=False)
compress_c = 8 if rfb else 16 #when adding rfb, we use half number of channels to save memory
self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1, leaky=False)
self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0)
self.vis= vis
def forward(self, x_level_0, x_level_1, x_level_2):
if self.level==0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level==1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized =x_level_1
level_2_resized =self.stride_level_2(x_level_2)
elif self.level==2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized =F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')
level_2_resized =x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+ \
level_1_resized * levels_weight[:,1:2,:,:]+ \
level_2_resized * levels_weight[:,2:,:,:]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
class ASFF(nn.Module):
def __init__(self, level, rfb=False, vis=False):
super(ASFF, self).__init__()
self.level = level
self.dim = [512, 256, 256]
self.inter_dim = self.dim[self.level]
if level==0:
self.stride_level_1 = add_conv(256, self.inter_dim, 3, 2)
self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
self.expand = add_conv(self.inter_dim, 1024, 3, 1)
elif level==1:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
self.stride_level_2 = add_conv(256, self.inter_dim, 3, 2)
self.expand = add_conv(self.inter_dim, 512, 3, 1)
elif level==2:
self.compress_level_0 = add_conv(512, self.inter_dim, 1, 1)
self.expand = add_conv(self.inter_dim, 256, 3, 1)
compress_c = 8 if rfb else 16 #when adding rfb, we use half number of channels to save memory
self.weight_level_0 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_1 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_level_2 = add_conv(self.inter_dim, compress_c, 1, 1)
self.weight_levels = nn.Conv2d(compress_c*3, 3, kernel_size=1, stride=1, padding=0)
self.vis= vis
def forward(self, x_level_0, x_level_1, x_level_2):
if self.level==0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter =F.max_pool2d(x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level==1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized =F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized =x_level_1
level_2_resized =self.stride_level_2(x_level_2)
elif self.level==2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized =F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
level_1_resized =F.interpolate(x_level_1, scale_factor=2, mode='nearest')
level_2_resized =x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v),1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+ \
level_1_resized * levels_weight[:,1:2,:,:]+ \
level_2_resized * levels_weight[:,2:,:,:]
out = self.expand(fused_out_reduced)
if self.vis:
return out, levels_weight, fused_out_reduced.sum(dim=1)
else:
return out
def make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
def add_sepconv(in_ch, out_ch, ksize, stride):
stage = nn.Sequential()
pad = (ksize - 1) // 2
stage.add_module('sepconv', nn.Conv2d(in_channels=in_ch,
out_channels=in_ch, kernel_size=ksize, stride=stride,
padding=pad, groups=in_ch, bias=False))
stage.add_module('sepbn', nn.BatchNorm2d(in_ch))
stage.add_module('seprelu6', nn.ReLU6(inplace=True))
stage.add_module('ptconv', nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False))
stage.add_module('ptbn', nn.BatchNorm2d(out_ch))
stage.add_module('ptrelu6', nn.ReLU6(inplace=True))
return stage
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class ressepblock(nn.Module):
def __init__(self, ch, out_ch, in_ch=None, shortcut=True):
super().__init__()
self.shortcut = shortcut
self.module_list = nn.ModuleList()
in_ch = ch//2 if in_ch==None else in_ch
resblock_one = nn.ModuleList()
resblock_one.append(add_conv(ch, in_ch, 1, 1, leaky=False))
resblock_one.append(add_conv(in_ch, out_ch, 3, 1,leaky=False))
self.module_list.append(resblock_one)
def forward(self, x):
for module in self.module_list:
h = x
for res in module:
h = res(h)
x = x + h if self.shortcut else h
return x