|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
import torch.nn.functional as F |
|
|
|
|
|
def initialize_weights(net_l, scale=1): |
|
if not isinstance(net_l, list): |
|
net_l = [net_l] |
|
for net in net_l: |
|
for m in net.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in') |
|
m.weight.data *= scale |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.Linear): |
|
init.kaiming_normal_(m.weight, a=0, mode='fan_in') |
|
m.weight.data *= scale |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
elif isinstance(m, nn.BatchNorm2d): |
|
init.constant_(m.weight, 1) |
|
init.constant_(m.bias.data, 0.0) |
|
|
|
|
|
def make_layer(block, n_layers): |
|
layers = [] |
|
for _ in range(n_layers): |
|
layers.append(block()) |
|
return nn.Sequential(*layers) |
|
|
|
|
|
class ResidualBlock_noBN(nn.Module): |
|
'''Residual block w/o BN |
|
---Conv-ReLU-Conv-+- |
|
|________________| |
|
''' |
|
|
|
def __init__(self, nf=64): |
|
super(ResidualBlock_noBN, self).__init__() |
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) |
|
|
|
|
|
initialize_weights([self.conv1, self.conv2], 0.1) |
|
|
|
def forward(self, x): |
|
identity = x |
|
out = F.relu(self.conv1(x), inplace=True) |
|
out = self.conv2(out) |
|
return identity + out |
|
|
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True): |
|
return nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
padding=(kernel_size//2), bias=bias) |
|
|
|
class MeanShift(nn.Conv2d): |
|
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): |
|
super(MeanShift, self).__init__(3, 3, kernel_size=1) |
|
std = torch.Tensor(rgb_std) |
|
self.weight.data = torch.eye(3).view(3, 3, 1, 1) |
|
self.weight.data.div_(std.view(3, 1, 1, 1)) |
|
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) |
|
self.bias.data.div_(std) |
|
self.requires_grad = False |
|
|
|
class BasicBlock(nn.Sequential): |
|
def __init__( |
|
self, in_channels, out_channels, kernel_size, stride=1, bias=False, |
|
bn=True, act=nn.ReLU(True)): |
|
|
|
m = [nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
padding=(kernel_size//2), stride=stride, bias=bias) |
|
] |
|
if bn: m.append(nn.BatchNorm2d(out_channels)) |
|
if act is not None: m.append(act) |
|
super(BasicBlock, self).__init__(*m) |
|
|
|
class ResBlock(nn.Module): |
|
def __init__( |
|
self, conv, n_feat, kernel_size, |
|
bias=True, bn=False, act=nn.ReLU(True), res_scale=1): |
|
|
|
super(ResBlock, self).__init__() |
|
m = [] |
|
for i in range(2): |
|
m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) |
|
if bn: m.append(nn.BatchNorm2d(n_feat)) |
|
if i == 0: m.append(act) |
|
|
|
self.body = nn.Sequential(*m) |
|
self.res_scale = res_scale |
|
|
|
def forward(self, x): |
|
res = self.body(x).mul(self.res_scale) |
|
res += x |
|
|
|
return res |
|
|
|
class Upsampler(nn.Sequential): |
|
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): |
|
|
|
m = [] |
|
if (scale & (scale - 1)) == 0: |
|
for _ in range(int(math.log(scale, 2))): |
|
m.append(conv(n_feat, 4 * n_feat, 3, bias)) |
|
m.append(nn.PixelShuffle(2)) |
|
if bn: m.append(nn.BatchNorm2d(n_feat)) |
|
if act: m.append(act()) |
|
elif scale == 3: |
|
m.append(conv(n_feat, 9 * n_feat, 3, bias)) |
|
m.append(nn.PixelShuffle(3)) |
|
if bn: m.append(nn.BatchNorm2d(n_feat)) |
|
if act: m.append(act()) |
|
else: |
|
raise NotImplementedError |
|
|
|
super(Upsampler, self).__init__(*m) |
|
|
|
|
|
class EResidualBlock(nn.Module): |
|
def __init__(self, |
|
in_channels, out_channels, |
|
group=1): |
|
super(EResidualBlock, self).__init__() |
|
|
|
self.body = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(out_channels, out_channels, 1, 1, 0), |
|
) |
|
|
|
def forward(self, x): |
|
out = self.body(x) |
|
out = F.relu(out + x) |
|
return out |
|
|
|
|
|
class Upsampler(nn.Sequential): |
|
def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): |
|
|
|
m = [] |
|
if (scale & (scale - 1)) == 0: |
|
for _ in range(int(math.log(scale, 2))): |
|
m.append(conv(n_feat, 4 * n_feat, 3, bias)) |
|
m.append(nn.PixelShuffle(2)) |
|
if bn: m.append(nn.BatchNorm2d(n_feat)) |
|
if act: m.append(act()) |
|
elif scale == 3: |
|
m.append(conv(n_feat, 9 * n_feat, 3, bias)) |
|
m.append(nn.PixelShuffle(3)) |
|
if bn: m.append(nn.BatchNorm2d(n_feat)) |
|
if act: m.append(act()) |
|
else: |
|
raise NotImplementedError |
|
|
|
super(Upsampler, self).__init__(*m) |
|
|
|
|
|
class UpsampleBlock(nn.Module): |
|
def __init__(self, |
|
n_channels, scale, multi_scale, |
|
group=1): |
|
super(UpsampleBlock, self).__init__() |
|
|
|
if multi_scale: |
|
self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) |
|
self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) |
|
self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) |
|
else: |
|
self.up = _UpsampleBlock(n_channels, scale=scale, group=group) |
|
|
|
self.multi_scale = multi_scale |
|
|
|
def forward(self, x, scale): |
|
if self.multi_scale: |
|
if scale == 2: |
|
return self.up2(x) |
|
elif scale == 3: |
|
return self.up3(x) |
|
elif scale == 4: |
|
return self.up4(x) |
|
else: |
|
return self.up(x) |
|
|
|
|
|
class _UpsampleBlock(nn.Module): |
|
def __init__(self, |
|
n_channels, scale, |
|
group=1): |
|
super(_UpsampleBlock, self).__init__() |
|
|
|
modules = [] |
|
if scale == 2 or scale == 4 or scale == 8: |
|
for _ in range(int(math.log(scale, 2))): |
|
modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] |
|
modules += [nn.PixelShuffle(2)] |
|
elif scale == 3: |
|
modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] |
|
modules += [nn.PixelShuffle(3)] |
|
|
|
self.body = nn.Sequential(*modules) |
|
|
|
def forward(self, x): |
|
out = self.body(x) |
|
return out |