|
import torch |
|
import torch.nn as nn |
|
from WT.transform import DWT, IWT |
|
|
|
|
|
def conv3x3(in_chn, out_chn, bias=True): |
|
layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) |
|
return layer |
|
|
|
def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): |
|
return nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
padding=(kernel_size // 2), bias=bias, stride=stride) |
|
|
|
def bili_resize(factor): |
|
return nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=False) |
|
|
|
|
|
class UNetConvBlock(nn.Module): |
|
def __init__(self, in_size, out_size, downsample): |
|
super(UNetConvBlock, self).__init__() |
|
self.downsample = downsample |
|
self.body = [HWB(n_feat=in_size, o_feat=in_size, kernel_size=3, reduction=16, bias=False, act=nn.PReLU())] |
|
self.body = nn.Sequential(*self.body) |
|
|
|
if downsample: |
|
self.downsample = PS_down(out_size, out_size, downscale=2) |
|
|
|
self.tail = nn.Conv2d(in_size, out_size, kernel_size=1) |
|
|
|
def forward(self, x): |
|
out = self.body(x) |
|
out = self.tail(out) |
|
if self.downsample: |
|
out_down = self.downsample(out) |
|
return out_down, out |
|
else: |
|
return out |
|
|
|
class UNetUpBlock(nn.Module): |
|
def __init__(self, in_size, out_size): |
|
super(UNetUpBlock, self).__init__() |
|
self.up = PS_up(in_size, out_size, upscale=2) |
|
self.conv_block = UNetConvBlock(in_size, out_size, downsample=False) |
|
|
|
def forward(self, x, bridge): |
|
up = self.up(x) |
|
out = torch.cat([up, bridge], dim=1) |
|
out = self.conv_block(out) |
|
return out |
|
|
|
|
|
class PS_down(nn.Module): |
|
def __init__(self, in_size, out_size, downscale): |
|
super(PS_down, self).__init__() |
|
self.UnPS = nn.PixelUnshuffle(downscale) |
|
self.conv1 = nn.Conv2d((downscale**2) * in_size, out_size, 1, 1, 0) |
|
|
|
def forward(self, x): |
|
x = self.UnPS(x) |
|
x = self.conv1(x) |
|
return x |
|
|
|
class PS_up(nn.Module): |
|
def __init__(self, in_size, out_size, upscale): |
|
super(PS_up, self).__init__() |
|
|
|
self.PS = nn.PixelShuffle(upscale) |
|
self.conv1 = nn.Conv2d(in_size//(upscale**2), out_size, 1, 1, 0) |
|
|
|
def forward(self, x): |
|
x = self.PS(x) |
|
x = self.conv1(x) |
|
return x |
|
|
|
|
|
class SKFF(nn.Module): |
|
def __init__(self, in_channels, height=3, reduction=8, bias=False): |
|
super(SKFF, self).__init__() |
|
|
|
self.height = height |
|
d = max(int(in_channels / reduction), 4) |
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.PReLU()) |
|
|
|
self.fcs = nn.ModuleList([]) |
|
for i in range(self.height): |
|
self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias)) |
|
|
|
self.softmax = nn.Softmax(dim=1) |
|
|
|
def forward(self, inp_feats): |
|
batch_size, n_feats, H, W = inp_feats[1].shape |
|
|
|
inp_feats = torch.cat(inp_feats, dim=1) |
|
inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3]) |
|
|
|
feats_U = torch.sum(inp_feats, dim=1) |
|
feats_S = self.avg_pool(feats_U) |
|
feats_Z = self.conv_du(feats_S) |
|
|
|
attention_vectors = [fc(feats_Z) for fc in self.fcs] |
|
attention_vectors = torch.cat(attention_vectors, dim=1) |
|
attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1) |
|
|
|
attention_vectors = self.softmax(attention_vectors) |
|
feats_V = torch.sum(inp_feats * attention_vectors, dim=1) |
|
|
|
return feats_V |
|
|
|
|
|
|
|
|
|
class SALayer(nn.Module): |
|
def __init__(self, kernel_size=5, bias=False): |
|
super(SALayer, self).__init__() |
|
self.conv_du = nn.Sequential( |
|
nn.Conv2d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
|
|
max_pool, _ = torch.max(x, dim=1, keepdim=True) |
|
avg_pool = torch.mean(x, 1, keepdim=True) |
|
channel_pool = torch.cat([max_pool, avg_pool], dim=1) |
|
y = self.conv_du(channel_pool) |
|
|
|
return x * y |
|
|
|
|
|
|
|
class CALayer(nn.Module): |
|
def __init__(self, channel, reduction=16, bias=False): |
|
super(CALayer, self).__init__() |
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
self.conv_du = nn.Sequential( |
|
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
y = self.avg_pool(x) |
|
y = self.conv_du(y) |
|
return x * y |
|
|
|
|
|
|
|
class HWB(nn.Module): |
|
def __init__(self, n_feat, o_feat, kernel_size, reduction, bias, act): |
|
super(HWB, self).__init__() |
|
self.dwt = DWT() |
|
self.iwt = IWT() |
|
|
|
modules_body = \ |
|
[ |
|
conv(n_feat*2, n_feat, kernel_size, bias=bias), |
|
act, |
|
conv(n_feat, n_feat*2, kernel_size, bias=bias) |
|
] |
|
self.body = nn.Sequential(*modules_body) |
|
|
|
self.WSA = SALayer() |
|
self.WCA = CALayer(n_feat*2, reduction, bias=bias) |
|
|
|
self.conv1x1 = nn.Conv2d(n_feat*4, n_feat*2, kernel_size=1, bias=bias) |
|
self.conv3x3 = nn.Conv2d(n_feat, o_feat, kernel_size=3, padding=1, bias=bias) |
|
self.activate = act |
|
self.conv1x1_final = nn.Conv2d(n_feat, o_feat, kernel_size=1, bias=bias) |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
|
|
wavelet_path_in, identity_path = torch.chunk(x, 2, dim=1) |
|
|
|
|
|
x_dwt = self.dwt(wavelet_path_in) |
|
res = self.body(x_dwt) |
|
branch_sa = self.WSA(res) |
|
branch_ca = self.WCA(res) |
|
res = torch.cat([branch_sa, branch_ca], dim=1) |
|
res = self.conv1x1(res) + x_dwt |
|
wavelet_path = self.iwt(res) |
|
|
|
out = torch.cat([wavelet_path, identity_path], dim=1) |
|
out = self.activate(self.conv3x3(out)) |
|
out += self.conv1x1_final(residual) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
class HWMNet(nn.Module): |
|
def __init__(self, in_chn=3, wf=64, depth=4): |
|
super(HWMNet, self).__init__() |
|
self.depth = depth |
|
self.down_path = nn.ModuleList() |
|
self.bili_down = bili_resize(0.5) |
|
self.conv_01 = nn.Conv2d(in_chn, wf, 3, 1, 1) |
|
|
|
|
|
prev_channels = 0 |
|
for i in range(depth): |
|
downsample = True if (i + 1) < depth else False |
|
self.down_path.append(UNetConvBlock(prev_channels + wf, (2 ** i) * wf, downsample)) |
|
prev_channels = (2 ** i) * wf |
|
|
|
|
|
self.up_path = nn.ModuleList() |
|
self.skip_conv = nn.ModuleList() |
|
self.conv_up = nn.ModuleList() |
|
self.bottom_conv = nn.Conv2d(prev_channels, wf, 3, 1, 1) |
|
self.bottom_up = bili_resize(2 ** (depth-1)) |
|
|
|
for i in reversed(range(depth - 1)): |
|
self.up_path.append(UNetUpBlock(prev_channels, (2 ** i) * wf)) |
|
self.skip_conv.append(nn.Conv2d((2 ** i) * wf, (2 ** i) * wf, 3, 1, 1)) |
|
self.conv_up.append(nn.Sequential(*[bili_resize(2 ** i), nn.Conv2d((2 ** i) * wf, wf, 3, 1, 1)])) |
|
prev_channels = (2 ** i) * wf |
|
|
|
self.final_ff = SKFF(in_channels=wf, height=depth) |
|
self.last = conv3x3(prev_channels, in_chn, bias=True) |
|
|
|
def forward(self, x): |
|
img = x |
|
scale_img = img |
|
|
|
|
|
x1 = self.conv_01(img) |
|
encs = [] |
|
|
|
|
|
for i, down in enumerate(self.down_path): |
|
if i == 0: |
|
x1, x1_up = down(x1) |
|
encs.append(x1_up) |
|
elif (i + 1) < self.depth: |
|
scale_img = self.bili_down(scale_img) |
|
left_bar = self.conv_01(scale_img) |
|
x1 = torch.cat([x1, left_bar], dim=1) |
|
x1, x1_up = down(x1) |
|
encs.append(x1_up) |
|
else: |
|
scale_img = self.bili_down(scale_img) |
|
left_bar = self.conv_01(scale_img) |
|
x1 = torch.cat([x1, left_bar], dim=1) |
|
x1 = down(x1) |
|
|
|
|
|
ms_result = [self.bottom_up(self.bottom_conv(x1))] |
|
for i, up in enumerate(self.up_path): |
|
x1 = up(x1, self.skip_conv[i](encs[-i - 1])) |
|
ms_result.append(self.conv_up[i](x1)) |
|
|
|
msff_result = self.final_ff(ms_result) |
|
|
|
|
|
out_1 = self.last(msff_result) + img |
|
|
|
return out_1 |
|
|
|
if __name__ == "__main__": |
|
input = torch.ones(1, 3, 400, 592, dtype=torch.float, requires_grad=False).cuda() |
|
|
|
model = HWMNet(in_chn=3, wf=96, depth=4).cuda() |
|
out = model(input) |
|
flops, params = profile(model, inputs=(input,)) |
|
|
|
|
|
|
|
|
|
|
|
print('input shape:', input.shape) |
|
print('parameters:', params/1e6) |
|
print('flops', flops/1e9) |
|
print('output shape', out.shape) |
|
|