lambdanet / backup /deblur /src /model /RecLamResNet.py
hyliu's picture
Upload folder using huggingface_hub
8ec10cf verified
import torch
import torch.nn as nn
from . import common
from .LamResNet import ResNet
def build_model(args):
return RecLamResNet(args)
class conv_end(nn.Module):
def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2):
super(conv_end, self).__init__()
modules = [
common.default_conv(in_channels, out_channels, kernel_size),
nn.PixelShuffle(ratio),
]
self.uppath = nn.Sequential(*modules)
def forward(self, x):
return self.uppath(x)
class RecLamResNet(nn.Module):
def __init__(self, args):
super(RecLamResNet, self).__init__()
self.rgb_range = args.rgb_range
self.mean = self.rgb_range / 2
self.is_detach=args.detach
self.n_resblocks = args.n_resblocks
self.n_feats = args.n_feats
self.kernel_size = args.kernel_size
self.n_scales = args.n_scales
self.body_model = ResNet(args, 3, 3, mean_shift=False)
def forward(self, input_lst):
# we use a reversed list for better compact
input_lst[0] = input_lst[0] - self.mean
output_lst = [None] * self.n_scales
last_output = input_lst[0]
for i in range(self.n_scales):
if self.is_detach:
last_output=last_output.detach()
output = self.body_model(last_output) + last_output
output_lst[self.n_scales-i-1] = output + self.mean
last_output = output
return output_lst