yangheng's picture
init
9842c28
raw
history blame
13.3 kB
import json
from collections import OrderedDict
from math import exp
from .Common import *
# +++++++++++++++++++++++++++++++++++++
# FP16 Training
# -------------------------------------
# Modified from Nvidia/Apex
# https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/fp16util.py
class tofp16(nn.Module):
def __init__(self):
super(tofp16, self).__init__()
def forward(self, input):
if input.is_cuda:
return input.half()
else: # PyTorch 1.0 doesn't support fp16 in CPU
return input.float()
def BN_convert_float(module):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
BN_convert_float(child)
return module
def network_to_half(network):
return nn.Sequential(tofp16(), BN_convert_float(network.half()))
# warnings.simplefilter('ignore')
# +++++++++++++++++++++++++++++++++++++
# DCSCN
# -------------------------------------
class DCSCN(BaseModule):
# https://github.com/jiny2001/dcscn-super-resolution
def __init__(
self,
color_channel=3,
up_scale=2,
feature_layers=12,
first_feature_filters=196,
last_feature_filters=48,
reconstruction_filters=128,
up_sampler_filters=32,
):
super(DCSCN, self).__init__()
self.total_feature_channels = 0
self.total_reconstruct_filters = 0
self.upscale = up_scale
self.act_fn = nn.SELU(inplace=False)
self.feature_block = self.make_feature_extraction_block(
color_channel, feature_layers, first_feature_filters, last_feature_filters
)
self.reconstruction_block = self.make_reconstruction_block(
reconstruction_filters
)
self.up_sampler = self.make_upsampler(up_sampler_filters, color_channel)
self.selu_init_params()
def selu_init_params(self):
for i in self.modules():
if isinstance(i, nn.Conv2d):
i.weight.data.normal_(0.0, 1.0 / sqrt(i.weight.numel()))
if i.bias is not None:
i.bias.data.fill_(0)
def conv_block(self, in_channel, out_channel, kernel_size):
m = OrderedDict(
[
# ("Padding", nn.ReplicationPad2d((kernel_size - 1) // 2)),
(
"Conv2d",
nn.Conv2d(
in_channel,
out_channel,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
),
),
("Activation", self.act_fn),
]
)
return nn.Sequential(m)
def make_feature_extraction_block(
self, color_channel, num_layers, first_filters, last_filters
):
# input layer
feature_block = [
("Feature 1", self.conv_block(color_channel, first_filters, 3))
]
# exponential decay
# rest layers
alpha_rate = log(first_filters / last_filters) / (num_layers - 1)
filter_nums = [
round(first_filters * exp(-alpha_rate * i)) for i in range(num_layers)
]
self.total_feature_channels = sum(filter_nums)
layer_filters = [
[filter_nums[i], filter_nums[i + 1], 3] for i in range(num_layers - 1)
]
feature_block.extend(
[
("Feature {}".format(index + 2), self.conv_block(*x))
for index, x in enumerate(layer_filters)
]
)
return nn.Sequential(OrderedDict(feature_block))
def make_reconstruction_block(self, num_filters):
B1 = self.conv_block(self.total_feature_channels, num_filters // 2, 1)
B2 = self.conv_block(num_filters // 2, num_filters, 3)
m = OrderedDict(
[
("A", self.conv_block(self.total_feature_channels, num_filters, 1)),
("B", nn.Sequential(*[B1, B2])),
]
)
self.total_reconstruct_filters = num_filters * 2
return nn.Sequential(m)
def make_upsampler(self, out_channel, color_channel):
out = out_channel * self.upscale**2
m = OrderedDict(
[
(
"Conv2d_block",
self.conv_block(self.total_reconstruct_filters, out, kernel_size=3),
),
("PixelShuffle", nn.PixelShuffle(self.upscale)),
(
"Conv2d",
nn.Conv2d(
out_channel, color_channel, kernel_size=3, padding=1, bias=False
),
),
]
)
return nn.Sequential(m)
def forward(self, x):
# residual learning
lr, lr_up = x
feature = []
for layer in self.feature_block.children():
lr = layer(lr)
feature.append(lr)
feature = torch.cat(feature, dim=1)
reconstruction = [
layer(feature) for layer in self.reconstruction_block.children()
]
reconstruction = torch.cat(reconstruction, dim=1)
lr = self.up_sampler(reconstruction)
return lr + lr_up
# +++++++++++++++++++++++++++++++++++++
# CARN
# -------------------------------------
class CARN_Block(BaseModule):
def __init__(
self,
channels,
kernel_size=3,
padding=1,
dilation=1,
groups=1,
activation=nn.SELU(),
repeat=3,
SEBlock=False,
conv=nn.Conv2d,
single_conv_size=1,
single_conv_group=1,
):
super(CARN_Block, self).__init__()
m = []
for i in range(repeat):
m.append(
ResidualFixBlock(
channels,
channels,
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
groups=groups,
activation=activation,
conv=conv,
)
)
if SEBlock:
m.append(SpatialChannelSqueezeExcitation(channels, reduction=channels))
self.blocks = nn.Sequential(*m)
self.singles = nn.Sequential(
*[
ConvBlock(
channels * (i + 2),
channels,
kernel_size=single_conv_size,
padding=(single_conv_size - 1) // 2,
groups=single_conv_group,
activation=activation,
conv=conv,
)
for i in range(repeat)
]
)
def forward(self, x):
c0 = x
for block, single in zip(self.blocks, self.singles):
b = block(x)
c0 = c = torch.cat([c0, b], dim=1)
x = single(c)
return x
class CARN(BaseModule):
# Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network
# https://github.com/nmhkahn/CARN-pytorch
def __init__(
self,
color_channels=3,
mid_channels=64,
scale=2,
activation=nn.SELU(),
num_blocks=3,
conv=nn.Conv2d,
):
super(CARN, self).__init__()
self.color_channels = color_channels
self.mid_channels = mid_channels
self.scale = scale
self.entry_block = ConvBlock(
color_channels,
mid_channels,
kernel_size=3,
padding=1,
activation=activation,
conv=conv,
)
self.blocks = nn.Sequential(
*[
CARN_Block(
mid_channels,
kernel_size=3,
padding=1,
activation=activation,
conv=conv,
single_conv_size=1,
single_conv_group=1,
)
for _ in range(num_blocks)
]
)
self.singles = nn.Sequential(
*[
ConvBlock(
mid_channels * (i + 2),
mid_channels,
kernel_size=1,
padding=0,
activation=activation,
conv=conv,
)
for i in range(num_blocks)
]
)
self.upsampler = UpSampleBlock(
mid_channels, scale=scale, activation=activation, conv=conv
)
self.exit_conv = conv(mid_channels, color_channels, kernel_size=3, padding=1)
def forward(self, x):
x = self.entry_block(x)
c0 = x
for block, single in zip(self.blocks, self.singles):
b = block(x)
c0 = c = torch.cat([c0, b], dim=1)
x = single(c)
x = self.upsampler(x)
out = self.exit_conv(x)
return out
class CARN_V2(CARN):
def __init__(
self,
color_channels=3,
mid_channels=64,
scale=2,
activation=nn.LeakyReLU(0.1),
SEBlock=True,
conv=nn.Conv2d,
atrous=(1, 1, 1),
repeat_blocks=3,
single_conv_size=3,
single_conv_group=1,
):
super(CARN_V2, self).__init__(
color_channels=color_channels,
mid_channels=mid_channels,
scale=scale,
activation=activation,
conv=conv,
)
num_blocks = len(atrous)
m = []
for i in range(num_blocks):
m.append(
CARN_Block(
mid_channels,
kernel_size=3,
padding=1,
dilation=1,
activation=activation,
SEBlock=SEBlock,
conv=conv,
repeat=repeat_blocks,
single_conv_size=single_conv_size,
single_conv_group=single_conv_group,
)
)
self.blocks = nn.Sequential(*m)
self.singles = nn.Sequential(
*[
ConvBlock(
mid_channels * (i + 2),
mid_channels,
kernel_size=single_conv_size,
padding=(single_conv_size - 1) // 2,
groups=single_conv_group,
activation=activation,
conv=conv,
)
for i in range(num_blocks)
]
)
def forward(self, x):
x = self.entry_block(x)
c0 = x
res = x
for block, single in zip(self.blocks, self.singles):
b = block(x)
c0 = c = torch.cat([c0, b], dim=1)
x = single(c)
x = x + res
x = self.upsampler(x)
out = self.exit_conv(x)
return out
# +++++++++++++++++++++++++++++++++++++
# original Waifu2x model
# -------------------------------------
class UpConv_7(BaseModule):
# https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311
def __init__(self):
super(UpConv_7, self).__init__()
self.act_fn = nn.LeakyReLU(0.1, inplace=False)
self.offset = 7 # because of 0 padding
from torch.nn import ZeroPad2d
self.pad = ZeroPad2d(self.offset)
m = [
nn.Conv2d(3, 16, 3, 1, 0),
self.act_fn,
nn.Conv2d(16, 32, 3, 1, 0),
self.act_fn,
nn.Conv2d(32, 64, 3, 1, 0),
self.act_fn,
nn.Conv2d(64, 128, 3, 1, 0),
self.act_fn,
nn.Conv2d(128, 128, 3, 1, 0),
self.act_fn,
nn.Conv2d(128, 256, 3, 1, 0),
self.act_fn,
# in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=
nn.ConvTranspose2d(256, 3, kernel_size=4, stride=2, padding=3, bias=False),
]
self.Sequential = nn.Sequential(*m)
def load_pre_train_weights(self, json_file):
with open(json_file) as f:
weights = json.load(f)
box = []
for i in weights:
box.append(i["weight"])
box.append(i["bias"])
own_state = self.state_dict()
for index, (name, param) in enumerate(own_state.items()):
own_state[name].copy_(torch.FloatTensor(box[index]))
def forward(self, x):
x = self.pad(x)
return self.Sequential.forward(x)
class Vgg_7(UpConv_7):
def __init__(self):
super(Vgg_7, self).__init__()
self.act_fn = nn.LeakyReLU(0.1, inplace=False)
self.offset = 7
m = [
nn.Conv2d(3, 32, 3, 1, 0),
self.act_fn,
nn.Conv2d(32, 32, 3, 1, 0),
self.act_fn,
nn.Conv2d(32, 64, 3, 1, 0),
self.act_fn,
nn.Conv2d(64, 64, 3, 1, 0),
self.act_fn,
nn.Conv2d(64, 128, 3, 1, 0),
self.act_fn,
nn.Conv2d(128, 128, 3, 1, 0),
self.act_fn,
nn.Conv2d(128, 3, 3, 1, 0),
]
self.Sequential = nn.Sequential(*m)