import torch import torch.nn as nn import torch.nn.functional as F class CNR2d(nn.Module): def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, norm='bnorm', relu=0.0, drop=[], bias=[]): super().__init__() if bias == []: if norm == 'bnorm': bias = False else: bias = True layers = [] layers += [Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)] if norm != []: layers += [Norm2d(nch_out, norm)] if relu != []: layers += [ReLU(relu)] if drop != []: layers += [nn.Dropout2d(drop)] self.cbr = nn.Sequential(*layers) def forward(self, x): return self.cbr(x) class DECNR2d(nn.Module): def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, norm='bnorm', relu=0.0, drop=[], bias=[]): super().__init__() if bias == []: if norm == 'bnorm': bias = False else: bias = True layers = [] layers += [Deconv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias)] if norm != []: layers += [Norm2d(nch_out, norm)] if relu != []: layers += [ReLU(relu)] if drop != []: layers += [nn.Dropout2d(drop)] self.decbr = nn.Sequential(*layers) def forward(self, x): return self.decbr(x) class ResBlock(nn.Module): def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): super().__init__() if bias == []: if norm == 'bnorm': bias = False else: bias = True layers = [] # 1st conv layers += [Padding(padding, padding_mode=padding_mode)] layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] if drop != []: layers += [nn.Dropout2d(drop)] # 2nd conv layers += [Padding(padding, padding_mode=padding_mode)] layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] self.resblk = nn.Sequential(*layers) def forward(self, x): return x + self.resblk(x) class ResBlock_cat(nn.Module): def __init__(self, nch_in, nch_out, kernel_size=3, stride=1, padding=1, padding_mode='reflection', norm='inorm', relu=0.0, drop=[], bias=[]): super().__init__() if bias == []: if norm == 'bnorm': bias = False else: bias = True layers = [] # 1st conv layers += [Padding(padding, padding_mode=padding_mode)] layers += [CNR2d(nch_in*2, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=relu)] if drop != []: layers += [nn.Dropout2d(drop)] # 2nd conv layers += [Padding(padding, padding_mode=padding_mode)] layers += [CNR2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=0, norm=norm, relu=[])] self.resblk = nn.Sequential(*layers) def forward(self,x,y): output = x + self.resblk(torch.cat([x,y],dim=1)) return output class LinearBlock(nn.Module): def __init__(self, input_dim, output_dim, norm='none', activation='relu'): super(LinearBlock, self).__init__() use_bias = True # initialize fully connected layer if norm == 'sn': self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) else: self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) # initialize normalization norm_dim = output_dim if norm == 'bn': self.norm = nn.BatchNorm1d(norm_dim) elif norm == 'in': self.norm = nn.InstanceNorm1d(norm_dim) elif norm == 'ln': self.norm = LayerNorm(norm_dim) elif norm == 'none' or norm == 'sn': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU(inplace=True) elif activation == 'lrelu': self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == 'prelu': self.activation = nn.PReLU() elif activation == 'selu': self.activation = nn.SELU(inplace=True) elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) def forward(self, x): out = self.fc(x) if self.norm: out = self.norm(out) if self.activation: out = self.activation(out) return out class MLP(nn.Module): def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): super(MLP, self).__init__() self.model = [] self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] for i in range(n_blk - 2): self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations self.model = nn.Sequential(*self.model) def forward(self, x): return self.model(x.view(x.size(0), -1)) class CNR1d(nn.Module): def __init__(self, nch_in, nch_out, norm='bnorm', relu=0.0, drop=[]): super().__init__() if norm == 'bnorm': bias = False else: bias = True layers = [] layers += [nn.Linear(nch_in, nch_out, bias=bias)] if norm != []: layers += [Norm2d(nch_out, norm)] if relu != []: layers += [ReLU(relu)] if drop != []: layers += [nn.Dropout2d(drop)] self.cbr = nn.Sequential(*layers) def forward(self, x): return self.cbr(x) class Conv2d(nn.Module): def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, bias=True): super(Conv2d, self).__init__() self.conv = nn.Conv2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) def forward(self, x): return self.conv(x) class Deconv2d(nn.Module): def __init__(self, nch_in, nch_out, kernel_size=4, stride=1, padding=1, output_padding=0, bias=True): super(Deconv2d, self).__init__() self.deconv = nn.ConvTranspose2d(nch_in, nch_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=bias) # layers = [nn.Upsample(scale_factor=2, mode='bilinear'), # nn.ReflectionPad2d(1), # nn.Conv2d(nch_in , nch_out, kernel_size=3, stride=1, padding=0)] # # self.deconv = nn.Sequential(*layers) def forward(self, x): return self.deconv(x) class Linear(nn.Module): def __init__(self, nch_in, nch_out): super(Linear, self).__init__() self.linear = nn.Linear(nch_in, nch_out) def forward(self, x): return self.linear(x) class Norm2d(nn.Module): def __init__(self, nch, norm_mode): super(Norm2d, self).__init__() if norm_mode == 'bnorm': self.norm = nn.BatchNorm2d(nch) elif norm_mode == 'inorm': self.norm = nn.InstanceNorm2d(nch) def forward(self, x): return self.norm(x) class ReLU(nn.Module): def __init__(self, relu): super(ReLU, self).__init__() if relu > 0: self.relu = nn.LeakyReLU(relu, True) elif relu == 0: self.relu = nn.ReLU(True) def forward(self, x): return self.relu(x) class Padding(nn.Module): def __init__(self, padding, padding_mode='zeros', value=0): super(Padding, self).__init__() if padding_mode == 'reflection': self. padding = nn.ReflectionPad2d(padding) elif padding_mode == 'replication': self.padding = nn.ReplicationPad2d(padding) elif padding_mode == 'constant': self.padding = nn.ConstantPad2d(padding, value) elif padding_mode == 'zeros': self.padding = nn.ZeroPad2d(padding) def forward(self, x): return self.padding(x) class Pooling2d(nn.Module): def __init__(self, nch=[], pool=2, type='avg'): super().__init__() if type == 'avg': self.pooling = nn.AvgPool2d(pool) elif type == 'max': self.pooling = nn.MaxPool2d(pool) elif type == 'conv': self.pooling = nn.Conv2d(nch, nch, kernel_size=pool, stride=pool) def forward(self, x): return self.pooling(x) class UnPooling2d(nn.Module): def __init__(self, nch=[], pool=2, type='nearest'): super().__init__() if type == 'nearest': self.unpooling = nn.Upsample(scale_factor=pool, mode='nearest', align_corners=True) elif type == 'bilinear': self.unpooling = nn.Upsample(scale_factor=pool, mode='bilinear', align_corners=True) elif type == 'conv': self.unpooling = nn.ConvTranspose2d(nch, nch, kernel_size=pool, stride=pool) def forward(self, x): return self.unpooling(x) class Concat(nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2): diffy = x2.size()[2] - x1.size()[2] diffx = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffx // 2, diffx - diffx // 2, diffy // 2, diffy - diffy // 2]) return torch.cat([x2, x1], dim=1) class TV1dLoss(nn.Module): def __init__(self): super(TV1dLoss, self).__init__() def forward(self, input): # loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ # torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) loss = torch.mean(torch.abs(input[:, :-1] - input[:, 1:])) return loss class TV2dLoss(nn.Module): def __init__(self): super(TV2dLoss, self).__init__() def forward(self, input): loss = torch.mean(torch.abs(input[:, :, :, :-1] - input[:, :, :, 1:])) + \ torch.mean(torch.abs(input[:, :, :-1, :] - input[:, :, 1:, :])) return loss class SSIM2dLoss(nn.Module): def __init__(self): super(SSIM2dLoss, self).__init__() def forward(self, input, targer): loss = 0 return loss