|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class ConvRep5(nn.Module): |
|
def __init__(self, in_channels, out_channels, rep_scale=4): |
|
super(ConvRep5, self).__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 5, 1, 2) |
|
self.conv_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, 5, 1, 2), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1) |
|
self.conv1_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, 1), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv2 = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1) |
|
self.conv2_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)) |
|
self.conv_crossh_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)) |
|
self.conv_crossv_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv_out = nn.Conv2d(out_channels * rep_scale * 10, out_channels, 1) |
|
|
|
def forward(self, inp): |
|
x = torch.cat( |
|
[self.conv(inp), |
|
self.conv1(inp), |
|
self.conv2(inp), |
|
self.conv_crossh(inp), |
|
self.conv_crossv(inp), |
|
self.conv_bn(inp), |
|
self.conv1_bn(inp), |
|
self.conv2_bn(inp), |
|
self.conv_crossh_bn(inp), |
|
self.conv_crossv_bn(inp)], |
|
1 |
|
) |
|
|
|
out = self.conv_out(x) |
|
|
|
return out |
|
|
|
def slim(self): |
|
conv_weight = self.conv.weight |
|
conv_bias = self.conv.bias |
|
conv1_weight = self.conv1.weight |
|
conv1_bias = self.conv1.bias |
|
conv1_weight = nn.functional.pad(conv1_weight, (2, 2, 2, 2)) |
|
conv2_weight = self.conv2.weight |
|
conv2_weight = nn.functional.pad(conv2_weight, (1, 1, 1, 1)) |
|
conv2_bias = self.conv2.bias |
|
conv_crossv_weight = self.conv_crossv.weight |
|
conv_crossv_weight = nn.functional.pad(conv_crossv_weight, (1, 1, 2, 2)) |
|
conv_crossv_bias = self.conv_crossv.bias |
|
conv_crossh_weight = self.conv_crossh.weight |
|
conv_crossh_weight = nn.functional.pad(conv_crossh_weight, (2, 2, 1, 1)) |
|
conv_crossh_bias = self.conv_crossh.bias |
|
conv1_bn_weight = self.conv1_bn[0].weight |
|
conv1_bn_weight = nn.functional.pad(conv1_bn_weight, (2, 2, 2, 2)) |
|
conv2_bn_weight = self.conv2_bn[0].weight |
|
conv2_bn_weight = nn.functional.pad(conv2_bn_weight, (1, 1, 1, 1)) |
|
conv_crossv_bn_weight = self.conv_crossv_bn[0].weight |
|
conv_crossv_bn_weight = nn.functional.pad(conv_crossv_bn_weight, (1, 1, 2, 2)) |
|
conv_crossh_bn_weight = self.conv_crossh_bn[0].weight |
|
conv_crossh_bn_weight = nn.functional.pad(conv_crossh_bn_weight, (2, 2, 1, 1)) |
|
bn = self.conv_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_bn_bias = self.conv_bn[0].bias * k + b |
|
conv_bn_bias = conv_bn_bias * bn.weight + bn.bias |
|
bn = self.conv1_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv1_bn_weight = conv1_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv1_bn_bias = self.conv1_bn[0].bias * k + b |
|
conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias |
|
bn = self.conv2_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv2_bn_weight = conv2_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv2_bn_weight = conv2_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv2_bn_bias = self.conv2_bn[0].bias * k + b |
|
conv2_bn_bias = conv2_bn_bias * bn.weight + bn.bias |
|
bn = self.conv_crossv_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv_crossv_bn_weight = conv_crossv_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_crossv_bn_bias = self.conv_crossv_bn[0].bias * k + b |
|
conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias |
|
bn = self.conv_crossh_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv_crossh_bn_weight = conv_crossh_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_crossh_bn_bias = self.conv_crossh_bn[0].bias * k + b |
|
conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias |
|
weight = torch.cat( |
|
[conv_weight, conv1_weight, conv2_weight, |
|
conv_crossh_weight, conv_crossv_weight, |
|
conv_bn_weight, conv1_bn_weight, conv2_bn_weight, |
|
conv_crossh_bn_weight, conv_crossv_bn_weight], |
|
0 |
|
) |
|
weight_compress = self.conv_out.weight.squeeze() |
|
weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1]) |
|
bias_ = torch.cat( |
|
[conv_bias, conv1_bias, conv2_bias, |
|
conv_crossh_bias, conv_crossv_bias, |
|
conv_bn_bias, conv1_bn_bias, conv2_bn_bias, |
|
conv_crossh_bn_bias, conv_crossv_bn_bias], |
|
0 |
|
) |
|
bias = torch.matmul(weight_compress, bias_) |
|
if isinstance(self.conv_out.bias, torch.Tensor): |
|
bias = bias + self.conv_out.bias |
|
return weight, bias |
|
|
|
|
|
class ConvRep3(nn.Module): |
|
def __init__(self, in_channels, out_channels, rep_scale=4): |
|
super(ConvRep3, self).__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1) |
|
self.conv_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, 3, 1, 1), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv1 = nn.Conv2d(in_channels, out_channels * rep_scale, 1) |
|
self.conv1_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, 1), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv_crossh = nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)) |
|
self.conv_crossh_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, (3, 1), 1, (1, 0)), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv_crossv = nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)) |
|
self.conv_crossv_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, (1, 3), 1, (0, 1)), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv_out = nn.Conv2d(out_channels * rep_scale * 8, out_channels, 1) |
|
|
|
def forward(self, inp): |
|
x = torch.cat( |
|
[self.conv(inp), |
|
self.conv1(inp), |
|
self.conv_crossh(inp), |
|
self.conv_crossv(inp), |
|
self.conv_bn(inp), |
|
self.conv1_bn(inp), |
|
self.conv_crossh_bn(inp), |
|
self.conv_crossv_bn(inp)], |
|
1 |
|
) |
|
|
|
out = self.conv_out(x) |
|
|
|
return out |
|
|
|
def slim(self): |
|
conv_weight = self.conv.weight |
|
conv_bias = self.conv.bias |
|
conv1_weight = self.conv1.weight |
|
conv1_bias = self.conv1.bias |
|
conv1_weight = nn.functional.pad(conv1_weight, (1, 1, 1, 1)) |
|
conv_crossv_weight = self.conv_crossv.weight |
|
conv_crossv_weight = nn.functional.pad(conv_crossv_weight, (0, 0, 1, 1)) |
|
conv_crossv_bias = self.conv_crossv.bias |
|
conv_crossh_weight = self.conv_crossh.weight |
|
conv_crossh_weight = nn.functional.pad(conv_crossh_weight, (1, 1, 0, 0)) |
|
conv_crossh_bias = self.conv_crossh.bias |
|
conv1_bn_weight = self.conv1_bn[0].weight |
|
conv1_bn_weight = nn.functional.pad(conv1_bn_weight, (1, 1, 1, 1)) |
|
conv_crossv_bn_weight = self.conv_crossv_bn[0].weight |
|
conv_crossv_bn_weight = nn.functional.pad(conv_crossv_bn_weight, (0, 0, 1, 1)) |
|
conv_crossh_bn_weight = self.conv_crossh_bn[0].weight |
|
conv_crossh_bn_weight = nn.functional.pad(conv_crossh_bn_weight, (1, 1, 0, 0)) |
|
bn = self.conv_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_bn_bias = self.conv_bn[0].bias * k + b |
|
conv_bn_bias = conv_bn_bias * bn.weight + bn.bias |
|
bn = self.conv1_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv1_bn_weight = conv1_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv1_bn_weight = conv1_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv1_bn_bias = self.conv1_bn[0].bias * k + b |
|
conv1_bn_bias = conv1_bn_bias * bn.weight + bn.bias |
|
bn = self.conv_crossv_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv_crossv_bn_weight = conv_crossv_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_crossv_bn_weight = conv_crossv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_crossv_bn_bias = self.conv_crossv_bn[0].bias * k + b |
|
conv_crossv_bn_bias = conv_crossv_bn_bias * bn.weight + bn.bias |
|
bn = self.conv_crossh_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv_crossh_bn_weight = conv_crossh_bn_weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_crossh_bn_weight = conv_crossh_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_crossh_bn_bias = self.conv_crossh_bn[0].bias * k + b |
|
conv_crossh_bn_bias = conv_crossh_bn_bias * bn.weight + bn.bias |
|
weight = torch.cat( |
|
[conv_weight, conv1_weight, |
|
conv_crossh_weight, conv_crossv_weight, |
|
conv_bn_weight, conv1_bn_weight, |
|
conv_crossh_bn_weight, conv_crossv_bn_weight], |
|
0 |
|
) |
|
weight_compress = self.conv_out.weight.squeeze() |
|
weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1]) |
|
bias_ = torch.cat( |
|
[conv_bias, conv1_bias, |
|
conv_crossh_bias, conv_crossv_bias, |
|
conv_bn_bias, conv1_bn_bias, |
|
conv_crossh_bn_bias, conv_crossv_bn_bias], |
|
0 |
|
) |
|
bias = torch.matmul(weight_compress, bias_) |
|
if isinstance(self.conv_out.bias, torch.Tensor): |
|
bias = bias + self.conv_out.bias |
|
return weight, bias |
|
|
|
|
|
class ConvRepPoint(nn.Module): |
|
def __init__(self, in_channels, out_channels, rep_scale=4): |
|
super(ConvRepPoint, self).__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.conv = nn.Conv2d(in_channels, out_channels * rep_scale, 1) |
|
self.conv_bn = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels * rep_scale, 1), |
|
nn.BatchNorm2d(out_channels * rep_scale) |
|
) |
|
self.conv_out = nn.Conv2d(out_channels * rep_scale * 2, out_channels, 1) |
|
|
|
def forward(self, inp): |
|
x = torch.cat([self.conv(inp), self.conv_bn(inp)], 1) |
|
out = self.conv_out(x) |
|
return out |
|
|
|
def slim(self): |
|
conv_weight = self.conv.weight |
|
conv_bias = self.conv.bias |
|
bn = self.conv_bn[1] |
|
k = 1 / (bn.running_var + bn.eps) ** .5 |
|
b = - bn.running_mean / (bn.running_var + bn.eps) ** .5 |
|
conv_bn_weight = self.conv_bn[0].weight * k.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_bn_weight = conv_bn_weight * bn.weight.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
|
conv_bn_bias = self.conv_bn[0].bias * k + b |
|
conv_bn_bias = conv_bn_bias * bn.weight + bn.bias |
|
weight = torch.cat([conv_weight, conv_bn_weight], 0) |
|
weight_compress = self.conv_out.weight.squeeze() |
|
weight = torch.matmul(weight_compress, weight.permute([2, 3, 0, 1])).permute([2, 3, 0, 1]) |
|
bias = torch.cat([conv_bias, conv_bn_bias], 0) |
|
bias = torch.matmul(weight_compress, bias) |
|
if isinstance(self.conv_out.bias, torch.Tensor): |
|
bias = bias + self.conv_out.bias |
|
return weight, bias |
|
|
|
|
|
class QuadraticConnectionUnit(nn.Module): |
|
def __init__(self, block1, block2, channels): |
|
super(QuadraticConnectionUnit, self).__init__() |
|
self.block1 = block1 |
|
self.block2 = block2 |
|
self.scale = 0.1 |
|
self.bias = nn.Parameter(torch.randn((1, channels, 1, 1))) |
|
|
|
def forward(self, x): |
|
return self.scale * self.block1(x) * self.block2(x) + self.bias |
|
|
|
|
|
class QuadraticConnectionUnitS(nn.Module): |
|
def __init__(self, block1, block2, channels): |
|
super(QuadraticConnectionUnitS, self).__init__() |
|
self.block1 = block1 |
|
self.block2 = block2 |
|
self.bias = nn.Parameter(torch.randn((1, channels, 1, 1))) |
|
|
|
def forward(self, x): |
|
return self.block1(x) * self.block2(x) + self.bias |
|
|
|
|
|
class AdditionFusion(nn.Module): |
|
def __init__(self, addend1, addend2, channels): |
|
super(AdditionFusion, self).__init__() |
|
self.addend1 = addend1 |
|
self.addend2 = addend2 |
|
self.bias = nn.Parameter(torch.randn((1, channels, 1, 1))) |
|
|
|
def forward(self, x): |
|
return self.addend1(x) + self.addend2(x) + self.bias |
|
|
|
|
|
class AdditionFusionS(nn.Module): |
|
def __init__(self, addend1, addend2, channels): |
|
super(AdditionFusionS, self).__init__() |
|
self.addend1 = addend1 |
|
self.addend2 = addend2 |
|
self.bias = nn.Parameter(torch.randn((1, channels, 1, 1))) |
|
|
|
def forward(self, x): |
|
return self.addend1(x) + self.addend2(x) + self.bias |
|
|
|
|
|
class DropBlock(nn.Module): |
|
def __init__(self, block_size, p=0.5): |
|
super(DropBlock, self).__init__() |
|
self.block_size = block_size |
|
self.p = p / block_size / block_size |
|
|
|
def forward(self, x): |
|
mask = 1 - (torch.rand_like(x[:, :1]) >= self.p).float() |
|
mask = nn.functional.max_pool2d(mask, self.block_size, 1, self.block_size // 2) |
|
return x * (1 - mask) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, num_feat=4, rep_scale=4): |
|
super(ResBlock, self).__init__() |
|
self.conv1 = ConvRep3(num_feat, num_feat, rep_scale=rep_scale) |
|
self.conv2 = ConvRep3(num_feat, num_feat, rep_scale=rep_scale) |
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = self.conv2(self.relu(self.conv1(x))) |
|
return identity + out |
|
|
|
|
|
class ResBlockS(nn.Module): |
|
def __init__(self, num_feat=4): |
|
super(ResBlockS, self).__init__() |
|
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) |
|
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) |
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = self.conv2(self.relu(self.conv1(x))) |
|
return identity + out |
|
|