|
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
def default_conv(in_channels, out_channels, kernel_size, bias=True, groups=1):
|
|
return nn.Conv2d(
|
|
in_channels, out_channels, kernel_size,
|
|
padding=(kernel_size // 2), bias=bias, groups=groups)
|
|
|
|
def default_norm(n_feats):
|
|
return nn.BatchNorm2d(n_feats)
|
|
|
|
def default_act():
|
|
return nn.ReLU(True)
|
|
|
|
def empty_h(x, n_feats):
|
|
'''
|
|
create an empty hidden state
|
|
|
|
input
|
|
x: B x T x 3 x H x W
|
|
|
|
output
|
|
h: B x C x H/4 x W/4
|
|
'''
|
|
b = x.size(0)
|
|
h, w = x.size()[-2:]
|
|
return x.new_zeros((b, n_feats, h//4, w//4))
|
|
|
|
class Normalization(nn.Conv2d):
|
|
"""Normalize input tensor value with convolutional layer"""
|
|
def __init__(self, mean=(0, 0, 0), std=(1, 1, 1)):
|
|
super(Normalization, self).__init__(3, 3, kernel_size=1)
|
|
tensor_mean = torch.Tensor(mean)
|
|
tensor_inv_std = torch.Tensor(std).reciprocal()
|
|
|
|
self.weight.data = torch.eye(3).mul(tensor_inv_std).view(3, 3, 1, 1)
|
|
self.bias.data = torch.Tensor(-tensor_mean.mul(tensor_inv_std))
|
|
|
|
for params in self.parameters():
|
|
params.requires_grad = False
|
|
|
|
class BasicBlock(nn.Sequential):
|
|
"""Convolution layer + Activation layer"""
|
|
def __init__(
|
|
self, in_channels, out_channels, kernel_size, bias=True,
|
|
conv=default_conv, norm=False, act=default_act):
|
|
|
|
modules = []
|
|
modules.append(
|
|
conv(in_channels, out_channels, kernel_size, bias=bias))
|
|
if norm: modules.append(norm(out_channels))
|
|
if act: modules.append(act())
|
|
|
|
super(BasicBlock, self).__init__(*modules)
|
|
|
|
class ResBlock(nn.Module):
|
|
def __init__(
|
|
self, n_feats, kernel_size, bias=True,
|
|
conv=default_conv, norm=False, act=default_act):
|
|
|
|
super(ResBlock, self).__init__()
|
|
|
|
modules = []
|
|
for i in range(2):
|
|
modules.append(conv(n_feats, n_feats, kernel_size, bias=bias))
|
|
if norm: modules.append(norm(n_feats))
|
|
if act and i == 0: modules.append(act())
|
|
|
|
self.body = nn.Sequential(*modules)
|
|
|
|
def forward(self, x):
|
|
res = self.body(x)
|
|
res += x
|
|
|
|
return res
|
|
|
|
class ResBlock_mobile(nn.Module):
|
|
def __init__(
|
|
self, n_feats, kernel_size, bias=True,
|
|
conv=default_conv, norm=False, act=default_act, dropout=False):
|
|
|
|
super(ResBlock_mobile, self).__init__()
|
|
|
|
modules = []
|
|
for i in range(2):
|
|
modules.append(conv(n_feats, n_feats, kernel_size, bias=False, groups=n_feats))
|
|
modules.append(conv(n_feats, n_feats, 1, bias=False))
|
|
if dropout and i == 0: modules.append(nn.Dropout2d(dropout))
|
|
if norm: modules.append(norm(n_feats))
|
|
if act and i == 0: modules.append(act())
|
|
|
|
self.body = nn.Sequential(*modules)
|
|
|
|
def forward(self, x):
|
|
res = self.body(x)
|
|
res += x
|
|
|
|
return res
|
|
|
|
class Upsampler(nn.Sequential):
|
|
def __init__(
|
|
self, scale, n_feats, bias=True,
|
|
conv=default_conv, norm=False, act=False):
|
|
|
|
modules = []
|
|
if (scale & (scale - 1)) == 0:
|
|
for _ in range(int(math.log(scale, 2))):
|
|
modules.append(conv(n_feats, 4 * n_feats, 3, bias))
|
|
modules.append(nn.PixelShuffle(2))
|
|
if norm: modules.append(norm(n_feats))
|
|
if act: modules.append(act())
|
|
elif scale == 3:
|
|
modules.append(conv(n_feats, 9 * n_feats, 3, bias))
|
|
modules.append(nn.PixelShuffle(3))
|
|
if norm: modules.append(norm(n_feats))
|
|
if act: modules.append(act())
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
super(Upsampler, self).__init__(*modules)
|
|
|
|
|
|
class PixelSort(nn.Module):
|
|
"""The inverse operation of PixelShuffle
|
|
Reduces the spatial resolution, increasing the number of channels.
|
|
Currently, scale 0.5 is supported only.
|
|
Later, torch.nn.functional.pixel_sort may be implemented.
|
|
Reference:
|
|
http://pytorch.org/docs/0.3.0/_modules/torch/nn/modules/pixelshuffle.html#PixelShuffle
|
|
http://pytorch.org/docs/0.3.0/_modules/torch/nn/functional.html#pixel_shuffle
|
|
"""
|
|
def __init__(self, upscale_factor=0.5):
|
|
super(PixelSort, self).__init__()
|
|
self.upscale_factor = upscale_factor
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.size()
|
|
x = x.view(b, c, 2, 2, h // 2, w // 2)
|
|
x = x.permute(0, 1, 5, 3, 2, 4).contiguous()
|
|
x = x.view(b, 4 * c, h // 2, w // 2)
|
|
|
|
return x
|
|
|
|
class Downsampler(nn.Sequential):
|
|
def __init__(
|
|
self, scale, n_feats, bias=True,
|
|
conv=default_conv, norm=False, act=False):
|
|
|
|
modules = []
|
|
if scale == 0.5:
|
|
modules.append(PixelSort())
|
|
modules.append(conv(4 * n_feats, n_feats, 3, bias))
|
|
if norm: modules.append(norm(n_feats))
|
|
if act: modules.append(act())
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
super(Downsampler, self).__init__(*modules)
|
|
|
|
|