|
import torch
|
|
import torch.nn as nn
|
|
|
|
from . import common
|
|
from .ResNet import ResNet
|
|
from lambda_network import LambdaLayer
|
|
|
|
def build_model(args):
|
|
return MSResNet(args)
|
|
|
|
class conv_end(nn.Module):
|
|
def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
|
|
super(conv_end, self).__init__()
|
|
|
|
modules = [
|
|
common.default_conv(in_channels, out_channels, kernel_size),
|
|
nn.PixelShuffle(ratio)
|
|
]
|
|
|
|
self.uppath = nn.Sequential(*modules)
|
|
|
|
def forward(self, x):
|
|
return self.uppath(x)
|
|
|
|
class MSResNet(nn.Module):
|
|
def __init__(self, args):
|
|
super(MSResNet, self).__init__()
|
|
|
|
self.rgb_range = args.rgb_range
|
|
self.mean = self.rgb_range / 2
|
|
|
|
self.n_resblocks = args.n_resblocks
|
|
self.n_feats = args.n_feats
|
|
self.kernel_size = args.kernel_size
|
|
|
|
self.n_scales = args.n_scales
|
|
|
|
self.body_models = nn.ModuleList([
|
|
ResNet(args, 3, 3, mean_shift=False),
|
|
])
|
|
self.lambda_models = nn.ModuleList([
|
|
LambdaLayer(
|
|
dim = 32,
|
|
dim_out = 32,
|
|
n = 64 * 64,
|
|
dim_k = 16,
|
|
heads = 4,
|
|
dim_u = 1
|
|
)
|
|
])
|
|
for _ in range(1, self.n_scales):
|
|
self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False))
|
|
|
|
self.conv_end_models = nn.ModuleList([None])
|
|
for _ in range(1, self.n_scales):
|
|
self.conv_end_models += [conv_end(3, 12)]
|
|
|
|
def forward(self, input_pyramid):
|
|
|
|
scales = range(self.n_scales-1, -1, -1)
|
|
|
|
for s in scales:
|
|
input_pyramid[s] = input_pyramid[s] - self.mean
|
|
|
|
output_pyramid = [None] * self.n_scales
|
|
|
|
input_s = input_pyramid[-1]
|
|
for s in scales:
|
|
output_pyramid[s] = self.body_models[s](input_s)
|
|
if s > 0:
|
|
up_feat = self.conv_end_models[s](output_pyramid[s])
|
|
input_s = torch.cat((input_pyramid[s-1], up_feat), 1)
|
|
|
|
for s in scales:
|
|
output_pyramid[s] = output_pyramid[s] + self.mean
|
|
|
|
return output_pyramid
|
|
|