image2sketch / models /layer.py
sharazAhm890's picture
init
b4f7b8c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNR2d(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]):
super().__init__()
if bias == []:
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
if norm != []:
layers += [Norm2d(nch_out, norm)]
if relu != []:
layers += [ReLU(relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
self.cbr = nn.Sequential(*layers)
def forward(self, x):
return self.cbr(x)
class DECNR2d(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]):
super().__init__()
if bias == []:
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)]
if norm != []:
layers += [Norm2d(nch_out, norm)]
if relu != []:
layers += [ReLU(relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
self.decbr = nn.Sequential(*layers)
def forward(self, x):
return self.decbr(x)
class ResBlock(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
super().__init__()
if bias == []:
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
# 1st conv
layers += [Padding(padding, padding_mode=padding_mode)]
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
# 2nd conv
layers += [Padding(padding, padding_mode=padding_mode)]
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
self.resblk = nn.Sequential(*layers)
def forward(self, x):
return x + self.resblk(x)
class ResBlock_cat(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]):
super().__init__()
if bias == []:
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
# 1st conv
layers += [Padding(padding, padding_mode=padding_mode)]
layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
# 2nd conv
layers += [Padding(padding, padding_mode=padding_mode)]
layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])]
self.resblk = nn.Sequential(*layers)
def forward(self,x,y):
output = x + self.resblk(torch.cat([x,y],dim=1))
return output
class LinearBlock(nn.Module):
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
super(LinearBlock, self).__init__()
use_bias = True
# initialize fully connected layer
if norm == 'sn':
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
else:
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm1d(norm_dim)
elif norm == 'in':
self.norm = nn.InstanceNorm1d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
def forward(self, x):
out = self.fc(x)
if self.norm:
out = self.norm(out)
if self.activation:
out = self.activation(out)
return out
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
super(MLP, self).__init__()
self.model = []
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
for i in range(n_blk - 2):
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
self.model = nn.Sequential(*self.model)
def forward(self, x):
return self.model(x.view(x.size(0), -1))
class CNR1d(nn.Module):
def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]):
super().__init__()
if norm == 'bnorm':
bias = False
else:
bias = True
layers = []
layers += [nn.Linear(nch_in, nch_out, bias=bias)]
if norm != []:
layers += [Norm2d(nch_out, norm)]
if relu != []:
layers += [ReLU(relu)]
if drop != []:
layers += [nn.Dropout2d(drop)]
self.cbr = nn.Sequential(*layers)
def forward(self, x):
return self.cbr(x)
class Conv2d(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True):
super(Conv2d, self).__init__()
self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
def forward(self, x):
return self.conv(x)
class Deconv2d(nn.Module):
def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True):
super(Deconv2d, self).__init__()
self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)
# layers = [nn.Upsample(scale_factor=2, mode='bilinear'),
# nn.ReflectionPad2d(1),
# nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)]
#
# self.deconv = nn.Sequential(*layers)
def forward(self, x):
return self.deconv(x)
class Linear(nn.Module):
def __init__(self, nch_in, nch_out):
super(Linear, self).__init__()
self.linear = nn.Linear(nch_in, nch_out)
def forward(self, x):
return self.linear(x)
class Norm2d(nn.Module):
def __init__(self, nch, norm_mode):
super(Norm2d, self).__init__()
if norm_mode == 'bnorm':
self.norm = nn.BatchNorm2d(nch)
elif norm_mode == 'inorm':
self.norm = nn.InstanceNorm2d(nch)
def forward(self, x):
return self.norm(x)
class ReLU(nn.Module):
def __init__(self, relu):
super(ReLU, self).__init__()
if relu > 0:
self.relu = nn.LeakyReLU(relu, True)
elif relu == 0:
self.relu = nn.ReLU(True)
def forward(self, x):
return self.relu(x)
class Padding(nn.Module):
def __init__(self, padding, padding_mode='zeros', value=0):
super(Padding, self).__init__()
if padding_mode == 'reflection':
self. padding = nn.ReflectionPad2d(padding)
elif padding_mode == 'replication':
self.padding = nn.ReplicationPad2d(padding)
elif padding_mode == 'constant':
self.padding = nn.ConstantPad2d(padding, value)
elif padding_mode == 'zeros':
self.padding = nn.ZeroPad2d(padding)
def forward(self, x):
return self.padding(x)
class Pooling2d(nn.Module):
def __init__(self, nch=[], pool=2, type='avg'):
super().__init__()
if type == 'avg':
self.pooling = nn.AvgPool2d(pool)
elif type == 'max':
self.pooling = nn.MaxPool2d(pool)
elif type == 'conv':
self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool)
def forward(self, x):
return self.pooling(x)
class UnPooling2d(nn.Module):
def __init__(self, nch=[], pool=2, type='nearest'):
super().__init__()
if type == 'nearest':
self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True)
elif type == 'bilinear':
self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True)
elif type == 'conv':
self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool)
def forward(self, x):
return self.unpooling(x)
class Concat(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
diffy = x2.size()[2] - x1.size()[2]
diffx = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2,
diffy // 2, diffy - diffy // 2])
return torch.cat([x2, x1], dim=1)
class TV1dLoss(nn.Module):
def __init__(self):
super(TV1dLoss, self).__init__()
def forward(self, input):
# loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
# torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:]))
return loss
class TV2dLoss(nn.Module):
def __init__(self):
super(TV2dLoss, self).__init__()
def forward(self, input):
loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \
torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :]))
return loss
class SSIM2dLoss(nn.Module):
def __init__(self):
super(SSIM2dLoss, self).__init__()
def forward(self, input, targer):
loss = 0
return loss