Spaces:
Runtime error
Runtime error
from libs.blocks import encoder5 | |
import torch | |
import torchvision | |
import torch.nn as nn | |
from torch.nn import init | |
import torch.nn.functional as F | |
from .normalization import get_nonspade_norm_layer | |
from .blocks import encoder5 | |
import os | |
import numpy as np | |
class BaseNetwork(nn.Module): | |
def __init__(self): | |
super(BaseNetwork, self).__init__() | |
def print_network(self): | |
if isinstance(self, list): | |
self = self[0] | |
num_params = 0 | |
for param in self.parameters(): | |
num_params += param.numel() | |
print('Network [%s] was created. Total number of parameters: %.1f million. ' | |
'To see the architecture, do print(network).' | |
% (type(self).__name__, num_params / 1000000)) | |
def init_weights(self, init_type='normal', gain=0.02): | |
def init_func(m): | |
classname = m.__class__.__name__ | |
if classname.find('BatchNorm2d') != -1: | |
if hasattr(m, 'weight') and m.weight is not None: | |
init.normal_(m.weight.data, 1.0, gain) | |
if hasattr(m, 'bias') and m.bias is not None: | |
init.constant_(m.bias.data, 0.0) | |
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): | |
if init_type == 'normal': | |
init.normal_(m.weight.data, 0.0, gain) | |
elif init_type == 'xavier': | |
init.xavier_normal_(m.weight.data, gain=gain) | |
elif init_type == 'xavier_uniform': | |
init.xavier_uniform_(m.weight.data, gain=1.0) | |
elif init_type == 'kaiming': | |
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
elif init_type == 'orthogonal': | |
init.orthogonal_(m.weight.data, gain=gain) | |
elif init_type == 'none': # uses pytorch's default init method | |
m.reset_parameters() | |
else: | |
raise NotImplementedError('initialization method [%s] is not implemented' % init_type) | |
if hasattr(m, 'bias') and m.bias is not None: | |
init.constant_(m.bias.data, 0.0) | |
self.apply(init_func) | |
# propagate to children | |
for m in self.children(): | |
if hasattr(m, 'init_weights'): | |
m.init_weights(init_type, gain) | |
class NLayerDiscriminator(BaseNetwork): | |
def __init__(self): | |
super().__init__() | |
kw = 4 | |
padw = int(np.ceil((kw - 1.0) / 2)) | |
nf = 64 | |
n_layers_D = 4 | |
input_nc = 3 | |
norm_layer = get_nonspade_norm_layer('spectralinstance') | |
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), | |
nn.LeakyReLU(0.2, False)]] | |
for n in range(1, n_layers_D): | |
nf_prev = nf | |
nf = min(nf * 2, 512) | |
stride = 1 if n == n_layers_D - 1 else 2 | |
sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, | |
stride=stride, padding=padw)), | |
nn.LeakyReLU(0.2, False) | |
]] | |
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] | |
# We divide the layers into groups to extract intermediate layer outputs | |
for n in range(len(sequence)): | |
self.add_module('model' + str(n), nn.Sequential(*sequence[n])) | |
def forward(self, input, get_intermediate_features = True): | |
results = [input] | |
for submodel in self.children(): | |
intermediate_output = submodel(results[-1]) | |
results.append(intermediate_output) | |
if get_intermediate_features: | |
return results[1:] | |
else: | |
return results[-1] | |
class VGG19(torch.nn.Module): | |
def __init__(self, requires_grad=False): | |
super().__init__() | |
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features | |
self.slice1 = torch.nn.Sequential() | |
self.slice2 = torch.nn.Sequential() | |
self.slice3 = torch.nn.Sequential() | |
self.slice4 = torch.nn.Sequential() | |
self.slice5 = torch.nn.Sequential() | |
for x in range(2): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(2, 7): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(7, 12): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(12, 21): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(21, 30): | |
self.slice5.add_module(str(x), vgg_pretrained_features[x]) | |
import pdb; pdb.set_trace() | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, X): | |
h_relu1 = self.slice1(X) | |
h_relu2 = self.slice2(h_relu1) | |
h_relu3 = self.slice3(h_relu2) | |
h_relu4 = self.slice4(h_relu3) | |
h_relu5 = self.slice5(h_relu4) | |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] | |
return out | |
class encoder5(nn.Module): | |
def __init__(self): | |
super(encoder5,self).__init__() | |
# vgg | |
# 224 x 224 | |
self.conv1 = nn.Conv2d(3,3,1,1,0) | |
self.reflecPad1 = nn.ReflectionPad2d((1,1,1,1)) | |
# 226 x 226 | |
self.conv2 = nn.Conv2d(3,64,3,1,0) | |
self.relu2 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.reflecPad3 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv3 = nn.Conv2d(64,64,3,1,0) | |
self.relu3 = nn.ReLU(inplace=True) | |
# 224 x 224 | |
self.maxPool = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 112 x 112 | |
self.reflecPad4 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv4 = nn.Conv2d(64,128,3,1,0) | |
self.relu4 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.reflecPad5 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv5 = nn.Conv2d(128,128,3,1,0) | |
self.relu5 = nn.ReLU(inplace=True) | |
# 112 x 112 | |
self.maxPool2 = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 56 x 56 | |
self.reflecPad6 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv6 = nn.Conv2d(128,256,3,1,0) | |
self.relu6 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad7 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv7 = nn.Conv2d(256,256,3,1,0) | |
self.relu7 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad8 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv8 = nn.Conv2d(256,256,3,1,0) | |
self.relu8 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.reflecPad9 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv9 = nn.Conv2d(256,256,3,1,0) | |
self.relu9 = nn.ReLU(inplace=True) | |
# 56 x 56 | |
self.maxPool3 = nn.MaxPool2d(kernel_size=2,stride=2) | |
# 28 x 28 | |
self.reflecPad10 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv10 = nn.Conv2d(256,512,3,1,0) | |
self.relu10 = nn.ReLU(inplace=True) | |
self.reflecPad11 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv11 = nn.Conv2d(512,512,3,1,0) | |
self.relu11 = nn.ReLU(inplace=True) | |
self.reflecPad12 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv12 = nn.Conv2d(512,512,3,1,0) | |
self.relu12 = nn.ReLU(inplace=True) | |
self.reflecPad13 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv13 = nn.Conv2d(512,512,3,1,0) | |
self.relu13 = nn.ReLU(inplace=True) | |
self.maxPool4 = nn.MaxPool2d(kernel_size=2,stride=2) | |
self.reflecPad14 = nn.ReflectionPad2d((1,1,1,1)) | |
self.conv14 = nn.Conv2d(512,512,3,1,0) | |
self.relu14 = nn.ReLU(inplace=True) | |
def forward(self,x): | |
output = [] | |
out = self.conv1(x) | |
out = self.reflecPad1(out) | |
out = self.conv2(out) | |
out = self.relu2(out) | |
output.append(out) | |
out = self.reflecPad3(out) | |
out = self.conv3(out) | |
out = self.relu3(out) | |
out = self.maxPool(out) | |
out = self.reflecPad4(out) | |
out = self.conv4(out) | |
out = self.relu4(out) | |
output.append(out) | |
out = self.reflecPad5(out) | |
out = self.conv5(out) | |
out = self.relu5(out) | |
out = self.maxPool2(out) | |
out = self.reflecPad6(out) | |
out = self.conv6(out) | |
out = self.relu6(out) | |
output.append(out) | |
out = self.reflecPad7(out) | |
out = self.conv7(out) | |
out = self.relu7(out) | |
out = self.reflecPad8(out) | |
out = self.conv8(out) | |
out = self.relu8(out) | |
out = self.reflecPad9(out) | |
out = self.conv9(out) | |
out = self.relu9(out) | |
out = self.maxPool3(out) | |
out = self.reflecPad10(out) | |
out = self.conv10(out) | |
out = self.relu10(out) | |
output.append(out) | |
out = self.reflecPad11(out) | |
out = self.conv11(out) | |
out = self.relu11(out) | |
out = self.reflecPad12(out) | |
out = self.conv12(out) | |
out = self.relu12(out) | |
out = self.reflecPad13(out) | |
out = self.conv13(out) | |
out = self.relu13(out) | |
out = self.maxPool4(out) | |
out = self.reflecPad14(out) | |
out = self.conv14(out) | |
out = self.relu14(out) | |
output.append(out) | |
return output | |
class VGGLoss(nn.Module): | |
def __init__(self, model_path): | |
super(VGGLoss, self).__init__() | |
self.vgg = encoder5().cuda() | |
self.vgg.load_state_dict(torch.load(os.path.join(model_path, 'vgg_r51.pth'))) | |
self.criterion = nn.MSELoss() | |
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] | |
def forward(self, x, y): | |
x_vgg, y_vgg = self.vgg(x), self.vgg(y) | |
loss = 0 | |
for i in range(4): | |
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) | |
return loss | |
class GANLoss(nn.Module): | |
def __init__(self, gan_mode = 'hinge', target_real_label=1.0, target_fake_label=0.0, | |
tensor=torch.cuda.FloatTensor): | |
super(GANLoss, self).__init__() | |
self.real_label = target_real_label | |
self.fake_label = target_fake_label | |
self.real_label_tensor = None | |
self.fake_label_tensor = None | |
self.zero_tensor = None | |
self.Tensor = tensor | |
self.gan_mode = gan_mode | |
if gan_mode == 'ls': | |
pass | |
elif gan_mode == 'original': | |
pass | |
elif gan_mode == 'w': | |
pass | |
elif gan_mode == 'hinge': | |
pass | |
else: | |
raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) | |
def get_target_tensor(self, input, target_is_real): | |
if target_is_real: | |
if self.real_label_tensor is None: | |
self.real_label_tensor = self.Tensor(1).fill_(self.real_label) | |
self.real_label_tensor.requires_grad_(False) | |
return self.real_label_tensor.expand_as(input) | |
else: | |
if self.fake_label_tensor is None: | |
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) | |
self.fake_label_tensor.requires_grad_(False) | |
return self.fake_label_tensor.expand_as(input) | |
def get_zero_tensor(self, input): | |
if self.zero_tensor is None: | |
self.zero_tensor = self.Tensor(1).fill_(0) | |
self.zero_tensor.requires_grad_(False) | |
return self.zero_tensor.expand_as(input) | |
def loss(self, input, target_is_real, for_discriminator=True): | |
if self.gan_mode == 'original': # cross entropy loss | |
target_tensor = self.get_target_tensor(input, target_is_real) | |
loss = F.binary_cross_entropy_with_logits(input, target_tensor) | |
return loss | |
elif self.gan_mode == 'ls': | |
target_tensor = self.get_target_tensor(input, target_is_real) | |
return F.mse_loss(input, target_tensor) | |
elif self.gan_mode == 'hinge': | |
if for_discriminator: | |
if target_is_real: | |
minval = torch.min(input - 1, self.get_zero_tensor(input)) | |
loss = -torch.mean(minval) | |
else: | |
minval = torch.min(-input - 1, self.get_zero_tensor(input)) | |
loss = -torch.mean(minval) | |
else: | |
assert target_is_real, "The generator's hinge loss must be aiming for real" | |
loss = -torch.mean(input) | |
return loss | |
else: | |
# wgan | |
if target_is_real: | |
return -input.mean() | |
else: | |
return input.mean() | |
def __call__(self, input, target_is_real, for_discriminator=True): | |
# computing loss is a bit complicated because |input| may not be | |
# a tensor, but list of tensors in case of multiscale discriminator | |
if isinstance(input, list): | |
loss = 0 | |
for pred_i in input: | |
if isinstance(pred_i, list): | |
pred_i = pred_i[-1] | |
loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) | |
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) | |
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) | |
loss += new_loss | |
return loss / len(input) | |
else: | |
return self.loss(input, target_is_real, for_discriminator) | |
class SPADE_LOSS(nn.Module): | |
def __init__(self, model_path, lambda_feat = 1): | |
super(SPADE_LOSS, self).__init__() | |
self.criterionVGG = VGGLoss(model_path) | |
self.criterionGAN = GANLoss('hinge') | |
self.criterionL1 = nn.L1Loss() | |
self.discriminator = NLayerDiscriminator() | |
self.lambda_feat = lambda_feat | |
def forward(self, x, y, for_discriminator = False): | |
pred_real = self.discriminator(y) | |
if not for_discriminator: | |
pred_fake = self.discriminator(x) | |
VGGLoss = self.criterionVGG(x, y) | |
GANLoss = self.criterionGAN(pred_fake, True, for_discriminator = False) | |
# feature matching loss | |
# last output is the final prediction, so we exclude it | |
num_intermediate_outputs = len(pred_fake) - 1 | |
GAN_Feat_loss = 0 | |
for j in range(num_intermediate_outputs): # for each layer output | |
unweighted_loss = self.criterionL1(pred_fake[j], pred_real[j].detach()) | |
GAN_Feat_loss += unweighted_loss * self.lambda_feat | |
L1Loss = self.criterionL1(x, y) | |
return VGGLoss, GANLoss, GAN_Feat_loss, L1Loss | |
else: | |
pred_fake = self.discriminator(x.detach()) | |
GANLoss = self.criterionGAN(pred_fake, False, for_discriminator = True) | |
GANLoss += self.criterionGAN(pred_real, True, for_discriminator = True) | |
return GANLoss | |
class ContrastiveLoss(nn.Module): | |
""" | |
Contrastive loss | |
Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise | |
""" | |
def __init__(self, margin): | |
super(ContrastiveLoss, self).__init__() | |
self.margin = margin | |
self.eps = 1e-9 | |
def forward(self, out1, out2, target, size_average=True, norm = True): | |
if norm: | |
output1 = out1 / out1.pow(2).sum(1, keepdim=True).sqrt() | |
output2 = out1 / out2.pow(2).sum(1, keepdim=True).sqrt() | |
distances = (output2 - output1).pow(2).sum(1) # squared distances | |
losses = 0.5 * (target.float() * distances + | |
(1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2)) | |
return losses.mean() if size_average else losses.sum() | |