# https://github.com/XPixelGroup/ClassSR import math import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): '''Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| ''' def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True)): m = [nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size//2), stride=stride, bias=bias) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) class EResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, group=1): super(EResidualBlock, self).__init__() self.body = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 1, 1, 0), ) def forward(self, x): out = self.body(x) out = F.relu(out + x) return out class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) class UpsampleBlock(nn.Module): def __init__(self, n_channels, scale, multi_scale, group=1): super(UpsampleBlock, self).__init__() if multi_scale: self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) else: self.up = _UpsampleBlock(n_channels, scale=scale, group=group) self.multi_scale = multi_scale def forward(self, x, scale): if self.multi_scale: if scale == 2: return self.up2(x) elif scale == 3: return self.up3(x) elif scale == 4: return self.up4(x) else: return self.up(x) class _UpsampleBlock(nn.Module): def __init__(self, n_channels, scale, group=1): super(_UpsampleBlock, self).__init__() modules = [] if scale == 2 or scale == 4 or scale == 8: for _ in range(int(math.log(scale, 2))): modules += [nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] modules += [nn.PixelShuffle(2)] elif scale == 3: modules += [nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] modules += [nn.PixelShuffle(3)] self.body = nn.Sequential(*modules) def forward(self, x): out = self.body(x) return out