|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class CNR2d(nn.Module): |
|
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]): |
|
super().__init__() |
|
|
|
if bias == []: |
|
if norm == 'bnorm': |
|
bias = False |
|
else: |
|
bias = True |
|
|
|
layers = [] |
|
layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] |
|
|
|
if norm != []: |
|
layers += [Norm2d(nch_out, norm)] |
|
|
|
if relu != []: |
|
layers += [ReLU(relu)] |
|
|
|
if drop != []: |
|
layers += [nn.Dropout2d(drop)] |
|
|
|
self.cbr = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.cbr(x) |
|
|
|
|
|
class DECNR2d(nn.Module): |
|
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]): |
|
super().__init__() |
|
|
|
if bias == []: |
|
if norm == 'bnorm': |
|
bias = False |
|
else: |
|
bias = True |
|
|
|
layers = [] |
|
layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)] |
|
|
|
if norm != []: |
|
layers += [Norm2d(nch_out, norm)] |
|
|
|
if relu != []: |
|
layers += [ReLU(relu)] |
|
|
|
if drop != []: |
|
layers += [nn.Dropout2d(drop)] |
|
|
|
self.decbr = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.decbr(x) |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): |
|
super().__init__() |
|
|
|
if bias == []: |
|
if norm == 'bnorm': |
|
bias = False |
|
else: |
|
bias = True |
|
|
|
layers = [] |
|
|
|
|
|
layers += [Padding(padding, padding_mode=padding_mode)] |
|
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] |
|
|
|
if drop != []: |
|
layers += [nn.Dropout2d(drop)] |
|
|
|
|
|
layers += [Padding(padding, padding_mode=padding_mode)] |
|
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] |
|
|
|
self.resblk = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return x + self.resblk(x) |
|
|
|
|
|
class ResBlock_cat(nn.Module): |
|
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): |
|
super().__init__() |
|
|
|
if bias == []: |
|
if norm == 'bnorm': |
|
bias = False |
|
else: |
|
bias = True |
|
|
|
layers = [] |
|
|
|
|
|
layers += [Padding(padding, padding_mode=padding_mode)] |
|
layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] |
|
|
|
if drop != []: |
|
layers += [nn.Dropout2d(drop)] |
|
|
|
|
|
layers += [Padding(padding, padding_mode=padding_mode)] |
|
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] |
|
|
|
self.resblk = nn.Sequential(*layers) |
|
|
|
def forward(self,x,y): |
|
output = x + self.resblk(torch.cat([x,y],dim=1)) |
|
return output |
|
|
|
class LinearBlock(nn.Module): |
|
def __init__(self, input_dim, output_dim, norm='none', activation='relu'): |
|
super(LinearBlock, self).__init__() |
|
use_bias = True |
|
|
|
if norm == 'sn': |
|
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) |
|
else: |
|
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) |
|
|
|
|
|
norm_dim = output_dim |
|
if norm == 'bn': |
|
self.norm = nn.BatchNorm1d(norm_dim) |
|
elif norm == 'in': |
|
self.norm = nn.InstanceNorm1d(norm_dim) |
|
elif norm == 'ln': |
|
self.norm = LayerNorm(norm_dim) |
|
elif norm == 'none' or norm == 'sn': |
|
self.norm = None |
|
else: |
|
assert 0, "Unsupported normalization: {}".format(norm) |
|
|
|
|
|
if activation == 'relu': |
|
self.activation = nn.ReLU(inplace=True) |
|
elif activation == 'lrelu': |
|
self.activation = nn.LeakyReLU(0.2, inplace=True) |
|
elif activation == 'prelu': |
|
self.activation = nn.PReLU() |
|
elif activation == 'selu': |
|
self.activation = nn.SELU(inplace=True) |
|
elif activation == 'tanh': |
|
self.activation = nn.Tanh() |
|
elif activation == 'none': |
|
self.activation = None |
|
else: |
|
assert 0, "Unsupported activation: {}".format(activation) |
|
|
|
def forward(self, x): |
|
out = self.fc(x) |
|
if self.norm: |
|
out = self.norm(out) |
|
if self.activation: |
|
out = self.activation(out) |
|
return out |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): |
|
|
|
super(MLP, self).__init__() |
|
self.model = [] |
|
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] |
|
for i in range(n_blk - 2): |
|
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] |
|
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] |
|
self.model = nn.Sequential(*self.model) |
|
|
|
def forward(self, x): |
|
return self.model(x.view(x.size(0), -1)) |
|
|
|
class CNR1d(nn.Module): |
|
def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]): |
|
super().__init__() |
|
|
|
if norm == 'bnorm': |
|
bias = False |
|
else: |
|
bias = True |
|
|
|
layers = [] |
|
layers += [nn.Linear(nch_in, nch_out, bias=bias)] |
|
|
|
if norm != []: |
|
layers += [Norm2d(nch_out, norm)] |
|
|
|
if relu != []: |
|
layers += [ReLU(relu)] |
|
|
|
if drop != []: |
|
layers += [nn.Dropout2d(drop)] |
|
|
|
self.cbr = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.cbr(x) |
|
|
|
|
|
class Conv2d(nn.Module): |
|
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True): |
|
super(Conv2d, self).__init__() |
|
self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
|
|
class Deconv2d(nn.Module): |
|
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True): |
|
super(Deconv2d, self).__init__() |
|
self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
return self.deconv(x) |
|
|
|
|
|
class Linear(nn.Module): |
|
def __init__(self, nch_in, nch_out): |
|
super(Linear, self).__init__() |
|
self.linear = nn.Linear(nch_in, nch_out) |
|
|
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
|
|
class Norm2d(nn.Module): |
|
def __init__(self, nch, norm_mode): |
|
super(Norm2d, self).__init__() |
|
if norm_mode == 'bnorm': |
|
self.norm = nn.BatchNorm2d(nch) |
|
elif norm_mode == 'inorm': |
|
self.norm = nn.InstanceNorm2d(nch) |
|
|
|
def forward(self, x): |
|
return self.norm(x) |
|
|
|
|
|
class ReLU(nn.Module): |
|
def __init__(self, relu): |
|
super(ReLU, self).__init__() |
|
if relu > 0: |
|
self.relu = nn.LeakyReLU(relu, True) |
|
elif relu == 0: |
|
self.relu = nn.ReLU(True) |
|
|
|
def forward(self, x): |
|
return self.relu(x) |
|
|
|
|
|
class Padding(nn.Module): |
|
def __init__(self, padding, padding_mode='zeros', value=0): |
|
super(Padding, self).__init__() |
|
if padding_mode == 'reflection': |
|
self. padding = nn.ReflectionPad2d(padding) |
|
elif padding_mode == 'replication': |
|
self.padding = nn.ReplicationPad2d(padding) |
|
elif padding_mode == 'constant': |
|
self.padding = nn.ConstantPad2d(padding, value) |
|
elif padding_mode == 'zeros': |
|
self.padding = nn.ZeroPad2d(padding) |
|
|
|
def forward(self, x): |
|
return self.padding(x) |
|
|
|
|
|
class Pooling2d(nn.Module): |
|
def __init__(self, nch=[], pool=2, type='avg'): |
|
super().__init__() |
|
|
|
if type == 'avg': |
|
self.pooling = nn.AvgPool2d(pool) |
|
elif type == 'max': |
|
self.pooling = nn.MaxPool2d(pool) |
|
elif type == 'conv': |
|
self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool) |
|
|
|
def forward(self, x): |
|
return self.pooling(x) |
|
|
|
|
|
class UnPooling2d(nn.Module): |
|
def __init__(self, nch=[], pool=2, type='nearest'): |
|
super().__init__() |
|
|
|
if type == 'nearest': |
|
self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True) |
|
elif type == 'bilinear': |
|
self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True) |
|
elif type == 'conv': |
|
self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool) |
|
|
|
def forward(self, x): |
|
return self.unpooling(x) |
|
|
|
|
|
class Concat(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x1, x2): |
|
diffy = x2.size()[2] - x1.size()[2] |
|
diffx = x2.size()[3] - x1.size()[3] |
|
|
|
x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, |
|
diffy // 2, diffy - diffy // 2]) |
|
|
|
return torch.cat([x2, x1], dim=1) |
|
|
|
|
|
class TV1dLoss(nn.Module): |
|
def __init__(self): |
|
super(TV1dLoss, self).__init__() |
|
|
|
def forward(self, input): |
|
|
|
|
|
loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:])) |
|
|
|
return loss |
|
|
|
|
|
class TV2dLoss(nn.Module): |
|
def __init__(self): |
|
super(TV2dLoss, self).__init__() |
|
|
|
def forward(self, input): |
|
loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ |
|
torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) |
|
return loss |
|
|
|
|
|
class SSIM2dLoss(nn.Module): |
|
def __init__(self): |
|
super(SSIM2dLoss, self).__init__() |
|
|
|
def forward(self, input, targer): |
|
loss = 0 |
|
return loss |
|
|
|
|